diff --git a/.gitmodules b/.gitmodules index d7f1a58bf9acd0c48d5a174b9a2eb55d068adfb2..630973e7f269884408ddf2d3b159fe924fe1c67a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,5 +1,4 @@ [submodule "mindspore"] path = mindspore url = https://gitee.com/mindspore/mindspore.git - # shallow = true branch = r2.7.rc1 \ No newline at end of file diff --git a/.jenkins/check/config/filter_cppcheck.txt b/.jenkins/check/config/filter_cppcheck.txt index 4f562fc726bfc74fd5f6a3364bb722f3e1285bd2..10e853cefed57b9ad69aa6ea88c32f6d907dbb64 100644 --- a/.jenkins/check/config/filter_cppcheck.txt +++ b/.jenkins/check/config/filter_cppcheck.txt @@ -36,3 +36,7 @@ "mindspore-lite/mindspore-lite/examples/quick_start_micro/" "syntaxError" "mindspore-lite/mindspore-lite/python/src/pybind_module.cc" "syntaxError" "mindspore-lite/mindspore-lite/java/src/main/native/model.cpp" "unreadVariable" + +# nnacl +"mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/" "unreadVariable" +"mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.c" "unknownMacro" diff --git a/.jenkins/check/config/filter_cpplint.txt b/.jenkins/check/config/filter_cpplint.txt index 7b1dc36a0461f6e4f1dfb282d51ab76c6fb59526..60814c367e5ffbf35cf102d079e52616e4ba0385 100644 --- a/.jenkins/check/config/filter_cpplint.txt +++ b/.jenkins/check/config/filter_cpplint.txt @@ -91,3 +91,6 @@ "mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_tik/sample/" "legal/copyright" "mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_tik/sample/" "whitespace/ending_newline" "mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_tik/sample/" "build/include" + +# nnacl +"mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/" "readability/casting" diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index c76e4a3d5c9a3bec41412bd3c4ff06807218a78f..b494fdf54878f0b8896325e4be0953a3ce032489 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -36,3 +36,208 @@ mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/dvpp_vid # other mindspore-lite/mindspore-lite/providers/nnie_proposal/src/proposal.cc:mindspore::proposal::Rpn mindspore-lite/mindspore-lite/providers/nnie/src/custom_infer.cc:mindspore::nnie::CustomInterface::Infer + +# nnacl +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.c:StridedSliceInferShape +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.c:CheckInputShapeValid +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.c:WinogradInputTransformFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.c:AvgPoolingFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.c:MaxPoolingFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform4x2UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform4x2ReluUnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform4x2Relu6UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform8x6UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform8x6ReluUnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform8x6Relu6UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.c:AvgPoolingOptInt8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.c:MaxPoolingWithQuantInt8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.c:Conv3x3Int8OutputUnit +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.c:Conv1x1PreOptPeroc +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.c:RegisterInfer +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.c:RowMajor2Col12MajorStride +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.c:RowMajor2Col8MajorStride +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.c:Conv3x3Fp16InputUnit +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.c:Conv3x3Fp16FilterTransform +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.c:AvgPoolingFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.c:MaxPoolingFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.c:PackNHWCToNCHWFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:InputTransform6x6UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:InputTransform8x8UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform4x2Relu6UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform8x6UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform8x6ReluUnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform8x6Relu6UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.c:AvgPoolingOptInt8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.c:Conv3x3Int8InputUnit +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.c:Conv3x3Int8FilterTransform +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.c:Conv3x3Int8OutputUnit +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.c:Conv1x1PreOptPeroc +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.c:Conv1x1PreOptPert +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.c:PackNHWCToNCHWInt8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.c:AvgPooling +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.c:MatMul4x1Kernel, MatMul2x1Kernel +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.c:SWConv3x32AVXKernel, SWConv4x24AVXKernel, SWConv12x8AVXKernel, SWConv8x8AVXKernel, SWConv4x8AVXKernel, SWConv6x16AVXKernel, SWConv4x16AVXKernel +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.c:DepthwiseSW3x32Kernel, DepthwiseSW4x24Kernel, DepthwiseSW12x8Kernel, DepthwiseSW8x8Kernel, DepthwiseSW4x8Kernel, DepthwiseSW6x16Kernel, DepthwiseSW4x16Kernel +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.c:Conv1x1SW3x32AVXKernel, Conv1x1SW4x24AVXKernel, Conv1x1SW12x8AVXKernel, Conv1x1SW8x8AVXKernel, Conv1x1SW4x8AVXKernel, Conv1x1SW6x16AVXKernel, Conv1x1SW4x16AVXKernel, Conv1x1SW1x32AVXKernel, Conv1x1SW1x24AVXKernel, Conv1x1SW1x16AVXKernel, Conv1x1SW1x8AVXKernel +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.c:MatMul3x32Kernel, MatMul4x24Kernel, MatMul12x8Kernel, MatMul8x8Kernel, MatMul4x8Kernel, MatMul6x16Kernel, MatMul4x16Kernel, MatVecMul1x32Kernel, MatVecMul1x24Kernel, MatVecMul1x16Kernel, MatVecMul1x8Kernel +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.c:MatMul3x32Kernel, MatMul4x24Kernel, MatMul12x8Kernel, MatMul8x8Kernel, MatMul4x8Kernel, MatMul6x16Kernel, MatMul4x16Kernel, MatVecMul1x32Kernel, MatVecMul1x24Kernel, MatVecMul1x16Kernel, MatVecMul1x8Kernel +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/TiledC4MatMulFp32.c:TiledC4MatmulFp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/PostFuncBiasReluC4.c:PostFuncBiasReluC4 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradTrans.c:WinogradTransRight +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradTrans.c:WinogradTransLeft +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradPostFuncBiasReluC4.c:WinogradPostFuncBiasReluC4 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/PostFuncBiasReluC8.c:PostFuncBiasReluC8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradPostFuncBiasReluC8.c:WinogradPostFuncBiasReluC8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.c:PackDeConvWgDataFp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradPostFuncBiasReluC4.c:WinogradPostFuncBiasReluC4 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/PostFuncBiasReluC8.c:PostFuncBiasReluC8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradTransAvx.c:WinogradTransLeft +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradTransAvx.c:WinogradTransRight +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/PostFuncBiasReluC8.c:PostFuncBiasReluC8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradPostFuncBiasReluC8.c:WinogradPostFuncBiasReluC8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.c:PackDeConvWgDataFp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.c:DeConvWgMerge +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/TiledC8MatMulFp32.c:TiledC8MatmulFp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.c:Fp16ToInt8_arm64 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_10x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_11x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x96_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x96_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x80_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x80_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x80_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_9x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_9x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x96_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x80_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_12x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_10x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_11x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x80_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_12x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x96_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.c:InstanceNormNC8HW8Fp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/MatMul_Sse.c:MatmulFloatSse64Opt +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.c:ConvWinogardFp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.c:ConvWinogardFp32CutByBatch +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/conv_fp32_nchwx_avx512.c:conv2d_compute_fp32_nchwx_avx512 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.c:GemmRowxColMaskKernelFp32 \ No newline at end of file diff --git a/mindspore-lite/CMakeLists.txt b/mindspore-lite/CMakeLists.txt index c497ee12a09a0839cfdf1cdab3d2a72e29f94e52..8acc712e352b5b686064cba5d40ff6443053e163 100644 --- a/mindspore-lite/CMakeLists.txt +++ b/mindspore-lite/CMakeLists.txt @@ -743,12 +743,12 @@ set(CORE_DIR ${TOP_DIR}/mindspore/mindspore/core) set(CORE_INC_DIR ${TOP_DIR}/mindspore/mindspore/core/include) set(CCSRC_DIR ${TOP_DIR}/mindspore/mindspore/ccsrc) set(OPS_DIR ${TOP_DIR}/mindspore/mindspore/ops) -set(NNACL_DIR ${OPS_DIR}/kernel/cpu/nnacl) +set(NNACL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src/litert/kernel/cpu/nnacl_c) if(PLATFORM_MCU) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-incompatible-pointer-types") # set(MSLITE_DEPS_CMSIS on) - add_subdirectory(${NNACL_DIR} build/nnacl) + add_subdirectory(${NNACL_DIR} build/nnacl_c) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter/micro/cmake/cortex-m/ build) include(${TOP_DIR}/cmake/package_lite.cmake) return() @@ -1062,7 +1062,7 @@ if(ANDROID_NDK_TOOLCHAIN_INCLUDED OR TARGET_OHOS_LITE OR TARGET_HIMIX) endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src) -add_subdirectory(${OPS_DIR}/kernel/cpu/nnacl build) +add_subdirectory(${NNACL_DIR} build) if(MSLITE_ENABLE_TOOLS) if(NOT MSLITE_COMPILE_TWICE) diff --git a/mindspore-lite/java/native/CMakeLists.txt b/mindspore-lite/java/native/CMakeLists.txt index ef9368c1c9c01c9294bf93567904d7f33e83ae9a..3cac36427560eb545b820afc35c97d7f543cddd7 100644 --- a/mindspore-lite/java/native/CMakeLists.txt +++ b/mindspore-lite/java/native/CMakeLists.txt @@ -10,6 +10,7 @@ set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../..) set(MINDSPORE_DIR ${TOP_DIR}/mindspore) set(LITE_DIR ${TOP_DIR}/mindspore-lite) set(NEW_NATIVE_DIR ${LITE_DIR}/java/src/main/native) +set(NNACL_DIR ${LITE_DIR}/src/litert/kernel/cpu/nnacl_c) include(${LITE_DIR}/cmake/secure_option.cmake) include(${LITE_DIR}/cmake/compile_link_option.cmake) @@ -110,7 +111,7 @@ include_directories(${MINDSPORE_DIR}) ## api include include_directories(${MINDSPORE_DIR}/mindspore/core/include) ## core include include_directories(${MINDSPORE_DIR}/mindspore/core/mindrt) ## core include include_directories(${MINDSPORE_DIR}/mindspore/core/mindrt/include) ## core include -include_directories(${MINDSPORE_DIR}/mindspore/ops/kernel/cpu) +include_directories(${NNACL_DIR}/../) include_directories(${TOP_DIR}/build) ## flatbuffers if(PLATFORM_ARM64 OR PLATFORM_ARM32) @@ -137,7 +138,7 @@ set(JNI_SRC ) set(CCSRC - ${MINDSPORE_DIR}/mindspore/ops/kernel/cpu/nnacl/nnacl_common.c + ${NNACL_DIR}/nnacl_common.c ) if(MSLITE_ENABLE_PARALLEL_INFERENCE) diff --git a/mindspore-lite/java/native/common/jni_utils.h b/mindspore-lite/java/native/common/jni_utils.h index 3980a4d0c4565bf13a6b2c431839c8f17658ac4e..1f57a43a9ba0a5f961342054ecd1c44f63d56323 100644 --- a/mindspore-lite/java/native/common/jni_utils.h +++ b/mindspore-lite/java/native/common/jni_utils.h @@ -19,7 +19,7 @@ #include #include #include -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" std::string RealPath(const char *path); diff --git a/mindspore-lite/minddata/CMakeLists.txt b/mindspore-lite/minddata/CMakeLists.txt index 4557feccf7fbb9b8c0b2685a85dbac821f319ec9..a7b68b5117cd67dc9fcccfc46f37e5efc86f217d 100644 --- a/mindspore-lite/minddata/CMakeLists.txt +++ b/mindspore-lite/minddata/CMakeLists.txt @@ -94,7 +94,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") include_directories("dataset/liteapi") include_directories("${TOP_DIR}/mindspore-lite") include_directories("${TOP_DIR}") - include_directories("${TOP_DIR}/mindspore/mindspore/ops/kernel/cpu") + include_directories("${NNACL_DIR}/../") if(MSLITE_ENABLE_ACL) include_directories(${CCSRC_DIR}) @@ -105,7 +105,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") ${TOP_DIR}/mindspore-lite/src/litert/cxx_api/tensor_utils.cc ${TOP_DIR}/mindspore-lite/src/litert/cxx_api/tensor/tensor_impl.cc ${TOP_DIR}/mindspore-lite/src/tensor.cc - ${TOP_DIR}/mindspore/mindspore/ops/kernel/cpu/nnacl/tensor_c_utils.c + ${NNACL_DIR}/tensor_c_utils.c ${TOP_DIR}/mindspore-lite/src/common/utils.cc ${TOP_DIR}/mindspore-lite/src/common/string_util.cc) diff --git a/mindspore-lite/python/CMakeLists.txt b/mindspore-lite/python/CMakeLists.txt index 14747406db2807d228fc9ec2e9086994b18abb73..238d0950ade20387fb9036cb3672ea56f809feea 100644 --- a/mindspore-lite/python/CMakeLists.txt +++ b/mindspore-lite/python/CMakeLists.txt @@ -17,7 +17,7 @@ if(Python3_FOUND) include_directories(${TOP_DIR}/mindspore/mindspore/core/include) include_directories(${TOP_DIR}/mindspore/mindspore/core/mindrt) include_directories(${TOP_DIR}/mindspore/mindspore/core/mindrt/include) - include_directories(${TOP_DIR}/mindspore/mindspore/ops/kernel/cpu/) + include_directories(${NNACL_DIR}/../) if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) add_compile_definitions(MSLITE_ENABLE_CLOUD_INFERENCE) diff --git a/mindspore-lite/src/CMakeLists.txt b/mindspore-lite/src/CMakeLists.txt index 84c54640be23f65f77a815284bc3ec6c4758b639..7e6c3718fdfe6ee1a7a59f63e5be05ede9ab85a0 100644 --- a/mindspore-lite/src/CMakeLists.txt +++ b/mindspore-lite/src/CMakeLists.txt @@ -2,7 +2,7 @@ add_compile_definitions(USE_ANDROID_LOG) set(LITE_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CORE_INC_DIR}) -include_directories(${OPS_DIR}/kernel/cpu) +include_directories(${NNACL_DIR}/../) include_directories(${OPS_DIR}/kernel/include) set(TOOLS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../tools) diff --git a/mindspore-lite/src/common/common.h b/mindspore-lite/src/common/common.h index 0de6cb8f5a2dfc0a3b8809dd243ba3308c21559d..83ea8e7abacd9afcc6cdc2394e6b098268d8f52d 100644 --- a/mindspore-lite/src/common/common.h +++ b/mindspore-lite/src/common/common.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_COMMON_COMMON_H_ #include -#include "mindspore/ops/kernel/cpu/nnacl/op_base.h" +#include "nnacl_c/op_base.h" /* Naming a key of path must be consistent with existing naming styles and follow the following rules: diff --git a/mindspore-lite/src/common/graph_util.cc b/mindspore-lite/src/common/graph_util.cc index d5cb2a985d0dcc6b06ae002df9c6850231adac75..c025391311395f1aee9f27807c1df07c13613fef 100644 --- a/mindspore-lite/src/common/graph_util.cc +++ b/mindspore-lite/src/common/graph_util.cc @@ -23,7 +23,7 @@ #include "src/common/log_adapter.h" #include "src/common/version_manager.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/common/ops/CMakeLists.txt b/mindspore-lite/src/common/ops/CMakeLists.txt index 0c6a666fac6a5604107e82cb4668b7ca1c338925..a6cfa820270b9351c5cc21298a6c4711685384a3 100644 --- a/mindspore-lite/src/common/ops/CMakeLists.txt +++ b/mindspore-lite/src/common/ops/CMakeLists.txt @@ -1,5 +1,5 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/) -include_directories(${OPS_DIR}/kernel/cpu) +include_directories(${NNACL_DIR}/../) if(APPLE) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstrict-aliasing -ffunction-sections \ -fdata-sections -ffast-math -fno-rtti -fno-exceptions -Wno-shorten-64-to-32 \ diff --git a/mindspore-lite/src/common/ops/operator_populate/activation_grad_populate.cc b/mindspore-lite/src/common/ops/operator_populate/activation_grad_populate.cc index 0299e0a27db9749f4e5904233ff17c5324e4f232..e4681cbaa6217038c6d0244155e838c780400202 100644 --- a/mindspore-lite/src/common/ops/operator_populate/activation_grad_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/activation_grad_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32_grad/activation_grad_fp32.h" +#include "nnacl_c/fp32_grad/activation_grad_fp32.h" #include "infer/grad/activation_grad.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" using mindspore::ops::kNameActivationGrad; diff --git a/mindspore-lite/src/common/ops/operator_populate/activation_populate.cc b/mindspore-lite/src/common/ops/operator_populate/activation_populate.cc index e2294c51a4524cfce27d688013bb22dda119bf0e..587a77537eaaca52342a7f07284be74e88d7ac01 100644 --- a/mindspore-lite/src/common/ops/operator_populate/activation_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/activation_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/activation.h" #include "infer/leaky_relu.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/adder_populate.cc b/mindspore-lite/src/common/ops/operator_populate/adder_populate.cc index ec1936f7982540451228d101110a986cdeff0582..12c525e5f6a23c055e024986b44377778711969b 100644 --- a/mindspore-lite/src/common/ops/operator_populate/adder_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/adder_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "infer/adder.h" #include "infer/cxx_api/adder_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/affine_populate.cc b/mindspore-lite/src/common/ops/operator_populate/affine_populate.cc index 4056c571424ce497266474d8242b5803d448ec82..45ab33ed3f19d233053da4665f932fbe77b485ab 100644 --- a/mindspore-lite/src/common/ops/operator_populate/affine_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/affine_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/affine_parameter.h" +#include "nnacl_c/affine_parameter.h" #include "infer/affine.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" using mindspore::ops::kNameAffine; diff --git a/mindspore-lite/src/common/ops/operator_populate/all_gather_populate.cc b/mindspore-lite/src/common/ops/operator_populate/all_gather_populate.cc index 149d598340ca995bfcfd0752ac64920e225fb22e..fd7c3450742022f64983307c6928d88159fbc02e 100644 --- a/mindspore-lite/src/common/ops/operator_populate/all_gather_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/all_gather_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/all_gather_parameter.h" +#include "nnacl_c/all_gather_parameter.h" #include "infer/all_gather.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" using mindspore::ops::kNameAllGather; diff --git a/mindspore-lite/src/common/ops/operator_populate/arg_minmax_populate.cc b/mindspore-lite/src/common/ops/operator_populate/arg_minmax_populate.cc index acbb5363242995456abd3d58ebbad6a953fb0a4d..78a503ac76e9bfd3c6e089ccce2d4fa3ceb2ee63 100644 --- a/mindspore-lite/src/common/ops/operator_populate/arg_minmax_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/arg_minmax_populate.cc @@ -16,7 +16,7 @@ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" #include "mindspore/ops/op_def/array_ops.h" -#include "nnacl/arg_min_max_parameter.h" +#include "nnacl_c/arg_min_max_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/arg_max_fusion.h" #include "infer/cxx_api/arg_min_fusion.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/arithmetic_operator_populate.h b/mindspore-lite/src/common/ops/operator_populate/arithmetic_operator_populate.h index 9f514dae3cbd1af685940a2fe523dde49a6f81eb..c14e2f94e1c0575c656d668c9346a5621d7fabcd 100644 --- a/mindspore-lite/src/common/ops/operator_populate/arithmetic_operator_populate.h +++ b/mindspore-lite/src/common/ops/operator_populate/arithmetic_operator_populate.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_LITE_SRC_COMMON_OPS_OPERATOR_POPULATE_ARITHMETIC_POPULATE_H_ #define MINDSPORE_LITE_SRC_COMMON_OPS_OPERATOR_POPULATE_ARITHMETIC_POPULATE_H_ -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" #include "src/common/ops/operator_populate/operator_populate_register.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/common/ops/operator_populate/arithmetic_self_populate.cc b/mindspore-lite/src/common/ops/operator_populate/arithmetic_self_populate.cc index 35a4a21961e4f573015ee44a6662ce57d0f60577..1c922365c560f474236e44795764a84bebac8393 100644 --- a/mindspore-lite/src/common/ops/operator_populate/arithmetic_self_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/arithmetic_self_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/arithmetic_self_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/grad/log_grad.h" #include "infer/grad/neg_grad.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/attention_populate.cc b/mindspore-lite/src/common/ops/operator_populate/attention_populate.cc index 128cb9c5fea3cc33cf25f4308e18ecf5b3b95313..9b9470cfbe92d9c25163de73c39588caf7ceb36c 100644 --- a/mindspore-lite/src/common/ops/operator_populate/attention_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/attention_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/attention_parameter.h" +#include "nnacl_c/attention_parameter.h" #include "infer/attention.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" using mindspore::ops::kNameAttention; diff --git a/mindspore-lite/src/common/ops/operator_populate/audio_spectrogram_populate.cc b/mindspore-lite/src/common/ops/operator_populate/audio_spectrogram_populate.cc index a9930018ce2e0a0d5b5429a29d6fba48b9e377d8..2aa371f25a3f827246fb161a8854a80b7b9fe276 100644 --- a/mindspore-lite/src/common/ops/operator_populate/audio_spectrogram_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/audio_spectrogram_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/infer/audio_spectrogram_infer.h" +#include "nnacl_c/infer/audio_spectrogram_infer.h" #include "infer/audio_spectrogram.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" using mindspore::ops::kNameAudioSpectrogram; diff --git a/mindspore-lite/src/common/ops/operator_populate/base_operator_populate.cc b/mindspore-lite/src/common/ops/operator_populate/base_operator_populate.cc index a42ac7f53fea4794d259aa376b8d386b194c0a17..406854a354622783fe1f0bdbe35e04ef28bfa553 100644 --- a/mindspore-lite/src/common/ops/operator_populate/base_operator_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/base_operator_populate.cc @@ -14,16 +14,16 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/where_parameter.h" -#include "nnacl/sparse_to_dense_parameter.h" -#include "nnacl/transpose_parameter.h" -#include "nnacl/triu_tril_parameter.h" -#include "nnacl/fp32/unique_fp32.h" -#include "nnacl/scatter_nd_parameter.h" -#include "nnacl/op_base.h" -#include "nnacl/gather_parameter.h" -#include "nnacl/gather_nd_parameter.h" -#include "nnacl/reshape_parameter.h" +#include "nnacl_c/where_parameter.h" +#include "nnacl_c/sparse_to_dense_parameter.h" +#include "nnacl_c/transpose_parameter.h" +#include "nnacl_c/triu_tril_parameter.h" +#include "nnacl_c/fp32/unique_fp32.h" +#include "nnacl_c/scatter_nd_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/gather_parameter.h" +#include "nnacl_c/gather_nd_parameter.h" +#include "nnacl_c/reshape_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/adam.h" #include "infer/assert.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/batch_norm_populate.cc b/mindspore-lite/src/common/ops/operator_populate/batch_norm_populate.cc index 7b0248fd4482db78be3376ddbb591628acc427f7..a89aea9f34f548006ef3fe80a530648eece02a47 100644 --- a/mindspore-lite/src/common/ops/operator_populate/batch_norm_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/batch_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" using mindspore::schema::PrimitiveType_BatchNorm; diff --git a/mindspore-lite/src/common/ops/operator_populate/batch_to_space_populate.cc b/mindspore-lite/src/common/ops/operator_populate/batch_to_space_populate.cc index e5385ee51151a94bedfacc13d1db7195545cd38b..ab8e965bd9e45f4e097346de5ef233becaba648b 100644 --- a/mindspore-lite/src/common/ops/operator_populate/batch_to_space_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/batch_to_space_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/batch_to_space_parameter.h" +#include "nnacl_c/batch_to_space_parameter.h" #include "infer/batch_to_space.h" #include "infer/batch_to_space_nd.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/broadcast_to_populate.cc b/mindspore-lite/src/common/ops/operator_populate/broadcast_to_populate.cc index 8afe3ef49a29f5b5af3b0c9be1704cadebeb5b9d..d4a97f54c3e2acbb64377dc3a696443d30695945 100644 --- a/mindspore-lite/src/common/ops/operator_populate/broadcast_to_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/broadcast_to_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/base/broadcast_to.h" +#include "nnacl_c/base/broadcast_to.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/call_populate.cc b/mindspore-lite/src/common/ops/operator_populate/call_populate.cc index 785c7c55516d7f132f421dab33bf2ec36608e182..a727ba0205e625d546d073a977aab76f9c70c397 100644 --- a/mindspore-lite/src/common/ops/operator_populate/call_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/call_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/call_parameter.h" +#include "nnacl_c/call_parameter.h" #include "infer/call.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameCall; diff --git a/mindspore-lite/src/common/ops/operator_populate/clip_populate.cc b/mindspore-lite/src/common/ops/operator_populate/clip_populate.cc index 22eee0fda340b47cbca953ee07e8e70fe44e1a8b..491416fdbbd396d6634c5de04a73ad3c1f28e5ff 100644 --- a/mindspore-lite/src/common/ops/operator_populate/clip_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/clip_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/clip_parameter.h" +#include "nnacl_c/clip_parameter.h" #include "infer/clip.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameClip; diff --git a/mindspore-lite/src/common/ops/operator_populate/concat_populate.cc b/mindspore-lite/src/common/ops/operator_populate/concat_populate.cc index 76ff105ad60191b403ea2ce8efdb20428effef81..e10d7e73bf3c42e46757dd305e9ba5f4549b5cca 100644 --- a/mindspore-lite/src/common/ops/operator_populate/concat_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/concat_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameConcat; diff --git a/mindspore-lite/src/common/ops/operator_populate/constant_of_shape_populate.cc b/mindspore-lite/src/common/ops/operator_populate/constant_of_shape_populate.cc index 499ea4f83ae1ba31559936fab078c23f40a2e92f..1fefda841324608c5b6f6d2d82755b94bdeaa014 100644 --- a/mindspore-lite/src/common/ops/operator_populate/constant_of_shape_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/constant_of_shape_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/constant_of_shape_parameter.h" +#include "nnacl_c/constant_of_shape_parameter.h" #include "infer/constant_of_shape.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameConstantOfShape; diff --git a/mindspore-lite/src/common/ops/operator_populate/conv2d_populate.cc b/mindspore-lite/src/common/ops/operator_populate/conv2d_populate.cc index d2762e8786c484262b2b77e904a9c439bfad5a76..f326917f1401f9fde4f288fc3dec400c4ce3d7e7 100644 --- a/mindspore-lite/src/common/ops/operator_populate/conv2d_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/conv2d_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "infer/conv2d.h" #include "infer/cxx_api/conv2d_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/crop_and_resize_populate.cc b/mindspore-lite/src/common/ops/operator_populate/crop_and_resize_populate.cc index 801c8c2bb044d86af26a9693583bdedb2b012888..99d4dff4484b68f9baa592112e48987c56890cbf 100644 --- a/mindspore-lite/src/common/ops/operator_populate/crop_and_resize_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/crop_and_resize_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" #include "infer/crop_and_resize.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameCropAndResize; diff --git a/mindspore-lite/src/common/ops/operator_populate/crop_populate.cc b/mindspore-lite/src/common/ops/operator_populate/crop_populate.cc index ad24e8e3ed5597648f75e86cd21dd7829a7fc1a2..d52c98d9ab5a4bbbf591b1ccfb49717fea4ec9ee 100644 --- a/mindspore-lite/src/common/ops/operator_populate/crop_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/crop_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/crop_parameter.h" +#include "nnacl_c/crop_parameter.h" #include "infer/crop.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameCrop; diff --git a/mindspore-lite/src/common/ops/operator_populate/cumsum_populate.cc b/mindspore-lite/src/common/ops/operator_populate/cumsum_populate.cc index 83ceca366f529ef7fc793f5ec5ee4b6a72cc9abb..3bbb8c45054042175c59c076fe522a117b59bc49 100644 --- a/mindspore-lite/src/common/ops/operator_populate/cumsum_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/cumsum_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/cumsum_parameter.h" +#include "nnacl_c/cumsum_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameCumSum; diff --git a/mindspore-lite/src/common/ops/operator_populate/custom_populate.cc b/mindspore-lite/src/common/ops/operator_populate/custom_populate.cc index 54f212858ab7509bfaeb6bdb6ffed1a3cd3d46c8..80e1875000373744c3e0bdcbc2c7d7d3ce4dc4a3 100644 --- a/mindspore-lite/src/common/ops/operator_populate/custom_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/custom_populate.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/custom_parameter.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/custom_parameter.h" +#include "nnacl_c/split_parameter.h" #include "infer/custom.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameCustom; diff --git a/mindspore-lite/src/common/ops/operator_populate/custom_predict_populate.cc b/mindspore-lite/src/common/ops/operator_populate/custom_predict_populate.cc index b17e6075c2c2e6e61006a1b07f7b7d62525ee83d..46207a60793c275c7262ef8414a538597c0290c6 100644 --- a/mindspore-lite/src/common/ops/operator_populate/custom_predict_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/custom_predict_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/predict_parameter.h" +#include "nnacl_c/predict_parameter.h" #include "infer/custom_predict.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameCustomPredict; diff --git a/mindspore-lite/src/common/ops/operator_populate/deconv2d_populate.cc b/mindspore-lite/src/common/ops/operator_populate/deconv2d_populate.cc index 26a660123096467e9753a8058e9d402afb7a5aa8..494f7e68e3ef7d928f08e390b7b3bec42665d669 100644 --- a/mindspore-lite/src/common/ops/operator_populate/deconv2d_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/deconv2d_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "infer/cxx_api/conv2d_transpose_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameConv2dTransposeFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/depth_to_space_populate.cc b/mindspore-lite/src/common/ops/operator_populate/depth_to_space_populate.cc index 155038b02fa0e289f1d8a1239528694c0e17cb52..2893631524e830b0b30e8b935dba7d8651a7c5cb 100644 --- a/mindspore-lite/src/common/ops/operator_populate/depth_to_space_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/depth_to_space_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/depth_to_space_parameter.h" +#include "nnacl_c/depth_to_space_parameter.h" #include "infer/depth_to_space.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" using mindspore::ops::kNameDepthToSpace; diff --git a/mindspore-lite/src/common/ops/operator_populate/detection_post_process_populate.cc b/mindspore-lite/src/common/ops/operator_populate/detection_post_process_populate.cc index 80d8e9fc432985094775781c78c2f746a0ee36b3..7a10ce920aeb71b635157673e976588d48c7869b 100644 --- a/mindspore-lite/src/common/ops/operator_populate/detection_post_process_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/detection_post_process_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/detection_post_process_parameter.h" +#include "nnacl_c/detection_post_process_parameter.h" #include "infer/detection_post_process.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" using mindspore::ops::kNameDetectionPostProcess; diff --git a/mindspore-lite/src/common/ops/operator_populate/dynamic_quant_populate.cc b/mindspore-lite/src/common/ops/operator_populate/dynamic_quant_populate.cc index 6da61ae05cdf4f0d7e17212de21ce25ef4f3497d..cbd34195ef218048059687b4b6c1244255a61a35 100644 --- a/mindspore-lite/src/common/ops/operator_populate/dynamic_quant_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/dynamic_quant_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/dynamic_quant_parameter.h" +#include "nnacl_c/dynamic_quant_parameter.h" #include "infer/dynamic_quant.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" using mindspore::ops::kNameDynamicQuant; diff --git a/mindspore-lite/src/common/ops/operator_populate/embedding_lookup_populate.cc b/mindspore-lite/src/common/ops/operator_populate/embedding_lookup_populate.cc index cd37e26310a352fb91c3cd7e99af5eee94e1a58f..42e93a36577de1786eb221367c49810ee1e17812 100644 --- a/mindspore-lite/src/common/ops/operator_populate/embedding_lookup_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/embedding_lookup_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/embedding_lookup_fp32.h" +#include "nnacl_c/fp32/embedding_lookup_fp32.h" #include "infer/embedding_lookup.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" using mindspore::ops::kMaxNorm; diff --git a/mindspore-lite/src/common/ops/operator_populate/exp_populate.cc b/mindspore-lite/src/common/ops/operator_populate/exp_populate.cc index 03b6176639192f84efcd424018faf9540ac82858..3c414b502a7638aebc745a1e46a756c23dd3df3e 100644 --- a/mindspore-lite/src/common/ops/operator_populate/exp_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/exp_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/fp32/exp_fp32.h" +#include "nnacl_c/fp32/exp_fp32.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/exp_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/flatten_populate.cc b/mindspore-lite/src/common/ops/operator_populate/flatten_populate.cc index cf2cbc8bc413921438a37fa56329f239d38e229e..e48de873fd4c90db97baad605231d12ba304380f 100644 --- a/mindspore-lite/src/common/ops/operator_populate/flatten_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/flatten_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/flatten_parameter.h" +#include "nnacl_c/flatten_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" using mindspore::ops::kNameFlatten; diff --git a/mindspore-lite/src/common/ops/operator_populate/full_connection_populate.cc b/mindspore-lite/src/common/ops/operator_populate/full_connection_populate.cc index 279bf12fa88acd65f87eefa92959a1f19b2a0394..f850b7dd6c091086877ba01c2573d6c5f9acb14c 100644 --- a/mindspore-lite/src/common/ops/operator_populate/full_connection_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/full_connection_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "infer/cxx_api/full_connection.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" using mindspore::ops::kNameFullConnection; diff --git a/mindspore-lite/src/common/ops/operator_populate/fused_batchnorm_populate.cc b/mindspore-lite/src/common/ops/operator_populate/fused_batchnorm_populate.cc index bb27d5c81f7a464b31f67473994055f87b4ff8fe..ba21e6cf372d6213638b797c380412c82e019051 100644 --- a/mindspore-lite/src/common/ops/operator_populate/fused_batchnorm_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/fused_batchnorm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" #include "infer/fused_batch_norm.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" using mindspore::ops::kNameFusedBatchNorm; diff --git a/mindspore-lite/src/common/ops/operator_populate/glu_populate.cc b/mindspore-lite/src/common/ops/operator_populate/glu_populate.cc index 9d565e524d20be910d663e655892ae5cadfc700d..de7c0822346e72a7cc3df863eaa2d045e030b5c0 100644 --- a/mindspore-lite/src/common/ops/operator_populate/glu_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/glu_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/glu_parameter.h" +#include "nnacl_c/glu_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" using mindspore::ops::kAxis; using mindspore::ops::kNameGLU; diff --git a/mindspore-lite/src/common/ops/operator_populate/group_norm_populate.cc b/mindspore-lite/src/common/ops/operator_populate/group_norm_populate.cc index f303eee700ef484399c45fed977e72db41dfaee5..ad9110b4b5492a9cf72c91ba9e4773152b333a30 100644 --- a/mindspore-lite/src/common/ops/operator_populate/group_norm_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/group_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/group_norm_parameter.h" +#include "nnacl_c/group_norm_parameter.h" #include "infer/cxx_api/groupnorm_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" using mindspore::ops::kNameGroupNormFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/gru_populate.cc b/mindspore-lite/src/common/ops/operator_populate/gru_populate.cc index 014d1b6326ac716b73c80dae72974f9cf5851fee..0c1b0a6b37230d043eaf30bae7e41c9826fe20b6 100644 --- a/mindspore-lite/src/common/ops/operator_populate/gru_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/gru_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/gru_fp32.h" +#include "nnacl_c/fp32/gru_fp32.h" #include "infer/gru.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" using mindspore::ops::kBidirectional; diff --git a/mindspore-lite/src/common/ops/operator_populate/instance_norm_populate.cc b/mindspore-lite/src/common/ops/operator_populate/instance_norm_populate.cc index c83da511c233cabe68a083789ab751fd797b06bb..784981686c503a37a62b08b6fc8a027c83f497fb 100644 --- a/mindspore-lite/src/common/ops/operator_populate/instance_norm_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/instance_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/instance_norm_parameter.h" +#include "nnacl_c/instance_norm_parameter.h" #include "infer/instance_norm.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" using mindspore::ops::kEpsilon; diff --git a/mindspore-lite/src/common/ops/operator_populate/l2_norm_populate.cc b/mindspore-lite/src/common/ops/operator_populate/l2_norm_populate.cc index 8243dedc35aa38d1ab0dee451ff71228c95ce52e..219ddf16b8070c2b19c51d231e76221d82e7ebed 100644 --- a/mindspore-lite/src/common/ops/operator_populate/l2_norm_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/l2_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/l2_norm_parameter.h" +#include "nnacl_c/l2_norm_parameter.h" #include "infer/l2_normalize.h" #include "infer/cxx_api/l2_normalize_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/layer_norm_grad_populate.cc b/mindspore-lite/src/common/ops/operator_populate/layer_norm_grad_populate.cc index a3b2b32a5d08e60a79062cdd28debca59ba1f67c..cf1068254673b81c9f1b4553fb99bd05d24d63bc 100644 --- a/mindspore-lite/src/common/ops/operator_populate/layer_norm_grad_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/layer_norm_grad_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32_grad/layernormgrad_parameter.h" +#include "nnacl_c/fp32_grad/layernormgrad_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" using mindspore::ops::kNameLayerNormGrad; diff --git a/mindspore-lite/src/common/ops/operator_populate/layer_norm_populate.cc b/mindspore-lite/src/common/ops/operator_populate/layer_norm_populate.cc index 8166bc66f302daa665e2deae67f860e07565713d..0034b530ca2d93df2deae00a79aa89baac89fee2 100644 --- a/mindspore-lite/src/common/ops/operator_populate/layer_norm_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/layer_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/layer_norm_parameter.h" +#include "nnacl_c/layer_norm_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/layer_norm_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/local_response_normalization_populate.cc b/mindspore-lite/src/common/ops/operator_populate/local_response_normalization_populate.cc index d004ca578b60113ca8774472e2e5c7d239ba2f5e..8c806e6db87b64f1c32d34e49094e6b3686f49b1 100644 --- a/mindspore-lite/src/common/ops/operator_populate/local_response_normalization_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/local_response_normalization_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/local_response_norm_fp32.h" +#include "nnacl_c/fp32/local_response_norm_fp32.h" #include "infer/lrn.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" using mindspore::ops::kNameLRN; diff --git a/mindspore-lite/src/common/ops/operator_populate/log_softmax_populate.cc b/mindspore-lite/src/common/ops/operator_populate/log_softmax_populate.cc index 10e5701a289f4a9c75273458a3e17148195ee549..c4e5860912a2a4f321321194e823a0b30e592d62 100644 --- a/mindspore-lite/src/common/ops/operator_populate/log_softmax_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/log_softmax_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/lsh_projection_populate.cc b/mindspore-lite/src/common/ops/operator_populate/lsh_projection_populate.cc index 6cd51b7fa3c3fa56c123af45bbf111cd71ac392f..a9eb9942e590832a24845671f5eb2ec9d5570fa6 100644 --- a/mindspore-lite/src/common/ops/operator_populate/lsh_projection_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/lsh_projection_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/lsh_projection_parameter.h" +#include "nnacl_c/lsh_projection_parameter.h" #include "infer/lsh_projection.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" using mindspore::ops::kNameLshProjection; diff --git a/mindspore-lite/src/common/ops/operator_populate/lstm_populate.cc b/mindspore-lite/src/common/ops/operator_populate/lstm_populate.cc index 9ee910b860586bc4fd41643b4863aaf8c45135fb..d9faa31a74e08472e4dffc9e77b35a97b21cf5dd 100644 --- a/mindspore-lite/src/common/ops/operator_populate/lstm_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/lstm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" #include "infer/lstm.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" using mindspore::ops::kNameLSTM; diff --git a/mindspore-lite/src/common/ops/operator_populate/matmul_populate.cc b/mindspore-lite/src/common/ops/operator_populate/matmul_populate.cc index 939c686fe89385a63f949729390de5ab8ccb6f99..ed17cbd18bea8f63eeaf9aa38cc82544f4dd8b30 100644 --- a/mindspore-lite/src/common/ops/operator_populate/matmul_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/matmul_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/mat_mul_fusion.h" #include "mindspore/ops/op_def/op_name.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/mfcc_populate.cc b/mindspore-lite/src/common/ops/operator_populate/mfcc_populate.cc index eaf48b97fc9d38d2706bfc9845af5901e94f7fd0..613e9b593d747f8ebe5a17541ec6eb5d5c3c8203 100644 --- a/mindspore-lite/src/common/ops/operator_populate/mfcc_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/mfcc_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/infer/mfcc_infer.h" +#include "nnacl_c/infer/mfcc_infer.h" #include "infer/mfcc.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" using mindspore::ops::kNameMfcc; diff --git a/mindspore-lite/src/common/ops/operator_populate/nllloss_populate.cc b/mindspore-lite/src/common/ops/operator_populate/nllloss_populate.cc index cf04225c94a16e041e84af4aaaab0c6d03c19d2c..2a5fbdd71139f5b496157ab095e8adce4feb6a88 100644 --- a/mindspore-lite/src/common/ops/operator_populate/nllloss_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/nllloss_populate.cc @@ -16,7 +16,7 @@ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/nllloss_parameter.h" +#include "nnacl_c/nllloss_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_n.h" using mindspore::ops::kNameNLLLoss; diff --git a/mindspore-lite/src/common/ops/operator_populate/non_max_suppression_populate.cc b/mindspore-lite/src/common/ops/operator_populate/non_max_suppression_populate.cc index 90b0907b08fd1a630ab2061afd5fe24c6a58fe9e..ace5ba69fc87b8863ee8f01312378b514f719cd8 100644 --- a/mindspore-lite/src/common/ops/operator_populate/non_max_suppression_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/non_max_suppression_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/non_max_suppression_parameter.h" +#include "nnacl_c/non_max_suppression_parameter.h" #include "infer/non_max_suppression.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_n.h" using mindspore::ops::kNameNonMaxSuppression; diff --git a/mindspore-lite/src/common/ops/operator_populate/one_hot_populate.cc b/mindspore-lite/src/common/ops/operator_populate/one_hot_populate.cc index 55761e8cf49e05e86397e97f7cf406b6cc63e104..77a5f2ea5a8f2c40291cc88961600dd01ee18e9e 100644 --- a/mindspore-lite/src/common/ops/operator_populate/one_hot_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/one_hot_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/one_hot_fp32.h" +#include "nnacl_c/fp32/one_hot_fp32.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_o.h" using mindspore::ops::kNameOneHot; diff --git a/mindspore-lite/src/common/ops/operator_populate/operator_populate_register.h b/mindspore-lite/src/common/ops/operator_populate/operator_populate_register.h index 204f486cbee4e49d71886e02c0d0dc391ba8fffe..50718f0ffcfa8bd913493341b9f1d6d724e40e06 100644 --- a/mindspore-lite/src/common/ops/operator_populate/operator_populate_register.h +++ b/mindspore-lite/src/common/ops/operator_populate/operator_populate_register.h @@ -22,7 +22,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/common.h" #include "src/common/log_adapter.h" #include "src/common/version_manager.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/p_relu_populate.cc b/mindspore-lite/src/common/ops/operator_populate/p_relu_populate.cc index 35bf274601ec1f964c2eca430ca91ae80f280574..225ffc5a451e95ba3c988a208756bfea1a8d7543 100644 --- a/mindspore-lite/src/common/ops/operator_populate/p_relu_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/p_relu_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/prelu_parameter.h" +#include "nnacl_c/prelu_parameter.h" #include "infer/cxx_api/prelu_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" using mindspore::ops::kNamePReLUFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/pad_populate.cc b/mindspore-lite/src/common/ops/operator_populate/pad_populate.cc index e2a90db1f9d7626aa04e500bd5053840aaa66e44..d694075996d4c808b26d175f62d6c79406acc6ea 100644 --- a/mindspore-lite/src/common/ops/operator_populate/pad_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/pad_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/pad_parameter.h" +#include "nnacl_c/pad_parameter.h" #include "infer/cxx_api/pad_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" using mindspore::ops::kNamePadFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/partial_populate.cc b/mindspore-lite/src/common/ops/operator_populate/partial_populate.cc index 96c9855fe77654767c1ff56c0ec3be7cb6cd2133..efc57b27fe762d5150059ac38aab03cbccb6bfb0 100644 --- a/mindspore-lite/src/common/ops/operator_populate/partial_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/partial_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/partial_fusion_parameter.h" +#include "nnacl_c/partial_fusion_parameter.h" #include "infer/cxx_api/partial_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" using mindspore::ops::kNamePartialFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/pooling_populate.cc b/mindspore-lite/src/common/ops/operator_populate/pooling_populate.cc index 0339695c59218b39ed2e36107bc4e5dc8aacf94a..2f8358ada00a56a265537c9156859ae2c4cab55f 100644 --- a/mindspore-lite/src/common/ops/operator_populate/pooling_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/pooling_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/pooling_parameter.h" +#include "nnacl_c/pooling_parameter.h" #include "infer/cxx_api/avg_pool_fusion.h" #include "infer/cxx_api/max_pool_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/power_populate.cc b/mindspore-lite/src/common/ops/operator_populate/power_populate.cc index 08001d52330b6413e953bc1ffa5f89329e1b0665..7960d5d923f56f50eaaeea510810c6aa376ba808 100644 --- a/mindspore-lite/src/common/ops/operator_populate/power_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/power_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/pow_parameter.h" +#include "nnacl_c/pow_parameter.h" #include "infer/cxx_api/pow_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" using mindspore::ops::kNamePowFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/prior_box_populate.cc b/mindspore-lite/src/common/ops/operator_populate/prior_box_populate.cc index e58e844899e2060db747f94e8cc7a8d00f4cc114..93ae68de4216c4d19b3d0ce0253f9dd41bedcbbf 100644 --- a/mindspore-lite/src/common/ops/operator_populate/prior_box_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/prior_box_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/prior_box_parameter.h" +#include "nnacl_c/prior_box_parameter.h" #include "infer/prior_box.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" using mindspore::ops::kNamePriorBox; diff --git a/mindspore-lite/src/common/ops/operator_populate/quant_dtype_cast_populate.cc b/mindspore-lite/src/common/ops/operator_populate/quant_dtype_cast_populate.cc index acbb4c1c8865a3002e64ea35a61b1c24d04e2fa4..40c3054de3d6dcf45f9873cfe4dfbdc09c442086 100644 --- a/mindspore-lite/src/common/ops/operator_populate/quant_dtype_cast_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/quant_dtype_cast_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" #include "infer/quant_dtype_cast.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_q.h" using mindspore::ops::kNameQuantDTypeCast; diff --git a/mindspore-lite/src/common/ops/operator_populate/random_normal_populate.cc b/mindspore-lite/src/common/ops/operator_populate/random_normal_populate.cc index a293805f2df670479c8655fdb4c9015b8a1f8313..bbf19eb54f926df779e4d9bf734ce09a92017199 100644 --- a/mindspore-lite/src/common/ops/operator_populate/random_normal_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/random_normal_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" #include "infer/random_normal.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" using mindspore::ops::kNameRandomNormal; diff --git a/mindspore-lite/src/common/ops/operator_populate/random_standard_normal_populate.cc b/mindspore-lite/src/common/ops/operator_populate/random_standard_normal_populate.cc index 7f4a060d4e9b9480305b1b043a9326e389feaa5c..e2155b909bb7151f2eb57d4fc22b1ad7d3dc4dae 100644 --- a/mindspore-lite/src/common/ops/operator_populate/random_standard_normal_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/random_standard_normal_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" #include "infer/random_standard_normal.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" using mindspore::ops::kNameRandomStandardNormal; diff --git a/mindspore-lite/src/common/ops/operator_populate/range_populate.cc b/mindspore-lite/src/common/ops/operator_populate/range_populate.cc index b2a65a814a54cf480bd5b129c08dfff094c6c8a7..4446c9e8c676648015639246aca42596015d3e9e 100644 --- a/mindspore-lite/src/common/ops/operator_populate/range_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/range_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/range_parameter.h" +#include "nnacl_c/range_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" using mindspore::ops::kNameRange; diff --git a/mindspore-lite/src/common/ops/operator_populate/reduce_populate.cc b/mindspore-lite/src/common/ops/operator_populate/reduce_populate.cc index 6742ef2cdd27fd16b45b7c9d5da4972921ff4f46..73b738a45a3160a350f81de9929eb29609167ef8 100644 --- a/mindspore-lite/src/common/ops/operator_populate/reduce_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/reduce_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/reduce_parameter.h" #include "infer/reduce.h" #include "infer/cxx_api/reduce_fusion.h" #include "mindspore/ops/op_def/op_name.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/reduce_scatter.cc b/mindspore-lite/src/common/ops/operator_populate/reduce_scatter.cc index e308a5cebc74dc1ce40a001cb54b066b0b685144..d3b1f7b9e5519fe9484a6e625d5dd949dd8a1fe1 100644 --- a/mindspore-lite/src/common/ops/operator_populate/reduce_scatter.cc +++ b/mindspore-lite/src/common/ops/operator_populate/reduce_scatter.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/reduce_scatter_parameter.h" +#include "nnacl_c/reduce_scatter_parameter.h" #include "infer/reduce_scatter.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" using mindspore::ops::kNameReduceScatter; diff --git a/mindspore-lite/src/common/ops/operator_populate/resize_populate.cc b/mindspore-lite/src/common/ops/operator_populate/resize_populate.cc index db587a915ba64407f90c507aa5d4aa0a6f746f8c..44259066f54653b01231620073a16473582710a8 100644 --- a/mindspore-lite/src/common/ops/operator_populate/resize_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/resize_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" #include "infer/resize.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" using mindspore::ops::kNameResize; diff --git a/mindspore-lite/src/common/ops/operator_populate/reverse_populate.cc b/mindspore-lite/src/common/ops/operator_populate/reverse_populate.cc index d6a6a1319482f00c1af553ea7b926468de465e2e..aad907d5e07c9270221dcb2b9af852c5aaf5b463 100644 --- a/mindspore-lite/src/common/ops/operator_populate/reverse_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/reverse_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/reverse_fp32.h" +#include "nnacl_c/fp32/reverse_fp32.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/reverse_sequence_populate.cc b/mindspore-lite/src/common/ops/operator_populate/reverse_sequence_populate.cc index 9a670f877f4ebd23cf112d3c8343037204615e7c..a4c2afa2103011153122a96e2685c3959c9b185c 100644 --- a/mindspore-lite/src/common/ops/operator_populate/reverse_sequence_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/reverse_sequence_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/reverse_sequence_parameter.h" +#include "nnacl_c/reverse_sequence_parameter.h" #include "infer/reverse_sequence.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" using mindspore::ops::kNameReverseSequence; diff --git a/mindspore-lite/src/common/ops/operator_populate/roi_pooling_populate.cc b/mindspore-lite/src/common/ops/operator_populate/roi_pooling_populate.cc index f0856d7b28a634973a255e77fd0fc3b571e65cef..fe3c8dd6d7c41519e9ae1871899c2970c961e704 100644 --- a/mindspore-lite/src/common/ops/operator_populate/roi_pooling_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/roi_pooling_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/roi_pooling_fp32.h" +#include "nnacl_c/fp32/roi_pooling_fp32.h" #include "infer/roi_pooling.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" using mindspore::ops::kNameROIPooling; diff --git a/mindspore-lite/src/common/ops/operator_populate/scale_populate.cc b/mindspore-lite/src/common/ops/operator_populate/scale_populate.cc index 6530ef1403019ac0fd6b018d0b3da9783c54c4d0..c51faba0337c08e4d3eaed8d6caa835bfa78d6ce 100644 --- a/mindspore-lite/src/common/ops/operator_populate/scale_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/scale_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/scale_parameter.h" +#include "nnacl_c/scale_parameter.h" #include "infer/cxx_api/scale_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameScaleFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/scatter_element_populate.cc b/mindspore-lite/src/common/ops/operator_populate/scatter_element_populate.cc index 0bef9f13b0f3e54253f27b43d83b5e8694aad980..68c4269605edb87c864ba7c346fdf676bd1ca2c5 100644 --- a/mindspore-lite/src/common/ops/operator_populate/scatter_element_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/scatter_element_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/scatter_elements_parameter.h" +#include "nnacl_c/scatter_elements_parameter.h" #include "infer/scatter_elements.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameScatterElements; diff --git a/mindspore-lite/src/common/ops/operator_populate/skip_gram_populate.cc b/mindspore-lite/src/common/ops/operator_populate/skip_gram_populate.cc index d1e01509f7aa8455bf383621d7eda377a73948f5..494d74053d3af5ed940673f008913d9346433013 100644 --- a/mindspore-lite/src/common/ops/operator_populate/skip_gram_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/skip_gram_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/skip_gram_parameter.h" +#include "nnacl_c/skip_gram_parameter.h" #include "infer/skip_gram.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSkipGram; diff --git a/mindspore-lite/src/common/ops/operator_populate/slice_populate.cc b/mindspore-lite/src/common/ops/operator_populate/slice_populate.cc index bcf3808c910cc4292fca3e2694b74933139670e0..71d773746b32a7b02aec3ac5ed75491f057a0d81 100644 --- a/mindspore-lite/src/common/ops/operator_populate/slice_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/slice_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/slice_parameter.h" +#include "nnacl_c/slice_parameter.h" #include "infer/cxx_api/slice_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSliceFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/softmax_populate.cc b/mindspore-lite/src/common/ops/operator_populate/softmax_populate.cc index 1ddd8b47b5be1165ed0b9438d9689d0da6263fde..aa8b16c2ec3f92648c4904219e0a82a8507994ba 100644 --- a/mindspore-lite/src/common/ops/operator_populate/softmax_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/softmax_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kAxis; diff --git a/mindspore-lite/src/common/ops/operator_populate/space_to_batch_nd_populate.cc b/mindspore-lite/src/common/ops/operator_populate/space_to_batch_nd_populate.cc index 8c33063dc6c02c6978d3f1e440598f7954dda5b5..302f527e0986dc20c5746fd767029450a5e99538 100644 --- a/mindspore-lite/src/common/ops/operator_populate/space_to_batch_nd_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/space_to_batch_nd_populate.cc @@ -15,7 +15,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" #include "infer/space_to_batch_nd.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSpaceToBatchND; diff --git a/mindspore-lite/src/common/ops/operator_populate/space_to_batch_populate.cc b/mindspore-lite/src/common/ops/operator_populate/space_to_batch_populate.cc index efd292d8cc2c1d8e671f99f24866f4dfc7eaba50..39b8326fad35f04aa46bdf263a01200eb2dd9486 100644 --- a/mindspore-lite/src/common/ops/operator_populate/space_to_batch_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/space_to_batch_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" #include "infer/space_to_batch.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSpaceToBatch; diff --git a/mindspore-lite/src/common/ops/operator_populate/space_to_depth_populate.cc b/mindspore-lite/src/common/ops/operator_populate/space_to_depth_populate.cc index c71e27a12cd8c5b880c074714535757f1977c0b0..778c9783d42cd74d1ef974fc86444b97c60bf240 100644 --- a/mindspore-lite/src/common/ops/operator_populate/space_to_depth_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/space_to_depth_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/space_to_depth_parameter.h" +#include "nnacl_c/space_to_depth_parameter.h" #include "infer/space_to_depth.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSpaceToDepth; diff --git a/mindspore-lite/src/common/ops/operator_populate/sparse_softmax_cross_entropy_with_logits_populate.cc b/mindspore-lite/src/common/ops/operator_populate/sparse_softmax_cross_entropy_with_logits_populate.cc index 7261d819e9bf2290d6001ecf56c2bdcf8a56ed00..f52e6de5262098b80f350be62ae63a6d803a4228 100644 --- a/mindspore-lite/src/common/ops/operator_populate/sparse_softmax_cross_entropy_with_logits_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/sparse_softmax_cross_entropy_with_logits_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32_grad/softmax_grad.h" +#include "nnacl_c/fp32_grad/softmax_grad.h" #include "infer/sparse_softmax_cross_entropy_with_logits.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSparseSoftmaxCrossEntropyWithLogits; diff --git a/mindspore-lite/src/common/ops/operator_populate/splice_populate.cc b/mindspore-lite/src/common/ops/operator_populate/splice_populate.cc index 345ae17702297726f9b91a905b6a4559518f0f96..bffd6c9dfe91624881faac39c5097faa88853f74 100644 --- a/mindspore-lite/src/common/ops/operator_populate/splice_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/splice_populate.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/op_base.h" -#include "nnacl/splice_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/splice_parameter.h" #include "infer/splice.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSplice; diff --git a/mindspore-lite/src/common/ops/operator_populate/split_populate.cc b/mindspore-lite/src/common/ops/operator_populate/split_populate.cc index f20b6d354c6fa632a3e8eeff03e6270806d3b95a..3c2ba4f2d5ebdc2acc9b10158f838fc2137edd90 100644 --- a/mindspore-lite/src/common/ops/operator_populate/split_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/split_populate.cc @@ -15,8 +15,8 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/split_parameter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kAxis; diff --git a/mindspore-lite/src/common/ops/operator_populate/split_with_overlap_populate.cc b/mindspore-lite/src/common/ops/operator_populate/split_with_overlap_populate.cc index 2dee4eea518e83e1c0a31bcceed2a3d215bdac9d..01f028acb816add0af8ec0035c30cdf487a39389 100644 --- a/mindspore-lite/src/common/ops/operator_populate/split_with_overlap_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/split_with_overlap_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" #include "infer/split_with_overlap.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSplitWithOverlap; diff --git a/mindspore-lite/src/common/ops/operator_populate/squeeze_populate.cc b/mindspore-lite/src/common/ops/operator_populate/squeeze_populate.cc index 3a7730562b71be0d13e113a4583cd20a2c6f8ecf..68fd05f67a0120780459954dab58e5e742964c4a 100644 --- a/mindspore-lite/src/common/ops/operator_populate/squeeze_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/squeeze_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/squeeze_parameter.h" +#include "nnacl_c/squeeze_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kAxis; diff --git a/mindspore-lite/src/common/ops/operator_populate/stack_operator_populate.cc b/mindspore-lite/src/common/ops/operator_populate/stack_operator_populate.cc index 114b3a56ae261aa4471251eff7caee6566b639df..ba55af572c45b4ff02fc73bc399cf317fa351f74 100644 --- a/mindspore-lite/src/common/ops/operator_populate/stack_operator_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/stack_operator_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/stack_parameter.h" +#include "nnacl_c/stack_parameter.h" #include "infer/stack.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameStack; diff --git a/mindspore-lite/src/common/ops/operator_populate/strided_slice_grad_populate.cc b/mindspore-lite/src/common/ops/operator_populate/strided_slice_grad_populate.cc index 9ad629baccf6e25bbf4bfd7450f6d33f6f0e6efa..f97fc71dbf801029b6c0b7b841265bf596f77458 100644 --- a/mindspore-lite/src/common/ops/operator_populate/strided_slice_grad_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/strided_slice_grad_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" #include "infer/grad/strided_slice_grad.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameStridedSliceGrad; diff --git a/mindspore-lite/src/common/ops/operator_populate/strided_slice_operator_populate.cc b/mindspore-lite/src/common/ops/operator_populate/strided_slice_operator_populate.cc index 8495832979e0a2c89bba83cd8a9e2a13b50eb2f0..91973e833e4bcb8993464812ffe63ffbc151e546 100644 --- a/mindspore-lite/src/common/ops/operator_populate/strided_slice_operator_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/strided_slice_operator_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameStridedSlice; diff --git a/mindspore-lite/src/common/ops/operator_populate/tensor_array_populate.cc b/mindspore-lite/src/common/ops/operator_populate/tensor_array_populate.cc index 6093df8505c6f7c000284a1ee7b73dfca4bd4214..b8cfdba1b9d9ef3ba193d408ed507863ca51bf40 100644 --- a/mindspore-lite/src/common/ops/operator_populate/tensor_array_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/tensor_array_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/tensor_array_parameter.h" +#include "nnacl_c/tensor_array_parameter.h" #include "infer/tensor_array.h" #include "infer/tensor_array_read.h" #include "infer/tensor_array_write.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/tensor_list_from_tensor_populate.cc b/mindspore-lite/src/common/ops/operator_populate/tensor_list_from_tensor_populate.cc index 722b01efbb8babfcfb2b317a301d7323e8863ca7..3b872825a247443b12d1816f9f24fefd5d5bc291 100644 --- a/mindspore-lite/src/common/ops/operator_populate/tensor_list_from_tensor_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/tensor_list_from_tensor_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" #include "infer/tensor_list_from_tensor.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" using mindspore::ops::kNameTensorListFromTensor; diff --git a/mindspore-lite/src/common/ops/operator_populate/tensor_list_get_item_populate.cc b/mindspore-lite/src/common/ops/operator_populate/tensor_list_get_item_populate.cc index 82878e872fc0feadf9896b657c57c9446aa8cfaa..49e41305c4021fb4f15c8caa59d0ae3dcdb6810c 100644 --- a/mindspore-lite/src/common/ops/operator_populate/tensor_list_get_item_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/tensor_list_get_item_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" #include "infer/tensor_list_get_item.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" using mindspore::ops::kElement_dtype; diff --git a/mindspore-lite/src/common/ops/operator_populate/tensor_list_reserve_populate.cc b/mindspore-lite/src/common/ops/operator_populate/tensor_list_reserve_populate.cc index fb42bcca06d38f459c829ab433974c90227bc976..61bbca3cd5826d8f7d3c67903228be9669aad171 100644 --- a/mindspore-lite/src/common/ops/operator_populate/tensor_list_reserve_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/tensor_list_reserve_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" #include "infer/tensor_list_reserve.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" using mindspore::ops::kNameTensorListReserve; diff --git a/mindspore-lite/src/common/ops/operator_populate/tensor_list_set_item_populate.cc b/mindspore-lite/src/common/ops/operator_populate/tensor_list_set_item_populate.cc index 16fbb7dcb2d70950c8fcb180faeb2c65f62c5f0f..44b8a8a82bf2296e61cbf11fe56af73c3c3246c8 100644 --- a/mindspore-lite/src/common/ops/operator_populate/tensor_list_set_item_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/tensor_list_set_item_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" #include "infer/tensor_list_set_item.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" using mindspore::ops::kNameTensorListSetItem; diff --git a/mindspore-lite/src/common/ops/operator_populate/tensor_list_stack_populate.cc b/mindspore-lite/src/common/ops/operator_populate/tensor_list_stack_populate.cc index 9b1611d05923a995b56a281b49e08ddb9f528905..a5dfa96c9ee553d11469054a13c7c165dc091f4c 100644 --- a/mindspore-lite/src/common/ops/operator_populate/tensor_list_stack_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/tensor_list_stack_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" #include "infer/tensor_list_stack.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" using mindspore::ops::kNameTensorListStack; diff --git a/mindspore-lite/src/common/ops/operator_populate/tile_operator_populate.cc b/mindspore-lite/src/common/ops/operator_populate/tile_operator_populate.cc index 07767c5814328ae9fb53c2195f739d2fe3262f0a..3e60e3cb11ed02b09e58945cf4715613185ad736 100644 --- a/mindspore-lite/src/common/ops/operator_populate/tile_operator_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/tile_operator_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/base/tile_base.h" +#include "nnacl_c/base/tile_base.h" #include "infer/cxx_api/tile_fusion.h" #include "mindspore/ops/op_def/op_name.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/topk_populate.cc b/mindspore-lite/src/common/ops/operator_populate/topk_populate.cc index c3e38984fdda440cf7e7cf1ae625822c8b145a27..6d02fe2ecc5178cd9f4302ece1c0df14dade6cf4 100644 --- a/mindspore-lite/src/common/ops/operator_populate/topk_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/topk_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/topk_fp32.h" +#include "nnacl_c/fp32/topk_fp32.h" #include "infer/cxx_api/topk_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/uniform_real_populate.cc b/mindspore-lite/src/common/ops/operator_populate/uniform_real_populate.cc index c08c0435fb07875573cf5f0381e22e9368d4c387..df55e61d9757f860625d68dc53e2b295018365dd 100644 --- a/mindspore-lite/src/common/ops/operator_populate/uniform_real_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/uniform_real_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/populate/default_populate.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" #include "infer/uniform_real.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" using mindspore::ops::kNameUniformReal; diff --git a/mindspore-lite/src/common/ops/operator_populate/unsqueeze_populate.cc b/mindspore-lite/src/common/ops/operator_populate/unsqueeze_populate.cc index fcfb8a5ed4ca130f6b91db2e1458bc511faa69b1..b22ca15b2161658e50c866e217efb771dddc9a42 100644 --- a/mindspore-lite/src/common/ops/operator_populate/unsqueeze_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/unsqueeze_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/unsqueeze_parameter.h" +#include "nnacl_c/unsqueeze_parameter.h" #include "infer/unsqueeze.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" using mindspore::ops::kAxis; diff --git a/mindspore-lite/src/common/ops/operator_populate/unstack_populate.cc b/mindspore-lite/src/common/ops/operator_populate/unstack_populate.cc index ab05ff38c7db4d082387da6dc21aa92efcd229ea..49c195e4217b568b430e3a6c09d8af96eea2ce36 100644 --- a/mindspore-lite/src/common/ops/operator_populate/unstack_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/unstack_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/unstack_parameter.h" +#include "nnacl_c/unstack_parameter.h" #include "infer/unstack.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" using mindspore::ops::kAxis; diff --git a/mindspore-lite/src/common/ops/populate/activation_grad_populate.cc b/mindspore-lite/src/common/ops/populate/activation_grad_populate.cc index 59a8c297015f2aaa29e85e56647963f09453f632..121a32da2fc6d12988417a53467b1f91c992e83a 100644 --- a/mindspore-lite/src/common/ops/populate/activation_grad_populate.cc +++ b/mindspore-lite/src/common/ops/populate/activation_grad_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32_grad/activation_grad_fp32.h" +#include "nnacl_c/fp32_grad/activation_grad_fp32.h" using mindspore::schema::PrimitiveType_ActivationGrad; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/activation_populate.cc b/mindspore-lite/src/common/ops/populate/activation_populate.cc index 1cd1d966d9cd4dc9d020796e163def998ffb49ad..893a57a0db57212c44aae5eb05154a3a12f7bd0f 100644 --- a/mindspore-lite/src/common/ops/populate/activation_populate.cc +++ b/mindspore-lite/src/common/ops/populate/activation_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" using mindspore::schema::PrimitiveType_Activation; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/adam_populate.cc b/mindspore-lite/src/common/ops/populate/adam_populate.cc index 6dd94f7b890a0a292cdae6072999f0fc195ce34e..119130e5301da0982d469d1783b72cb35a46570a 100644 --- a/mindspore-lite/src/common/ops/populate/adam_populate.cc +++ b/mindspore-lite/src/common/ops/populate/adam_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::schema::PrimitiveType_Adam; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/add_populate.cc b/mindspore-lite/src/common/ops/populate/add_populate.cc index 1de02b45b6c8ca15b60aee3ae709ba9bb585a481..e529d2aef74acb4efeb4d60747b5c7205e5cdf36 100644 --- a/mindspore-lite/src/common/ops/populate/add_populate.cc +++ b/mindspore-lite/src/common/ops/populate/add_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" #include "src/common/ops/populate/arithmetic_populate.h" using mindspore::schema::PrimitiveType_AddFusion; diff --git a/mindspore-lite/src/common/ops/populate/adder_populate.cc b/mindspore-lite/src/common/ops/populate/adder_populate.cc index a09e00643afde248a99407150618ea8ee406bcb7..337c23fe8bdb5eb6f478d7999bc309ba46c699fe 100644 --- a/mindspore-lite/src/common/ops/populate/adder_populate.cc +++ b/mindspore-lite/src/common/ops/populate/adder_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/log_adapter.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_AdderFusion; diff --git a/mindspore-lite/src/common/ops/populate/affine_populate.cc b/mindspore-lite/src/common/ops/populate/affine_populate.cc index 3e780d8cde92595179323ede21a736d3c6c13cce..7730a09c3b230d069a38ad282b0ca9176c0e1db1 100644 --- a/mindspore-lite/src/common/ops/populate/affine_populate.cc +++ b/mindspore-lite/src/common/ops/populate/affine_populate.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/op_base.h" -#include "nnacl/affine_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/affine_parameter.h" using mindspore::schema::PrimitiveType_Affine; diff --git a/mindspore-lite/src/common/ops/populate/all_gather.cc b/mindspore-lite/src/common/ops/populate/all_gather.cc index ea8c1f1e8b119687e8648d6b559a961fad9f09fa..3c1e2b7da1344053a230cc206624c6b724f9c9ec 100644 --- a/mindspore-lite/src/common/ops/populate/all_gather.cc +++ b/mindspore-lite/src/common/ops/populate/all_gather.cc @@ -16,7 +16,7 @@ #include "schema/ops_generated.h" #include "schema/model_generated.h" -#include "nnacl/all_gather_parameter.h" +#include "nnacl_c/all_gather_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_AllGather; diff --git a/mindspore-lite/src/common/ops/populate/argmax_populate.cc b/mindspore-lite/src/common/ops/populate/argmax_populate.cc index 49639aa37160efad2415f702062ee4708c4b9ed5..0653820678c866d3795ccfefe9f42b08aa6c5776 100644 --- a/mindspore-lite/src/common/ops/populate/argmax_populate.cc +++ b/mindspore-lite/src/common/ops/populate/argmax_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/arg_min_max_parameter.h" +#include "nnacl_c/arg_min_max_parameter.h" using mindspore::schema::PrimitiveType_ArgMaxFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/argmin_populate.cc b/mindspore-lite/src/common/ops/populate/argmin_populate.cc index 280d1edb960f4915295429d08932863932d4b97d..730daaf4bba67b4493835313fb26d6d955f25069 100644 --- a/mindspore-lite/src/common/ops/populate/argmin_populate.cc +++ b/mindspore-lite/src/common/ops/populate/argmin_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/arg_min_max_parameter.h" +#include "nnacl_c/arg_min_max_parameter.h" using mindspore::schema::PrimitiveType_ArgMinFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/arithmetic_populate.h b/mindspore-lite/src/common/ops/populate/arithmetic_populate.h index 7601f25c3d92168c91ea75b1f8d539bff0f9e3f8..e79fa2c45e1d114456fbbbe015d9c8cbf3c2d6cc 100644 --- a/mindspore-lite/src/common/ops/populate/arithmetic_populate.h +++ b/mindspore-lite/src/common/ops/populate/arithmetic_populate.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_ARITHMETIC_POPULATE_H_ #define MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_ARITHMETIC_POPULATE_H_ -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/common/ops/populate/arithmetic_self_populate.cc b/mindspore-lite/src/common/ops/populate/arithmetic_self_populate.cc index 05e0253b9dabffd2f5d5b07b8029313efd5eadc1..69e3725bc09f4cc8fca9dda1619892825288850d 100644 --- a/mindspore-lite/src/common/ops/populate/arithmetic_self_populate.cc +++ b/mindspore-lite/src/common/ops/populate/arithmetic_self_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/log_adapter.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/arithmetic_self_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_Abs; using mindspore::schema::PrimitiveType_Ceil; diff --git a/mindspore-lite/src/common/ops/populate/attention_populate.cc b/mindspore-lite/src/common/ops/populate/attention_populate.cc index 69c0bdd5537c4422124aa2b664a3b778c3624478..75c86cf08652466238990caf3a7b8c88b3505ce6 100644 --- a/mindspore-lite/src/common/ops/populate/attention_populate.cc +++ b/mindspore-lite/src/common/ops/populate/attention_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/attention_parameter.h" +#include "nnacl_c/attention_parameter.h" using mindspore::schema::PrimitiveType_Attention; diff --git a/mindspore-lite/src/common/ops/populate/audio_spectrogram_populate.cc b/mindspore-lite/src/common/ops/populate/audio_spectrogram_populate.cc index b2edc51ec3b5470b1f3ff927d9670b23a22b558c..1eb73d6189b0cb5911c7b262a71fc51ebf6c27c9 100644 --- a/mindspore-lite/src/common/ops/populate/audio_spectrogram_populate.cc +++ b/mindspore-lite/src/common/ops/populate/audio_spectrogram_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/infer/audio_spectrogram_infer.h" +#include "nnacl_c/infer/audio_spectrogram_infer.h" using mindspore::schema::PrimitiveType_AudioSpectrogram; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/batch_norm_populate.cc b/mindspore-lite/src/common/ops/populate/batch_norm_populate.cc index 9923672e0261dcbe6128b93a1ae9b4a3b27f18ea..7e2f54185929e946ef01940a1b3a2bf9ce06d27d 100644 --- a/mindspore-lite/src/common/ops/populate/batch_norm_populate.cc +++ b/mindspore-lite/src/common/ops/populate/batch_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" using mindspore::schema::PrimitiveType_BatchNorm; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/batch_to_space_populate.cc b/mindspore-lite/src/common/ops/populate/batch_to_space_populate.cc index 4820b6c7a06d6cd9cceb7ef42e6cf759db5ca995..14105093885f2f31948032ef5af61727245201dc 100644 --- a/mindspore-lite/src/common/ops/populate/batch_to_space_populate.cc +++ b/mindspore-lite/src/common/ops/populate/batch_to_space_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/batch_to_space_parameter.h" +#include "nnacl_c/batch_to_space_parameter.h" using mindspore::schema::PrimitiveType_BatchToSpace; using mindspore::schema::PrimitiveType_BatchToSpaceND; diff --git a/mindspore-lite/src/common/ops/populate/bias_add_populate.cc b/mindspore-lite/src/common/ops/populate/bias_add_populate.cc index aa2b65937423e438c7eee3561dcfe91ecb5016ca..bdac3bef186aa6336a95a3bbbf92bee688615d69 100644 --- a/mindspore-lite/src/common/ops/populate/bias_add_populate.cc +++ b/mindspore-lite/src/common/ops/populate/bias_add_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" using mindspore::schema::PrimitiveType_BiasAdd; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/broadcast_to_populate.cc b/mindspore-lite/src/common/ops/populate/broadcast_to_populate.cc index a8c51d1c0a758a05f933d6442b388751caa9c78f..59017fab3ab987bb70daeb256831366524d3b2a6 100644 --- a/mindspore-lite/src/common/ops/populate/broadcast_to_populate.cc +++ b/mindspore-lite/src/common/ops/populate/broadcast_to_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/base/broadcast_to.h" +#include "nnacl_c/base/broadcast_to.h" using mindspore::schema::PrimitiveType_BroadcastTo; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/call_populate.cc b/mindspore-lite/src/common/ops/populate/call_populate.cc index 105ff5eee25ab75d8fdf88d581b4155cb5e228a4..67da1fff192e6b2a19e07a2979021410538437d9 100644 --- a/mindspore-lite/src/common/ops/populate/call_populate.cc +++ b/mindspore-lite/src/common/ops/populate/call_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/call_parameter.h" +#include "nnacl_c/call_parameter.h" using mindspore::schema::PrimitiveType_Call; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/clip_populate.cc b/mindspore-lite/src/common/ops/populate/clip_populate.cc index df1da96a65c5f85b1a0ce14a9263d7aeeb0fee0e..40d8155e5be3ad7aa3cefa48a145f581d91f2c18 100644 --- a/mindspore-lite/src/common/ops/populate/clip_populate.cc +++ b/mindspore-lite/src/common/ops/populate/clip_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/clip_parameter.h" +#include "nnacl_c/clip_parameter.h" using mindspore::schema::PrimitiveType_Clip; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/concat_populate.cc b/mindspore-lite/src/common/ops/populate/concat_populate.cc index 485f711461d4affb7120d8826cc0b0d2a89ee7b2..e86126c0c6afdde4668c3da92a1d45b4378b40d2 100644 --- a/mindspore-lite/src/common/ops/populate/concat_populate.cc +++ b/mindspore-lite/src/common/ops/populate/concat_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" using mindspore::schema::PrimitiveType_Concat; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/constant_of_shape_populate.cc b/mindspore-lite/src/common/ops/populate/constant_of_shape_populate.cc index 231c2dec914959d1b2e981e91e59d1e8a3a3af94..519c9479940a0a7b4c1ecc1905762ff08e68e67b 100644 --- a/mindspore-lite/src/common/ops/populate/constant_of_shape_populate.cc +++ b/mindspore-lite/src/common/ops/populate/constant_of_shape_populate.cc @@ -15,7 +15,7 @@ */ #include "ir/dtype/type_id.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/constant_of_shape_parameter.h" +#include "nnacl_c/constant_of_shape_parameter.h" using mindspore::schema::PrimitiveType_ConstantOfShape; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/control/tensor_array_populate.cc b/mindspore-lite/src/common/ops/populate/control/tensor_array_populate.cc index b834071a563974a1355fbda59cd387917e7872d4..6858ebf77882fbd309a4d854014abb6bbe497da9 100644 --- a/mindspore-lite/src/common/ops/populate/control/tensor_array_populate.cc +++ b/mindspore-lite/src/common/ops/populate/control/tensor_array_populate.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/op_base.h" -#include "nnacl/tensor_array_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_array_parameter.h" using mindspore::schema::PrimitiveType_TensorArray; using mindspore::schema::PrimitiveType_TensorArrayRead; diff --git a/mindspore-lite/src/common/ops/populate/control/tensorlistfromtensor_populate.cc b/mindspore-lite/src/common/ops/populate/control/tensorlistfromtensor_populate.cc index cd72d1f929d697da25658b089c09589508720d0a..83239162514d58791d474bb55941b70e5a43efd7 100644 --- a/mindspore-lite/src/common/ops/populate/control/tensorlistfromtensor_populate.cc +++ b/mindspore-lite/src/common/ops/populate/control/tensorlistfromtensor_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_TensorListFromTensor; diff --git a/mindspore-lite/src/common/ops/populate/control/tensorlistgetitem_populate.cc b/mindspore-lite/src/common/ops/populate/control/tensorlistgetitem_populate.cc index 0374044d333d6da0b0615c136ccaad2c71cacfdc..6459563665bd862c8e1f7e3557559fee43af14b8 100644 --- a/mindspore-lite/src/common/ops/populate/control/tensorlistgetitem_populate.cc +++ b/mindspore-lite/src/common/ops/populate/control/tensorlistgetitem_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" using mindspore::schema::PrimitiveType_TensorListGetItem; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/control/tensorlistreserve_populate.cc b/mindspore-lite/src/common/ops/populate/control/tensorlistreserve_populate.cc index 306f2529867436bef2cdb259531347d8fc63c0c5..f7590986b973265cfc4e97571210f35c7f1ec7ae 100644 --- a/mindspore-lite/src/common/ops/populate/control/tensorlistreserve_populate.cc +++ b/mindspore-lite/src/common/ops/populate/control/tensorlistreserve_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" using mindspore::schema::PrimitiveType_TensorListReserve; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/control/tensorlistsetlitem_populate.cc b/mindspore-lite/src/common/ops/populate/control/tensorlistsetlitem_populate.cc index 7cfebab481c363daad09ccb26ba2860759e68b69..1e2107d76b4e1dd0110f729d9778794da2ee399d 100644 --- a/mindspore-lite/src/common/ops/populate/control/tensorlistsetlitem_populate.cc +++ b/mindspore-lite/src/common/ops/populate/control/tensorlistsetlitem_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" using mindspore::schema::PrimitiveType_TensorListSetItem; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/control/tensorliststack_populate.cc b/mindspore-lite/src/common/ops/populate/control/tensorliststack_populate.cc index b053c2512ae262e37cd274f0fc8ecf8e0545d8f4..1e720291b9c774f2e8902ba4198f02102040e4dc 100644 --- a/mindspore-lite/src/common/ops/populate/control/tensorliststack_populate.cc +++ b/mindspore-lite/src/common/ops/populate/control/tensorliststack_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" using mindspore::schema::PrimitiveType_TensorListStack; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/conv2d_populate.cc b/mindspore-lite/src/common/ops/populate/conv2d_populate.cc index 79d8b792fdc98b854989f9a4ed05f898db5ef857..3310cdb16d5315ea5341eaf4a3dd9a0e4a854c17 100644 --- a/mindspore-lite/src/common/ops/populate/conv2d_populate.cc +++ b/mindspore-lite/src/common/ops/populate/conv2d_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_Conv2DFusion; diff --git a/mindspore-lite/src/common/ops/populate/crop_and_resize_populate.cc b/mindspore-lite/src/common/ops/populate/crop_and_resize_populate.cc index 096f51a6ab16e38afb0911a2e6e93251b35f57c8..470727644f1bb5751ab1314a6a7a8b5f81b7a186 100644 --- a/mindspore-lite/src/common/ops/populate/crop_and_resize_populate.cc +++ b/mindspore-lite/src/common/ops/populate/crop_and_resize_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" using mindspore::schema::PrimitiveType_CropAndResize; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/crop_populate.cc b/mindspore-lite/src/common/ops/populate/crop_populate.cc index 7db5c4b5d63f14aaeb8b3a607d1d49c35d6ece1a..357a36b0af6814df8c3bf1be9dfc16f658f8839d 100644 --- a/mindspore-lite/src/common/ops/populate/crop_populate.cc +++ b/mindspore-lite/src/common/ops/populate/crop_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/crop_parameter.h" +#include "nnacl_c/crop_parameter.h" using mindspore::schema::PrimitiveType_Crop; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/cumsum_populate.cc b/mindspore-lite/src/common/ops/populate/cumsum_populate.cc index 76fc45a3658d3691911ae1aee63a9f39a946ff2a..af59a8a3bbceeaa83d0631271d87a5b236c52ccb 100644 --- a/mindspore-lite/src/common/ops/populate/cumsum_populate.cc +++ b/mindspore-lite/src/common/ops/populate/cumsum_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/cumsum_parameter.h" +#include "nnacl_c/cumsum_parameter.h" using mindspore::schema::PrimitiveType_CumSum; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/custom_populate.cc b/mindspore-lite/src/common/ops/populate/custom_populate.cc index 0cde665bc788a80d15126b09d507b391016b0435..84ea9ace12eeef3774c0db97b260d5d176817916 100644 --- a/mindspore-lite/src/common/ops/populate/custom_populate.cc +++ b/mindspore-lite/src/common/ops/populate/custom_populate.cc @@ -16,14 +16,14 @@ #include "src/common/ops/populate/populate_register.h" #include "src/common/log_adapter.h" #include "src/tensor.h" -#include "nnacl/custom_parameter.h" -#include "nnacl/split_parameter.h" -#include "nnacl/custom_gru_parameter.h" -#include "nnacl/custom_masked_fill_parameter.h" -#include "nnacl/custom_is_inf_parameter.h" -#include "nnacl/scatter_nd_parameter.h" -#include "nnacl/conv3d_parameter.h" -#include "nnacl/grid_sampler_parameter.h" +#include "nnacl_c/custom_parameter.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/custom_gru_parameter.h" +#include "nnacl_c/custom_masked_fill_parameter.h" +#include "nnacl_c/custom_is_inf_parameter.h" +#include "nnacl_c/scatter_nd_parameter.h" +#include "nnacl_c/conv3d_parameter.h" +#include "nnacl_c/grid_sampler_parameter.h" using mindspore::schema::PrimitiveType_Custom; diff --git a/mindspore-lite/src/common/ops/populate/deconv2d_populate.cc b/mindspore-lite/src/common/ops/populate/deconv2d_populate.cc index 4b6b6a76012b84c045b0f83ffbe247fa288adc40..25c8fcbf012d4bb95f4e105d179612c436b5e050 100644 --- a/mindspore-lite/src/common/ops/populate/deconv2d_populate.cc +++ b/mindspore-lite/src/common/ops/populate/deconv2d_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/log_adapter.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" using mindspore::schema::PrimitiveType_Conv2dTransposeFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/default_populate.h b/mindspore-lite/src/common/ops/populate/default_populate.h index 6ee48376a1e68f622e1c6fe2e748b33141d9bdce..d13aaa859f6bc89a6d90785ff005ace2784169f7 100644 --- a/mindspore-lite/src/common/ops/populate/default_populate.h +++ b/mindspore-lite/src/common/ops/populate/default_populate.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_DEFAULT_POPULATE_H_ #define MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_DEFAULT_POPULATE_H_ -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/common/ops/populate/depth_to_space_populate.cc b/mindspore-lite/src/common/ops/populate/depth_to_space_populate.cc index 0c7f6a6bed3d4b96470b80eb18bc484dc30a637c..5b83359881d3a439694c5e4d0ec213d359436fb9 100644 --- a/mindspore-lite/src/common/ops/populate/depth_to_space_populate.cc +++ b/mindspore-lite/src/common/ops/populate/depth_to_space_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/depth_to_space_parameter.h" +#include "nnacl_c/depth_to_space_parameter.h" using mindspore::schema::PrimitiveType_DepthToSpace; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/detection_post_process_populate.cc b/mindspore-lite/src/common/ops/populate/detection_post_process_populate.cc index 6cfdb35cd56a85c0a9e9cbbadc2973a930f7fed7..bbb625d3fd6459d957f8bdd83dfe1ce64aad3d51 100644 --- a/mindspore-lite/src/common/ops/populate/detection_post_process_populate.cc +++ b/mindspore-lite/src/common/ops/populate/detection_post_process_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/detection_post_process_parameter.h" +#include "nnacl_c/detection_post_process_parameter.h" using mindspore::schema::PrimitiveType_DetectionPostProcess; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/dynamic_quant_populate.cc b/mindspore-lite/src/common/ops/populate/dynamic_quant_populate.cc index 8e3933209d3bef9408dd70c5e5fe9a5e59208f00..86f732f9bc3419cf85ba50638f7785486aea63aa 100644 --- a/mindspore-lite/src/common/ops/populate/dynamic_quant_populate.cc +++ b/mindspore-lite/src/common/ops/populate/dynamic_quant_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/dynamic_quant_parameter.h" +#include "nnacl_c/dynamic_quant_parameter.h" using mindspore::schema::PrimitiveType_DynamicQuant; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/embedding_lookup_populate.cc b/mindspore-lite/src/common/ops/populate/embedding_lookup_populate.cc index 87b56c026bc64a5b0015d17f0f255cbc71b3b532..5173b69ba90dde123cff0964ca0276425400e3c5 100644 --- a/mindspore-lite/src/common/ops/populate/embedding_lookup_populate.cc +++ b/mindspore-lite/src/common/ops/populate/embedding_lookup_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/embedding_lookup_fp32.h" +#include "nnacl_c/fp32/embedding_lookup_fp32.h" using mindspore::schema::PrimitiveType_EmbeddingLookupFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/exp_populate.cc b/mindspore-lite/src/common/ops/populate/exp_populate.cc index 86c024561452c03624909bb47a4e60aa820dbb10..c14d969c95bc561e247c5d5401adf17c34069dd1 100644 --- a/mindspore-lite/src/common/ops/populate/exp_populate.cc +++ b/mindspore-lite/src/common/ops/populate/exp_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/exp_fp32.h" +#include "nnacl_c/fp32/exp_fp32.h" using mindspore::schema::PrimitiveType_ExpFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/flatten_populate.cc b/mindspore-lite/src/common/ops/populate/flatten_populate.cc index f4f30e770d2354282aee544d56236c541e89738c..045bd5361ceeedcd9311cee9b27fb81d5a50665f 100644 --- a/mindspore-lite/src/common/ops/populate/flatten_populate.cc +++ b/mindspore-lite/src/common/ops/populate/flatten_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/flatten_parameter.h" +#include "nnacl_c/flatten_parameter.h" using mindspore::schema::PrimitiveType_Flatten; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/full_connection_populate.cc b/mindspore-lite/src/common/ops/populate/full_connection_populate.cc index 30106e647877b27f11dae1f449562ed5e811db9a..00b57404e01ec658a8a3146de28b5df5c4dfb1fe 100644 --- a/mindspore-lite/src/common/ops/populate/full_connection_populate.cc +++ b/mindspore-lite/src/common/ops/populate/full_connection_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" using mindspore::schema::PrimitiveType_FullConnection; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/fused_batchnorm_populate.cc b/mindspore-lite/src/common/ops/populate/fused_batchnorm_populate.cc index a23fb7077993305d945f4a48416d1c857e832ebd..e5ad31d5cf995e1f45c0e836998db0bf77d0ba9d 100644 --- a/mindspore-lite/src/common/ops/populate/fused_batchnorm_populate.cc +++ b/mindspore-lite/src/common/ops/populate/fused_batchnorm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" using mindspore::schema::PrimitiveType_FusedBatchNorm; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/gather_d_populate.cc b/mindspore-lite/src/common/ops/populate/gather_d_populate.cc index b05fcfeecd94fd25827225e3576bcd86eae0d519..a1bee47b0836b90364649bbc495f844a585afa6d 100644 --- a/mindspore-lite/src/common/ops/populate/gather_d_populate.cc +++ b/mindspore-lite/src/common/ops/populate/gather_d_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/gather_parameter.h" +#include "nnacl_c/gather_parameter.h" using mindspore::schema::PrimitiveType_GatherD; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/gather_nd_populate.cc b/mindspore-lite/src/common/ops/populate/gather_nd_populate.cc index 980a1adfa91cb813d3e9b09504e8d9d271e26a9f..cf79b247b5b05997a7c7d8c017d5b9e16329be50 100644 --- a/mindspore-lite/src/common/ops/populate/gather_nd_populate.cc +++ b/mindspore-lite/src/common/ops/populate/gather_nd_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/gather_nd_parameter.h" +#include "nnacl_c/gather_nd_parameter.h" using mindspore::schema::PrimitiveType_GatherNd; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/gather_populate.cc b/mindspore-lite/src/common/ops/populate/gather_populate.cc index 7e19ccd904287767550f064a1c4c5fe18e920fd5..1ac8e829eb925a5c87b78f47e81d618c31de66c7 100644 --- a/mindspore-lite/src/common/ops/populate/gather_populate.cc +++ b/mindspore-lite/src/common/ops/populate/gather_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/gather_parameter.h" +#include "nnacl_c/gather_parameter.h" using mindspore::schema::PrimitiveType_Gather; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/glu_populate.cc b/mindspore-lite/src/common/ops/populate/glu_populate.cc index 96d232660bb43bdeed53e8235eefbfc44513f4e3..4d2dbd89627967c380263b3013fc35667afc5b9a 100644 --- a/mindspore-lite/src/common/ops/populate/glu_populate.cc +++ b/mindspore-lite/src/common/ops/populate/glu_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/glu_parameter.h" +#include "nnacl_c/glu_parameter.h" using mindspore::schema::PrimitiveType_GLU; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/group_norm_populate.cc b/mindspore-lite/src/common/ops/populate/group_norm_populate.cc index c832e705fe8d3174917bf9cc7b9a0def505580f0..59d3f6e735fbb61fda0319f1938840dfc48a399b 100644 --- a/mindspore-lite/src/common/ops/populate/group_norm_populate.cc +++ b/mindspore-lite/src/common/ops/populate/group_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/group_norm_parameter.h" +#include "nnacl_c/group_norm_parameter.h" using mindspore::schema::PrimitiveType_GroupNormFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/gru_populate.cc b/mindspore-lite/src/common/ops/populate/gru_populate.cc index ed157b6b8d193a1694641ab1ee872883803018ad..df247441aa4c817f64a7850d0ae00f0eeaa18d8b 100644 --- a/mindspore-lite/src/common/ops/populate/gru_populate.cc +++ b/mindspore-lite/src/common/ops/populate/gru_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/gru_fp32.h" +#include "nnacl_c/fp32/gru_fp32.h" using mindspore::schema::PrimitiveType_GRU; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/instance_norm_populate.cc b/mindspore-lite/src/common/ops/populate/instance_norm_populate.cc index 71acd6e378f8094db64a013bd15cfd10dc1a4210..dcc223c28563906d203a7470e5b5dc54853e2794 100644 --- a/mindspore-lite/src/common/ops/populate/instance_norm_populate.cc +++ b/mindspore-lite/src/common/ops/populate/instance_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/instance_norm_parameter.h" +#include "nnacl_c/instance_norm_parameter.h" using mindspore::schema::PrimitiveType_InstanceNorm; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/l2_norm_populate.cc b/mindspore-lite/src/common/ops/populate/l2_norm_populate.cc index c1dc48da0f457ebe53a12cc0ce50f6bbb8dbcffc..46850753cb606e12cae3248ab2da926c36945dca 100644 --- a/mindspore-lite/src/common/ops/populate/l2_norm_populate.cc +++ b/mindspore-lite/src/common/ops/populate/l2_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/l2_norm_parameter.h" +#include "nnacl_c/l2_norm_parameter.h" using mindspore::schema::PrimitiveType_L2NormalizeFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/layer_norm_grad_populate.cc b/mindspore-lite/src/common/ops/populate/layer_norm_grad_populate.cc index fc41e4d651508e7d94c69c0603926b4011d50bc9..498a23e284693a85ab1bec6512a824e372e44b6f 100644 --- a/mindspore-lite/src/common/ops/populate/layer_norm_grad_populate.cc +++ b/mindspore-lite/src/common/ops/populate/layer_norm_grad_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "nnacl/fp32_grad/layernormgrad_parameter.h" +#include "nnacl_c/fp32_grad/layernormgrad_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_LayerNormGrad; diff --git a/mindspore-lite/src/common/ops/populate/layer_norm_populate.cc b/mindspore-lite/src/common/ops/populate/layer_norm_populate.cc index 9da07bfc69a9d2225561800aa65ad4515dfe473f..5ee8056574f1c245d22e5d985225b5b3bd0d7baf 100644 --- a/mindspore-lite/src/common/ops/populate/layer_norm_populate.cc +++ b/mindspore-lite/src/common/ops/populate/layer_norm_populate.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "nnacl/layer_norm_parameter.h" +#include "nnacl_c/layer_norm_parameter.h" #include #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_LayerNormFusion; diff --git a/mindspore-lite/src/common/ops/populate/local_response_normalization_populate.cc b/mindspore-lite/src/common/ops/populate/local_response_normalization_populate.cc index 2b372e602ac5ef59176b8ce3be4b1824d0310ee1..194395872f6cf520a94c85edd1dc4b575cd7c5a5 100644 --- a/mindspore-lite/src/common/ops/populate/local_response_normalization_populate.cc +++ b/mindspore-lite/src/common/ops/populate/local_response_normalization_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/local_response_norm_fp32.h" +#include "nnacl_c/fp32/local_response_norm_fp32.h" using mindspore::schema::PrimitiveType_LRN; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/log_softmax_populate.cc b/mindspore-lite/src/common/ops/populate/log_softmax_populate.cc index 40ae66b368b226e1a5ca35a588746696f8216aba..fe720d40eef8c7c494b117fa5c89bff84e01d983 100644 --- a/mindspore-lite/src/common/ops/populate/log_softmax_populate.cc +++ b/mindspore-lite/src/common/ops/populate/log_softmax_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" using mindspore::schema::PrimitiveType_LogSoftmax; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/lstm_populate.cc b/mindspore-lite/src/common/ops/populate/lstm_populate.cc index b3a85b64b57dfa11498420c72729954aa7294b11..65c3d9ecdbbebda7df5d47c9e7f2acfe7874e4a7 100644 --- a/mindspore-lite/src/common/ops/populate/lstm_populate.cc +++ b/mindspore-lite/src/common/ops/populate/lstm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" using mindspore::schema::PrimitiveType_LSTM; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/matmul_populate.cc b/mindspore-lite/src/common/ops/populate/matmul_populate.cc index 8eb182b826dbf011b0f74f4b7691fa38cac17acd..803bedda5c29c3d1aba7999a5ec4d813c3334314 100644 --- a/mindspore-lite/src/common/ops/populate/matmul_populate.cc +++ b/mindspore-lite/src/common/ops/populate/matmul_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" using mindspore::schema::PrimitiveType_MatMulFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/mfcc_populate.cc b/mindspore-lite/src/common/ops/populate/mfcc_populate.cc index 3b7fc3d8860300af2b387ca30e1ef249e5a1bdcc..854872c68f5dc991d8dde410229cc9aaeac49f55 100644 --- a/mindspore-lite/src/common/ops/populate/mfcc_populate.cc +++ b/mindspore-lite/src/common/ops/populate/mfcc_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/infer/mfcc_infer.h" +#include "nnacl_c/infer/mfcc_infer.h" using mindspore::schema::PrimitiveType_Mfcc; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/mul_populate.cc b/mindspore-lite/src/common/ops/populate/mul_populate.cc index 3b3c5df038c7b54f38d4da731f70d0effb7cd980..3524d842bf6a78640d88b7146137f8619a06e874 100644 --- a/mindspore-lite/src/common/ops/populate/mul_populate.cc +++ b/mindspore-lite/src/common/ops/populate/mul_populate.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" #include "src/common/ops/populate/populate_register.h" #include "src/common/ops/populate/arithmetic_populate.h" using mindspore::schema::PrimitiveType_MulFusion; diff --git a/mindspore-lite/src/common/ops/populate/nllloss_populate.cc b/mindspore-lite/src/common/ops/populate/nllloss_populate.cc index 9a3c9f44e1c6012e96f7c9c5302cdfc5f4dcd681..86814c60cc5feada519e4c7de890346af47e9168 100644 --- a/mindspore-lite/src/common/ops/populate/nllloss_populate.cc +++ b/mindspore-lite/src/common/ops/populate/nllloss_populate.cc @@ -16,7 +16,7 @@ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/nllloss_parameter.h" +#include "nnacl_c/nllloss_parameter.h" using mindspore::schema::PrimitiveType_NLLLoss; using mindspore::schema::PrimitiveType_NLLLossGrad; diff --git a/mindspore-lite/src/common/ops/populate/non_max_suppression_populate.cc b/mindspore-lite/src/common/ops/populate/non_max_suppression_populate.cc index 485ff9c2d07c7470c96c01a21fee19edcc96b32e..c9db6507ca2de3caae09767dc723ead6d0395093 100644 --- a/mindspore-lite/src/common/ops/populate/non_max_suppression_populate.cc +++ b/mindspore-lite/src/common/ops/populate/non_max_suppression_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/non_max_suppression_parameter.h" +#include "nnacl_c/non_max_suppression_parameter.h" using mindspore::schema::PrimitiveType_NonMaxSuppression; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/one_hot_populate.cc b/mindspore-lite/src/common/ops/populate/one_hot_populate.cc index 18caaa3d688d4559eb7ac72f1bf42ff99be0ea91..efd737dc06c3adce2461559266fea92d2ea3dc66 100644 --- a/mindspore-lite/src/common/ops/populate/one_hot_populate.cc +++ b/mindspore-lite/src/common/ops/populate/one_hot_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/one_hot_fp32.h" +#include "nnacl_c/fp32/one_hot_fp32.h" using mindspore::schema::PrimitiveType_OneHot; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/p_relu_populate.cc b/mindspore-lite/src/common/ops/populate/p_relu_populate.cc index cda27de985a530b9a9b1ee6964164adc034f2a51..48a544d7768298f518204f495e7b18d0921cb4b2 100644 --- a/mindspore-lite/src/common/ops/populate/p_relu_populate.cc +++ b/mindspore-lite/src/common/ops/populate/p_relu_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/prelu_parameter.h" +#include "nnacl_c/prelu_parameter.h" using mindspore::schema::PrimitiveType_PReLUFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/pad_populate.cc b/mindspore-lite/src/common/ops/populate/pad_populate.cc index ac4171869e7a5c13ef2663f206cdc4f21394ac1e..2ea13a6512f0d612a32a7fd7f028e00796a097d2 100644 --- a/mindspore-lite/src/common/ops/populate/pad_populate.cc +++ b/mindspore-lite/src/common/ops/populate/pad_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/pad_parameter.h" +#include "nnacl_c/pad_parameter.h" using mindspore::schema::PrimitiveType_PadFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/partial_populate.cc b/mindspore-lite/src/common/ops/populate/partial_populate.cc index b5516686569598e622870c017840d36d3561b7e9..cab3b01e0f9c6df3e24770da4b9acd24e9faed2e 100644 --- a/mindspore-lite/src/common/ops/populate/partial_populate.cc +++ b/mindspore-lite/src/common/ops/populate/partial_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/partial_fusion_parameter.h" +#include "nnacl_c/partial_fusion_parameter.h" using mindspore::schema::PrimitiveType_PartialFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/pooling_populate.cc b/mindspore-lite/src/common/ops/populate/pooling_populate.cc index dd9fd519fb38e8a3bbbb76404fa288ccaaab4a80..7f401afda1590541b7fc42d3684cd29d5792c75d 100644 --- a/mindspore-lite/src/common/ops/populate/pooling_populate.cc +++ b/mindspore-lite/src/common/ops/populate/pooling_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/pooling_parameter.h" +#include "nnacl_c/pooling_parameter.h" using mindspore::schema::PrimitiveType_AvgPoolFusion; using mindspore::schema::PrimitiveType_MaxPoolFusion; diff --git a/mindspore-lite/src/common/ops/populate/populate_register.h b/mindspore-lite/src/common/ops/populate/populate_register.h index 226c58d467135a08e893156385d3ceddd6b20792..428849b549aad5456be765bbbf7e6ce19f656c18 100644 --- a/mindspore-lite/src/common/ops/populate/populate_register.h +++ b/mindspore-lite/src/common/ops/populate/populate_register.h @@ -20,7 +20,7 @@ #include #include #include "schema/model_generated.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/common.h" #include "src/common/log_adapter.h" #include "src/common/prim_util.h" diff --git a/mindspore-lite/src/common/ops/populate/power_populate.cc b/mindspore-lite/src/common/ops/populate/power_populate.cc index 2559626bc6fbed9f32aa7e555973ad0d0a7e5e85..33450e87330997bf7e0e048bba959ed04bd295b3 100644 --- a/mindspore-lite/src/common/ops/populate/power_populate.cc +++ b/mindspore-lite/src/common/ops/populate/power_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/pow_parameter.h" +#include "nnacl_c/pow_parameter.h" using mindspore::schema::PrimitiveType_PowFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/prior_box_populate.cc b/mindspore-lite/src/common/ops/populate/prior_box_populate.cc index 60e66233582af60b657e39dd27f9c01218a83ef7..197668597e7d7aaa91491c186bbe4927887d1f78 100644 --- a/mindspore-lite/src/common/ops/populate/prior_box_populate.cc +++ b/mindspore-lite/src/common/ops/populate/prior_box_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/prior_box_parameter.h" +#include "nnacl_c/prior_box_parameter.h" using mindspore::schema::PrimitiveType_PriorBox; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/quant_dtype_cast_populate.cc b/mindspore-lite/src/common/ops/populate/quant_dtype_cast_populate.cc index 028de9f38a7f9bab08d4cc3afe8f7326ea26c03f..35035e8df8467ceb80ebf040b9c12e488f7205a8 100644 --- a/mindspore-lite/src/common/ops/populate/quant_dtype_cast_populate.cc +++ b/mindspore-lite/src/common/ops/populate/quant_dtype_cast_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" using mindspore::schema::PrimitiveType_QuantDTypeCast; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/random_normal_populate.cc b/mindspore-lite/src/common/ops/populate/random_normal_populate.cc index 7956640654dc5bff62bf2303d897921943ed0d38..5c92ead74369051ef119b7e7283d20351212b8bf 100644 --- a/mindspore-lite/src/common/ops/populate/random_normal_populate.cc +++ b/mindspore-lite/src/common/ops/populate/random_normal_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" using mindspore::schema::PrimitiveType_RandomNormal; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/random_standard_normal_populate.cc b/mindspore-lite/src/common/ops/populate/random_standard_normal_populate.cc index f432ae45af31eda5e1eb815f24f4f3008ef19364..e4da6342e83dcb2e6e3d6557322326706fb7faf5 100644 --- a/mindspore-lite/src/common/ops/populate/random_standard_normal_populate.cc +++ b/mindspore-lite/src/common/ops/populate/random_standard_normal_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" using mindspore::schema::PrimitiveType_RandomStandardNormal; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/range_populate.cc b/mindspore-lite/src/common/ops/populate/range_populate.cc index 1132567717bd33a1ccebae390f81d991865ce15a..6bb081a0484c2f8981d41ac817a89c154587678d 100644 --- a/mindspore-lite/src/common/ops/populate/range_populate.cc +++ b/mindspore-lite/src/common/ops/populate/range_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/range_parameter.h" +#include "nnacl_c/range_parameter.h" using mindspore::schema::PrimitiveType_Range; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/reduce_populate.cc b/mindspore-lite/src/common/ops/populate/reduce_populate.cc index da4d39175f4a6abc38a8431ded1703e91c4db64d..661b091e4cdffa06aa1b7bb045b930fd9a5e78bc 100644 --- a/mindspore-lite/src/common/ops/populate/reduce_populate.cc +++ b/mindspore-lite/src/common/ops/populate/reduce_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/reduce_parameter.h" using mindspore::schema::PrimitiveType_ReduceFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/reduce_scatter.cc b/mindspore-lite/src/common/ops/populate/reduce_scatter.cc index 1e02e6e8c0d72b5a23b6b525ca35b3619ee1fac6..0025cd577c2a89c9b8188dfdc4fdbf0d06f6c7a7 100644 --- a/mindspore-lite/src/common/ops/populate/reduce_scatter.cc +++ b/mindspore-lite/src/common/ops/populate/reduce_scatter.cc @@ -16,7 +16,7 @@ #include "schema/ops_generated.h" #include "schema/model_generated.h" -#include "nnacl/reduce_scatter_parameter.h" +#include "nnacl_c/reduce_scatter_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_ReduceScatter; diff --git a/mindspore-lite/src/common/ops/populate/reshape_populate.cc b/mindspore-lite/src/common/ops/populate/reshape_populate.cc index d34bcb8a30223236ba2ba37072907b3c0eeba0b8..82e6fab5d89ca2c1870eb164ce0538e1cc1b6150 100644 --- a/mindspore-lite/src/common/ops/populate/reshape_populate.cc +++ b/mindspore-lite/src/common/ops/populate/reshape_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/reshape_parameter.h" +#include "nnacl_c/reshape_parameter.h" using mindspore::schema::PrimitiveType_Reshape; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/resize_populate.cc b/mindspore-lite/src/common/ops/populate/resize_populate.cc index 0d8ae5daf71eb78a6c0ab14460f2b97baa985684..a46cdd6038ebe1dbb3877ecdee2e6082f6251069 100644 --- a/mindspore-lite/src/common/ops/populate/resize_populate.cc +++ b/mindspore-lite/src/common/ops/populate/resize_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" using mindspore::schema::PrimitiveType_Resize; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/reverse_populate.cc b/mindspore-lite/src/common/ops/populate/reverse_populate.cc index 59ef7d234ee64899c7f73880dfdb2272d79dc8d5..3d16522a18faeb59c6de7901e84eadb60f2b216d 100644 --- a/mindspore-lite/src/common/ops/populate/reverse_populate.cc +++ b/mindspore-lite/src/common/ops/populate/reverse_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/reverse_fp32.h" +#include "nnacl_c/fp32/reverse_fp32.h" using mindspore::schema::PrimitiveType_ReverseV2; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/reverse_sequence_populate.cc b/mindspore-lite/src/common/ops/populate/reverse_sequence_populate.cc index 761457d6dd868464a0e534dcbba277256e35e664..5896c32dfae8f1fa0b08c499c0abdd1fe8efabd8 100644 --- a/mindspore-lite/src/common/ops/populate/reverse_sequence_populate.cc +++ b/mindspore-lite/src/common/ops/populate/reverse_sequence_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/reverse_sequence_parameter.h" +#include "nnacl_c/reverse_sequence_parameter.h" using mindspore::schema::PrimitiveType_ReverseSequence; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/roi_pooling_populate.cc b/mindspore-lite/src/common/ops/populate/roi_pooling_populate.cc index cf1f9d6fcf9df629a020d3c0cba2ae358a3560d3..c47a449b04aedcbd79f81dfd626b47b083f0cfd8 100644 --- a/mindspore-lite/src/common/ops/populate/roi_pooling_populate.cc +++ b/mindspore-lite/src/common/ops/populate/roi_pooling_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/roi_pooling_fp32.h" +#include "nnacl_c/fp32/roi_pooling_fp32.h" using mindspore::schema::PrimitiveType_ROIPooling; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/scale_populate.cc b/mindspore-lite/src/common/ops/populate/scale_populate.cc index 530543c2d9d6c1b9f417b667e115e857d9318e17..780696d9a23735d8d7291f93ee9a20aefaa9f8cc 100644 --- a/mindspore-lite/src/common/ops/populate/scale_populate.cc +++ b/mindspore-lite/src/common/ops/populate/scale_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/scale_parameter.h" +#include "nnacl_c/scale_parameter.h" using mindspore::schema::PrimitiveType_ScaleFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/scatter_element_populate.cc b/mindspore-lite/src/common/ops/populate/scatter_element_populate.cc index 135fd26111c8b16bbd0e6c878cf9cd4f354faabb..fe7b83435c25db9568a2e0a66dcadadb73b3d3e7 100644 --- a/mindspore-lite/src/common/ops/populate/scatter_element_populate.cc +++ b/mindspore-lite/src/common/ops/populate/scatter_element_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/scatter_elements_parameter.h" +#include "nnacl_c/scatter_elements_parameter.h" using mindspore::schema::PrimitiveType_ScatterElements; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/scatter_nd_populate.cc b/mindspore-lite/src/common/ops/populate/scatter_nd_populate.cc index 7aa054ea779bc972be9271c5ad5c304f416c2d0c..897bd85424d922626c1c35934b1fad7f44d2dd91 100644 --- a/mindspore-lite/src/common/ops/populate/scatter_nd_populate.cc +++ b/mindspore-lite/src/common/ops/populate/scatter_nd_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/scatter_nd_parameter.h" +#include "nnacl_c/scatter_nd_parameter.h" using mindspore::schema::PrimitiveType_ScatterNd; diff --git a/mindspore-lite/src/common/ops/populate/scatter_nd_update_populate.cc b/mindspore-lite/src/common/ops/populate/scatter_nd_update_populate.cc index c8002c2285dbfb20e81b4c7766d79c7d282317e0..e75246afa9948ffffe008ca60b98d3a39ec3fe12 100644 --- a/mindspore-lite/src/common/ops/populate/scatter_nd_update_populate.cc +++ b/mindspore-lite/src/common/ops/populate/scatter_nd_update_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/scatter_nd_parameter.h" +#include "nnacl_c/scatter_nd_parameter.h" using mindspore::schema::PrimitiveType_ScatterNdUpdate; using mindspore::schema::PrimitiveType_TensorScatterAdd; diff --git a/mindspore-lite/src/common/ops/populate/slice_populate.cc b/mindspore-lite/src/common/ops/populate/slice_populate.cc index c41899ebe57d12bf0cc3c4ef4cf8d2ae172fd46c..270029f73c39306897779892ff25cbf2a871f0b6 100644 --- a/mindspore-lite/src/common/ops/populate/slice_populate.cc +++ b/mindspore-lite/src/common/ops/populate/slice_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/slice_parameter.h" +#include "nnacl_c/slice_parameter.h" using mindspore::schema::PrimitiveType_SliceFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/softmax_populate.cc b/mindspore-lite/src/common/ops/populate/softmax_populate.cc index 66ccdc3bf70569d290929f278d4147ba43371fde..8821d3a58327f505ddb8c84f04ba491a6c07ab64 100644 --- a/mindspore-lite/src/common/ops/populate/softmax_populate.cc +++ b/mindspore-lite/src/common/ops/populate/softmax_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" using mindspore::schema::PrimitiveType_Softmax; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/space_to_batch_nd_populate.cc b/mindspore-lite/src/common/ops/populate/space_to_batch_nd_populate.cc index bc7b279984aab0f92c2e1f541a585d224cfa9f91..201f7f0bde589609dc4b64706718eb9181738fa9 100644 --- a/mindspore-lite/src/common/ops/populate/space_to_batch_nd_populate.cc +++ b/mindspore-lite/src/common/ops/populate/space_to_batch_nd_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" using mindspore::schema::PrimitiveType_SpaceToBatchND; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/space_to_batch_populate.cc b/mindspore-lite/src/common/ops/populate/space_to_batch_populate.cc index 95ef85ea7d4524188a951da93c7d9bcbe8d1d29f..43ce9f3570fc76f7906e940770a101cd990e86d5 100644 --- a/mindspore-lite/src/common/ops/populate/space_to_batch_populate.cc +++ b/mindspore-lite/src/common/ops/populate/space_to_batch_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" using mindspore::schema::PrimitiveType_SpaceToBatch; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/space_to_depth_populate.cc b/mindspore-lite/src/common/ops/populate/space_to_depth_populate.cc index 0b37c433c11f0186d8288b4ed914286f326a717c..6454ac01c817d963e408ee765587200be99468f0 100644 --- a/mindspore-lite/src/common/ops/populate/space_to_depth_populate.cc +++ b/mindspore-lite/src/common/ops/populate/space_to_depth_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/space_to_depth_parameter.h" +#include "nnacl_c/space_to_depth_parameter.h" using mindspore::schema::PrimitiveType_SpaceToDepth; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/sparse_softmax_cross_entropy_with_logits_populate.cc b/mindspore-lite/src/common/ops/populate/sparse_softmax_cross_entropy_with_logits_populate.cc index f970e47ef197dc5d1ceba83987fe6afbfc71842b..74ccfe434516797e8aa93dd2a217bce6026fcd3f 100644 --- a/mindspore-lite/src/common/ops/populate/sparse_softmax_cross_entropy_with_logits_populate.cc +++ b/mindspore-lite/src/common/ops/populate/sparse_softmax_cross_entropy_with_logits_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32_grad/softmax_grad.h" +#include "nnacl_c/fp32_grad/softmax_grad.h" using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/sparse_to_dense_populate.cc b/mindspore-lite/src/common/ops/populate/sparse_to_dense_populate.cc index d4ac7f8e91909c16ab17c0642447520f8df1551d..f943e7b4e0bb4700bdfd0bd65334545930d5ddc1 100644 --- a/mindspore-lite/src/common/ops/populate/sparse_to_dense_populate.cc +++ b/mindspore-lite/src/common/ops/populate/sparse_to_dense_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/sparse_to_dense_parameter.h" +#include "nnacl_c/sparse_to_dense_parameter.h" using mindspore::schema::PrimitiveType_SparseToDense; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/splice_populate.cc b/mindspore-lite/src/common/ops/populate/splice_populate.cc index 32767fde1c16e6709cd21c6098df644255d6b6db..f2b73334968c60c5d66ad4a033dc8c80f07aecac 100644 --- a/mindspore-lite/src/common/ops/populate/splice_populate.cc +++ b/mindspore-lite/src/common/ops/populate/splice_populate.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/op_base.h" -#include "nnacl/splice_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/splice_parameter.h" using mindspore::schema::PrimitiveType_Splice; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/split_populate.cc b/mindspore-lite/src/common/ops/populate/split_populate.cc index 83e3e67f92f68160ae586aa305555226400ed55c..8bb30d3e1906a800ff49fc5c7ca02bbe6ea16119 100644 --- a/mindspore-lite/src/common/ops/populate/split_populate.cc +++ b/mindspore-lite/src/common/ops/populate/split_populate.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/split_parameter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/op_base.h" using mindspore::schema::PrimitiveType_Split; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/split_with_overlap_populate.cc b/mindspore-lite/src/common/ops/populate/split_with_overlap_populate.cc index c44027628c22b0f5f7d8ab84a1384ec1ebfe1a24..485b3bb37f66d433f8a615929912044426bd0a06 100644 --- a/mindspore-lite/src/common/ops/populate/split_with_overlap_populate.cc +++ b/mindspore-lite/src/common/ops/populate/split_with_overlap_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" using mindspore::schema::PrimitiveType_SplitWithOverlap; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/squeeze_populate.cc b/mindspore-lite/src/common/ops/populate/squeeze_populate.cc index 8767f130f58c10a32664a7f73b9572c0a60bdf52..d4aad57048399a1d52e80b72af3a8cb100617932 100644 --- a/mindspore-lite/src/common/ops/populate/squeeze_populate.cc +++ b/mindspore-lite/src/common/ops/populate/squeeze_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/squeeze_parameter.h" +#include "nnacl_c/squeeze_parameter.h" using mindspore::schema::PrimitiveType_Squeeze; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/stack_populate.cc b/mindspore-lite/src/common/ops/populate/stack_populate.cc index 57bb5652bf76b180b808ea77690e762fcef91e8e..0094a8e28eacf744e24906f41cda4f4a31bba866 100644 --- a/mindspore-lite/src/common/ops/populate/stack_populate.cc +++ b/mindspore-lite/src/common/ops/populate/stack_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/stack_parameter.h" +#include "nnacl_c/stack_parameter.h" using mindspore::schema::PrimitiveType_Stack; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/strided_slice_grad_populate.cc b/mindspore-lite/src/common/ops/populate/strided_slice_grad_populate.cc index 88e82ea1bc3f1c927513ba0bc337824e97a5a915..09feb95c67a04791794a8f8f80d2230e75fb94e9 100644 --- a/mindspore-lite/src/common/ops/populate/strided_slice_grad_populate.cc +++ b/mindspore-lite/src/common/ops/populate/strided_slice_grad_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" using mindspore::schema::PrimitiveType_StridedSliceGrad; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/strided_slice_populate.h b/mindspore-lite/src/common/ops/populate/strided_slice_populate.h index 552cc40bd0b81f4d939babc8689704bfc6464c3d..8a3cc5056c24d1a75b2d4f84819b5160f5a817ee 100644 --- a/mindspore-lite/src/common/ops/populate/strided_slice_populate.h +++ b/mindspore-lite/src/common/ops/populate/strided_slice_populate.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_STRIDED_SLICE_POPULATE_H_ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/common/ops/populate/string/custom_predict_populate.cc b/mindspore-lite/src/common/ops/populate/string/custom_predict_populate.cc index 536b2cc2c482925805453bfe2657ce37389c6e22..2992b1af72af8d5f4bf59105abb33383036428e6 100644 --- a/mindspore-lite/src/common/ops/populate/string/custom_predict_populate.cc +++ b/mindspore-lite/src/common/ops/populate/string/custom_predict_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/predict_parameter.h" +#include "nnacl_c/predict_parameter.h" using mindspore::schema::PrimitiveType_CustomPredict; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/string/lsh_projection_populate.cc b/mindspore-lite/src/common/ops/populate/string/lsh_projection_populate.cc index 05a1517ea38148263d154ca73c82bd5d2cf3dd64..750027b6fc95f19e6d029740c98b71346714cfcc 100644 --- a/mindspore-lite/src/common/ops/populate/string/lsh_projection_populate.cc +++ b/mindspore-lite/src/common/ops/populate/string/lsh_projection_populate.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "nnacl/lsh_projection_parameter.h" +#include "nnacl_c/lsh_projection_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_LshProjection; diff --git a/mindspore-lite/src/common/ops/populate/string/skip_gram_populate.cc b/mindspore-lite/src/common/ops/populate/string/skip_gram_populate.cc index 3baf95370faa81224af897a92c2b220b4df294a8..6b02efaa13e80d2cfbf1506603b49918ea7b814f 100644 --- a/mindspore-lite/src/common/ops/populate/string/skip_gram_populate.cc +++ b/mindspore-lite/src/common/ops/populate/string/skip_gram_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/skip_gram_parameter.h" +#include "nnacl_c/skip_gram_parameter.h" using mindspore::schema::PrimitiveType_SkipGram; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/sub_populate.cc b/mindspore-lite/src/common/ops/populate/sub_populate.cc index e9023698d23440b18cc0e5f87b947692630a8669..be1f3a99a8314a02e63b06639d0ef0722e0e164f 100644 --- a/mindspore-lite/src/common/ops/populate/sub_populate.cc +++ b/mindspore-lite/src/common/ops/populate/sub_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" #include "src/common/ops/populate/arithmetic_populate.h" using mindspore::schema::PrimitiveType_SubFusion; diff --git a/mindspore-lite/src/common/ops/populate/tile_populate.cc b/mindspore-lite/src/common/ops/populate/tile_populate.cc index b1b555d878a69232229f81aaf3e4b0028b7d75ec..faed13300f9d950481309b18745bb8c944543337 100644 --- a/mindspore-lite/src/common/ops/populate/tile_populate.cc +++ b/mindspore-lite/src/common/ops/populate/tile_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/base/tile_base.h" +#include "nnacl_c/base/tile_base.h" using mindspore::schema::PrimitiveType_TileFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/topk_populate.cc b/mindspore-lite/src/common/ops/populate/topk_populate.cc index a92f527718c192b1fb3d1e2c109bf7e07028f982..66446101e4fc409e8a88e20b0d2ec24c103bea2d 100644 --- a/mindspore-lite/src/common/ops/populate/topk_populate.cc +++ b/mindspore-lite/src/common/ops/populate/topk_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/topk_fp32.h" +#include "nnacl_c/fp32/topk_fp32.h" using mindspore::schema::PrimitiveType_TopKFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/transpose_populate.cc b/mindspore-lite/src/common/ops/populate/transpose_populate.cc index a13950b78cfc925f8d8587b36e70efa4f2910a0a..11e1c3e3412c1bd6106480b0dcc35c87f004260a 100644 --- a/mindspore-lite/src/common/ops/populate/transpose_populate.cc +++ b/mindspore-lite/src/common/ops/populate/transpose_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" using mindspore::schema::PrimitiveType_Transpose; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/triu_tril_populate.cc b/mindspore-lite/src/common/ops/populate/triu_tril_populate.cc index dcb6b5120f5cb80ce0ed289108a1748e9dc4f865..2bc02a859ef2fffbfc9fe319a6cc76eaf551d1d3 100644 --- a/mindspore-lite/src/common/ops/populate/triu_tril_populate.cc +++ b/mindspore-lite/src/common/ops/populate/triu_tril_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/triu_tril_parameter.h" +#include "nnacl_c/triu_tril_parameter.h" using mindspore::schema::PrimitiveType_Tril; using mindspore::schema::PrimitiveType_Triu; diff --git a/mindspore-lite/src/common/ops/populate/uniform_real_populate.cc b/mindspore-lite/src/common/ops/populate/uniform_real_populate.cc index d901852aedbe2ba7b46bf29c95eebea39dba9b1f..a00aa15cee4ecd8083d04fc7cbe0f6615a8512b5 100644 --- a/mindspore-lite/src/common/ops/populate/uniform_real_populate.cc +++ b/mindspore-lite/src/common/ops/populate/uniform_real_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/populate/populate_register.h" #include "src/common/ops/populate/default_populate.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" using mindspore::schema::PrimitiveType_UniformReal; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/unique_populate.cc b/mindspore-lite/src/common/ops/populate/unique_populate.cc index 456f1c7910fb7ae614b0d0313b9bcaef8502acec..ebe429173b4807c1d506f0af8d5f50e9e045a986 100644 --- a/mindspore-lite/src/common/ops/populate/unique_populate.cc +++ b/mindspore-lite/src/common/ops/populate/unique_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/unique_fp32.h" +#include "nnacl_c/fp32/unique_fp32.h" using mindspore::schema::PrimitiveType_Unique; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/unsqueeze_populate.cc b/mindspore-lite/src/common/ops/populate/unsqueeze_populate.cc index 5feafc4cfb5222aa7567eaeb0cecbecb6df8ee4f..d859556c6eba8e199189f0964dd887304c7d9911 100644 --- a/mindspore-lite/src/common/ops/populate/unsqueeze_populate.cc +++ b/mindspore-lite/src/common/ops/populate/unsqueeze_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/unsqueeze_parameter.h" +#include "nnacl_c/unsqueeze_parameter.h" using mindspore::schema::PrimitiveType_Unsqueeze; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/unstack_populate.cc b/mindspore-lite/src/common/ops/populate/unstack_populate.cc index 6602b08c347f3e61aa0a8b780e55b4b2acfafe67..a77ded6c506785df0c38779cdc909332409f0fa3 100644 --- a/mindspore-lite/src/common/ops/populate/unstack_populate.cc +++ b/mindspore-lite/src/common/ops/populate/unstack_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/unstack_parameter.h" +#include "nnacl_c/unstack_parameter.h" using mindspore::schema::PrimitiveType_Unstack; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/where_populate.cc b/mindspore-lite/src/common/ops/populate/where_populate.cc index 489520384089217697078324a2b243854cb6e8ea..f3a32b254c581928ecc0e0c9a86fef73c39a2e33 100644 --- a/mindspore-lite/src/common/ops/populate/where_populate.cc +++ b/mindspore-lite/src/common/ops/populate/where_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/where_parameter.h" +#include "nnacl_c/where_parameter.h" using mindspore::schema::PrimitiveType_Where; namespace mindspore { diff --git a/mindspore-lite/src/common/prim_util.cc b/mindspore-lite/src/common/prim_util.cc index 7da14e5b3f239e3b72028569bd97c2edfc04bbf7..d640815d9f52d94d4087f3d7b415b8c1d63e01bb 100644 --- a/mindspore-lite/src/common/prim_util.cc +++ b/mindspore-lite/src/common/prim_util.cc @@ -17,7 +17,7 @@ #include "src/common/prim_util.h" #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "schema/model_generated.h" #include "src/common/log_adapter.h" diff --git a/mindspore-lite/src/common/tensor_util.cc b/mindspore-lite/src/common/tensor_util.cc index aecc0236f5d482cf3ae4c9ee1ce1b51b95578453..f6a33f1f073dc28cea5d6b525a67060966c18b93 100644 --- a/mindspore-lite/src/common/tensor_util.cc +++ b/mindspore-lite/src/common/tensor_util.cc @@ -23,7 +23,7 @@ #include "src/common/log_adapter.h" #include "src/litert/pack_weight_manager.h" #include "src/litert/kernel/cpu/fp16/fp16_op_handler.h" -#include "nnacl/base/cast_base.h" +#include "nnacl_c/base/cast_base.h" namespace mindspore { namespace lite { void FreeInTensorC(std::vector *tensors_in, const std::shared_ptr &allocator) { diff --git a/mindspore-lite/src/common/tensor_util.h b/mindspore-lite/src/common/tensor_util.h index bd75f74b0b2096f4994f92c30aa99544cbeadb81..080bb448af952bb21a8a22722ff9db579d83ca07 100644 --- a/mindspore-lite/src/common/tensor_util.h +++ b/mindspore-lite/src/common/tensor_util.h @@ -20,11 +20,11 @@ #include #include #include "src/tensor.h" -#include "nnacl/tensor_c.h" -#include "nnacl/tensor_c_utils.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/tensor_c_utils.h" #include "src/tensorlist.h" -#include "nnacl/infer/common_infer.h" -#include "nnacl/tensorlist_c_utils.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/tensorlist_c_utils.h" #include "src/litert/cxx_api/tensor/tensor_impl.h" #include "include/api/visible.h" diff --git a/mindspore-lite/src/control_flow/control_flow_scheduler.cc b/mindspore-lite/src/control_flow/control_flow_scheduler.cc index a2a573de25a10f92d58180aecaf31326a79889dc..bfa80d916fbeca96e35653312a7bc15790819b6b 100644 --- a/mindspore-lite/src/control_flow/control_flow_scheduler.cc +++ b/mindspore-lite/src/control_flow/control_flow_scheduler.cc @@ -20,7 +20,7 @@ #include #include "src/litert/kernel_exec_util.h" #include "src/litert/kernel/cpu/base/partial_fusion.h" -#include "nnacl/call_parameter.h" +#include "nnacl_c/call_parameter.h" #include "src/control_flow/kernel/exit_subgraph_kernel.h" #include "src/control_flow/kernel/identity_kernel.h" #include "src/tensorlist.h" diff --git a/mindspore-lite/src/control_flow/control_flow_scheduler.h b/mindspore-lite/src/control_flow/control_flow_scheduler.h index 2344f3806ea544d9fbb7dbdac653f7eb28b30d7e..910a6aa1ee160276a6c4cd45afc75e32d323a39a 100644 --- a/mindspore-lite/src/control_flow/control_flow_scheduler.h +++ b/mindspore-lite/src/control_flow/control_flow_scheduler.h @@ -26,7 +26,7 @@ #include #include "src/common/utils.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/inner_context.h" #include "src/tensor.h" #include "src/executor/sub_graph_kernel.h" diff --git a/mindspore-lite/src/executor/kernel_exec.h b/mindspore-lite/src/executor/kernel_exec.h index 89a32b5f8cb284ba7b33a817b23cc525b8e2a8ba..37eacf3ed7c40588daa8948aea9c3b4b4c7ea363 100644 --- a/mindspore-lite/src/executor/kernel_exec.h +++ b/mindspore-lite/src/executor/kernel_exec.h @@ -26,7 +26,7 @@ #ifdef ENABLE_ARM #include #endif -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/inner_context.h" #include "src/tensor.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/executor/sub_graph_kernel.h b/mindspore-lite/src/executor/sub_graph_kernel.h index 13f1ef191c510199828fdc8f5ebaa6d6e4c66fcc..8f84a021eba793c2e1423d2f85a64752e56e15d5 100644 --- a/mindspore-lite/src/executor/sub_graph_kernel.h +++ b/mindspore-lite/src/executor/sub_graph_kernel.h @@ -29,7 +29,7 @@ #include "src/common/version_manager.h" #include "src/litert/cpu_info.h" #if defined(ENABLE_ARM) && defined(ENABLE_FP16) -#include "nnacl/constant_of_shape_parameter.h" +#include "nnacl_c/constant_of_shape_parameter.h" #endif namespace mindspore::kernel { diff --git a/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc b/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc index 520adeacbc23e6b953871fa1d72108755f2251e2..24effad566d7ab8972d6cb478e33635e23ad7ea7 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc +++ b/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc @@ -24,7 +24,7 @@ #include "ops/primitive_c.h" #include "tools/optimizer/common/gllo_utils.h" #include "src/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "extendrt/cxx_api/model/model_impl.h" #include "extendrt/cxx_api/dlutils.h" diff --git a/mindspore-lite/src/extendrt/cxx_api/model_pool/model_parallel_runner_impl.cc b/mindspore-lite/src/extendrt/cxx_api/model_pool/model_parallel_runner_impl.cc index 72712ed0fe69191c21d8f9b9575c4b5cd067f091..4fcdace8b48b11175e0bc6be7257d4bfce75bdad 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model_pool/model_parallel_runner_impl.cc +++ b/mindspore-lite/src/extendrt/cxx_api/model_pool/model_parallel_runner_impl.cc @@ -17,7 +17,7 @@ #include "src/extendrt/cxx_api/model_pool/runner_config.h" #include "src/common/log_adapter.h" #include "src/litert/cpu_info.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #ifdef CAPTURE_SIGNALS #include "src/extendrt/signal_handler.h" #endif diff --git a/mindspore-lite/src/extendrt/cxx_api/model_pool/model_pool.cc b/mindspore-lite/src/extendrt/cxx_api/model_pool/model_pool.cc index c350c05b454b7ece69817b5dddd963f11cb0900a..7573c0dd060b3172469491981570ecceef08ddfa 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model_pool/model_pool.cc +++ b/mindspore-lite/src/extendrt/cxx_api/model_pool/model_pool.cc @@ -18,7 +18,7 @@ #include #include #include -#include "mindspore/ops/kernel/cpu/nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/extendrt/cxx_api/model_pool/resource_manager.h" #include "src/common/log_adapter.h" #include "include/lite_types.h" diff --git a/mindspore-lite/src/extendrt/cxx_api/model_pool/model_worker.cc b/mindspore-lite/src/extendrt/cxx_api/model_pool/model_worker.cc index d36025acdf50166afa8cd6cdcc95fc2e9ae5201e..abed1765c0b9a95c2474fc9358adbbddb7e82e84 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model_pool/model_worker.cc +++ b/mindspore-lite/src/extendrt/cxx_api/model_pool/model_worker.cc @@ -18,7 +18,7 @@ #include "src/common/log_adapter.h" #include "src/extendrt/numa_adapter.h" #include "src/common/common.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { void ModelWorker::PrintWorkerInfo() { MS_LOG(ERROR) << "worker id: " << worker_config_->worker_id diff --git a/mindspore-lite/src/extendrt/cxx_api/model_pool/resource_manager.cc b/mindspore-lite/src/extendrt/cxx_api/model_pool/resource_manager.cc index a14f96b14fc94a9ccf433a767fa59f847ff88c13..a22224fdae58a9ef1766ed26a70b2ac152b63ccc 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model_pool/resource_manager.cc +++ b/mindspore-lite/src/extendrt/cxx_api/model_pool/resource_manager.cc @@ -22,7 +22,7 @@ #include "src/common/log_adapter.h" #include "src/common/utils.h" #include "src/extendrt/numa_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace { constexpr int kNumIndex = 2; diff --git a/mindspore-lite/src/extendrt/delegate/delegate_utils.cc b/mindspore-lite/src/extendrt/delegate/delegate_utils.cc index 25ad349ce441b76813eb0d4ee150349930d4a25c..59081be964ef2133032fc9be6f678bcfb17aa3c7 100644 --- a/mindspore-lite/src/extendrt/delegate/delegate_utils.cc +++ b/mindspore-lite/src/extendrt/delegate/delegate_utils.cc @@ -15,7 +15,7 @@ */ #include "src/extendrt/delegate/delegate_utils.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::lite { bool IsSubGraphInputTensor(const std::vector &inputs, const TensorInfo &input) { return std::find(inputs.begin(), inputs.end(), input) != inputs.end(); diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/CMakeLists.txt b/mindspore-lite/src/extendrt/delegate/tensorrt/CMakeLists.txt index 452b434288339734c3342f27cdbb8f5200935b41..cdb9661dc6a26ab7b42f2337278f5aaff72b39de 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/CMakeLists.txt +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/CMakeLists.txt @@ -12,7 +12,7 @@ set(TENSORRT_PATH $ENV{TENSORRT_PATH}) set(TENSORRT_LIB_PATH ${TENSORRT_PATH}/lib) include_directories(${TENSORRT_PATH}/include) -include_directories(${OPS_DIR}/kernel/cpu) +include_directories(${NNACL_DIR}/../) include_directories(${CCSRC_DIR}/../) include_directories(${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops) @@ -58,7 +58,7 @@ file(GLOB TENSORRT_RUNTIME_SRC LIST_DIRECTORIES false ${CMAKE_CURRENT_SOURCE_DIR}/cuda_impl/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../../extendrt/delegate/delegate_utils.cc ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/cuda_device_info.cc - ${OPS_DIR}/kernel/cpu/nnacl/nnacl_common.c + ${NNACL_DIR}/nnacl_common.c ${TOP_DIR}/mindspore-lite/src/common/file_utils.cc ) diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cu b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cu index e5150eda5c8d97771a6d6ad239b7b318fc8342e0..f1ee734bd6cc312423ad449091b7c957ee2eaecb 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cu +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cu @@ -17,7 +17,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" __device__ __forceinline__ uint64_t Pop(const uint64_t *chunks, uint64_t *curr_chunk, uint8_t bit_count, int32_t *curr_bit_count, int32_t *curr_chunk_index) { diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.cc index 608c3f4e00bdf31e279b31a06f13b9e4de580f0c..d8ade39543b85fa023c5fa810bdbba134dae1eb8 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.cc +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.cc @@ -16,7 +16,7 @@ #include "src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.h" #include -#include "nnacl/pack.h" +#include "nnacl_c/pack.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" namespace mindspore::lite { diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.cc index 412913d0e8c76040d9d15e1741f11ac8198e41a8..822335743c9357c46e3b188e105918068b62bd14 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.cc +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.cc @@ -16,7 +16,7 @@ #include "src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.h" #include -#include "nnacl/pack.h" +#include "nnacl_c/pack.h" #include "infer/cxx_api/conv2d_transpose_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.cc index 5856f63de313d4e3bf8d9ce81bb184fde9fd4399..afddf56b2582892b28eb7c7fd81276fcf5c32da2 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.cc +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.cc @@ -17,7 +17,7 @@ #include "src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.h" #include #include "src/extendrt/delegate/tensorrt/op/activation_tensorrt.h" -#include "nnacl/pack.h" +#include "nnacl_c/pack.h" #include "infer/cxx_api/conv2d_transpose_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/resize_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/resize_tensorrt.cc index 3d82f8d1abbaf64e42c6003d869a5f8092d8bca3..b9cddd0aa1f946ea981a2f8bccc3404f11441f5f 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/resize_tensorrt.cc +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/op/resize_tensorrt.cc @@ -19,7 +19,7 @@ #include #include #include -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" #include "infer/resize.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.cc index eebb15529bb23c1772f57b168475bc9ba74682be..aa99d260a0cff87935591130cfefa921c407ec88 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.cc +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.cc @@ -17,7 +17,7 @@ #include "mindspore/ops/op_def/lite_ops.h" #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/common/utils/anfalgo.h" #include "mindspore/ccsrc/include/common/utils/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_utils.h b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_utils.h index b5787b4e90567f504f18ad6795a34a6d1c1521fd..cee6582fc02b66e82a7ac8d82c8f351241ff9660 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_utils.h +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_utils.h @@ -27,7 +27,7 @@ #include "src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h" #include "ir/dtype/type_id.h" #include "schema/ops_generated.h" -#include "nnacl/pack.h" +#include "nnacl_c/pack.h" #include "include/api/context.h" #include "mindapi/base/types.h" diff --git a/mindspore-lite/src/extendrt/graph_compiler/single_graph_scheduler.cc b/mindspore-lite/src/extendrt/graph_compiler/single_graph_scheduler.cc index 67302c5d75ceb8ad3ac992ba953898a813c4061d..e8527d0dba2a11fc6f6131a987177306a8890b64 100644 --- a/mindspore-lite/src/extendrt/graph_compiler/single_graph_scheduler.cc +++ b/mindspore-lite/src/extendrt/graph_compiler/single_graph_scheduler.cc @@ -26,7 +26,7 @@ #include "src/common/draw/drawer.h" #include "src/extendrt/kernel/nnacl/nnacl_base_kernel.h" #include "src/extendrt/kernel/extendrt_kernel_exec.h" -#include "nnacl/format_transpose_parameter.h" +#include "nnacl_c/format_transpose_parameter.h" #include "extendrt/delegate/ascend_native/delegate.h" #include "extendrt/delegate/factory.h" diff --git a/mindspore-lite/src/extendrt/infer_session.cc b/mindspore-lite/src/extendrt/infer_session.cc index 3648c0121f86310a6139fd7b434433a38527243d..2e74d185b510f98803f71d5dc7d4492612a513fb 100644 --- a/mindspore-lite/src/extendrt/infer_session.cc +++ b/mindspore-lite/src/extendrt/infer_session.cc @@ -23,7 +23,7 @@ #include "extendrt/delegate/plugin/ascend_ge_executor_plugin.h" #include "extendrt/delegate/plugin/ascend_native_executor_plugin.h" #include "extendrt/kernel/ascend/plugin/ascend_kernel_plugin.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace { diff --git a/mindspore-lite/src/extendrt/kernel/ascend/model/dyn_shape_process.cc b/mindspore-lite/src/extendrt/kernel/ascend/model/dyn_shape_process.cc index d3f531f67484d46086b332791631d5762898716c..f3b5969518557fdd8887d26dd625c127c9581546 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/model/dyn_shape_process.cc +++ b/mindspore-lite/src/extendrt/kernel/ascend/model/dyn_shape_process.cc @@ -16,7 +16,7 @@ #include "extendrt/kernel/ascend/model/dyn_shape_process.h" #include -#include "mindspore/ops/kernel/cpu/nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/errorcode.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.cc b/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.cc index dfef81ad04573380ec99b2a84fe17d6d26a49b33..aafd9a4278f30728662c7cc3c2bfd5062640949f 100644 --- a/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.cc +++ b/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.cc @@ -20,7 +20,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "common/ms_factory.h" #include "include/api/status.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" #include "src/common/log_util.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.h b/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.h index 94dd8fa76dbfa2f114aa1d4697f5cc7d0e2f7c7b..c1b0b87474b346a6a43e9303f25ba5045e411a10 100644 --- a/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.h +++ b/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.h @@ -22,7 +22,7 @@ #include #include #include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" #include "common/common_utils.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/extendrt/kernel/cuda/batchtospace.cc b/mindspore-lite/src/extendrt/kernel/cuda/batchtospace.cc index ee4f07aca0255911029c85e416ef2c25435e5176..e9ff4be9a7d46c9c63e5b8240ccff1536d9cf093 100644 --- a/mindspore-lite/src/extendrt/kernel/cuda/batchtospace.cc +++ b/mindspore-lite/src/extendrt/kernel/cuda/batchtospace.cc @@ -16,7 +16,7 @@ #include "src/extendrt/kernel/cuda/batchtospace.h" #include -#include "nnacl/batch_to_space_parameter.h" +#include "nnacl_c/batch_to_space_parameter.h" namespace mindspore::kernel { int BatchtoSpaceCudaKernel::Prepare() { diff --git a/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc b/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc index 6e5e52d688f9caa215d8c0384568752689c99d1f..755bcece5c5ec3299af45ad1a6d849bf91eb502a 100644 --- a/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc +++ b/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc @@ -22,7 +22,7 @@ #include "ir/tensor.h" #include "ir/value.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/common.h" #include "src/common/log_util.h" diff --git a/mindspore-lite/src/extendrt/mock/lite_runtime/populate/arithmetic_populate.h b/mindspore-lite/src/extendrt/mock/lite_runtime/populate/arithmetic_populate.h index 3fc1e267246fd37de2f386f3924a82d4d02f95e0..404acb0ed2b766225e2676cb404f13ec06a1543b 100644 --- a/mindspore-lite/src/extendrt/mock/lite_runtime/populate/arithmetic_populate.h +++ b/mindspore-lite/src/extendrt/mock/lite_runtime/populate/arithmetic_populate.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_ARITHMETIC_POPULATE_H_ #define MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_ARITHMETIC_POPULATE_H_ -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore { ArithmeticParameter *PopulateArithmeticCommonPara(void *prim); diff --git a/mindspore-lite/src/extendrt/mock/lite_runtime/populate/base_operator_populate_register.h b/mindspore-lite/src/extendrt/mock/lite_runtime/populate/base_operator_populate_register.h index 5373ecd70b8f0fa4768f5f6117f3d4547647debe..1a6e1251afe9d92b9d85c065a5c72d8d69d5bcb8 100644 --- a/mindspore-lite/src/extendrt/mock/lite_runtime/populate/base_operator_populate_register.h +++ b/mindspore-lite/src/extendrt/mock/lite_runtime/populate/base_operator_populate_register.h @@ -22,7 +22,7 @@ #include #include "schema/model_generated.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/common.h" #include "src/common/log_adapter.h" #include "src/common/prim_util.h" diff --git a/mindspore-lite/src/infer/primitive_type.cc b/mindspore-lite/src/infer/primitive_type.cc index 312e05621af83d3e652c0d00341a90763d436781..9eaaf430ed559dbcebfa7eb27ed3e533c665925b 100644 --- a/mindspore-lite/src/infer/primitive_type.cc +++ b/mindspore-lite/src/infer/primitive_type.cc @@ -15,7 +15,7 @@ */ #include "src/infer/primitive_type.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::kernel { #ifdef ENABLE_CLOUD_INFERENCE diff --git a/mindspore-lite/src/litert/cpu_info.cc b/mindspore-lite/src/litert/cpu_info.cc index 7510de1ae53b2e329b547c183c9240a20adb8ff7..1302fe02bb13cbc5e5d77ca82a17bd306d766333 100644 --- a/mindspore-lite/src/litert/cpu_info.cc +++ b/mindspore-lite/src/litert/cpu_info.cc @@ -18,7 +18,7 @@ #include #include #include "src/common/log_adapter.h" -#include "nnacl/nnacl_utils.h" +#include "nnacl_c/nnacl_utils.h" #if defined(ENABLE_ARM64) && !defined(SUPPORT_NNIE) && !defined(MS_COMPILE_IOS) #include #include diff --git a/mindspore-lite/src/litert/cpu_info.h b/mindspore-lite/src/litert/cpu_info.h index d6ac9f767ffa926804695c0e8b3c3f9b7f041cfd..48f51c5047af6893a579b69cc93f412d88e1c20d 100644 --- a/mindspore-lite/src/litert/cpu_info.h +++ b/mindspore-lite/src/litert/cpu_info.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_CPU_INFO_H_ #if defined(ENABLE_AVX512) || defined(ENABLE_AVX) -#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" #endif inline bool PlatformInstructionSetSupportCheck() { diff --git a/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.cc b/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.cc index 3ea67a2ec021622944a2c625205254b0ac76fd14..20dfd8d42053596d6253064f7e9ed4b42f01f69e 100644 --- a/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.cc +++ b/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.cc @@ -15,7 +15,7 @@ */ #include "src/litert/delegate/coreml/op/coreml_op.h" -#include "nnacl/base/cast_base.h" +#include "nnacl_c/base/cast_base.h" namespace mindspore::lite { int CoreMLOp::Init() { auto ret = InitParams(); diff --git a/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.h b/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.h index d23dddb96a3a3aff3e4ffc80c8ed6091b3d6516f..31d6982bef3680b7b6ba8ef1ea3546f1a0f0a01e 100644 --- a/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.h +++ b/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.h @@ -30,7 +30,7 @@ #include "include/api/data_type.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_NOT_SUPPORT; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/delegate/delegate_utils.cc b/mindspore-lite/src/litert/delegate/delegate_utils.cc index 41d7ea938814c3215f4bb5a468c60590c061f9c2..c9aeeb11ca562ff2aa9bb38cef25a0f015c2ef1f 100644 --- a/mindspore-lite/src/litert/delegate/delegate_utils.cc +++ b/mindspore-lite/src/litert/delegate/delegate_utils.cc @@ -15,7 +15,7 @@ */ #include "src/litert/delegate/delegate_utils.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::lite { void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) { int hw8 = plane / C8NUM * C8NUM; diff --git a/mindspore-lite/src/litert/delegate/delegate_utils.h b/mindspore-lite/src/litert/delegate/delegate_utils.h index 7aaa9938ba569a2c94544592032d39d3d13a0d56..5843699cf6e929e55ecae2e9804c84a55fc4fa61 100644 --- a/mindspore-lite/src/litert/delegate/delegate_utils.h +++ b/mindspore-lite/src/litert/delegate/delegate_utils.h @@ -19,7 +19,7 @@ #include "include/api/delegate.h" #include "src/common/log_adapter.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { bool IsSubGraphInputTensor(const std::vector &inputs, mindspore::MSTensor input); diff --git a/mindspore-lite/src/litert/delegate/npu/CMakeLists.txt b/mindspore-lite/src/litert/delegate/npu/CMakeLists.txt index 5383d643b956e02ae268dcc715e675daa4615e78..9d94e0e4b3dfed2d2864a176c8931d2288d245d5 100644 --- a/mindspore-lite/src/litert/delegate/npu/CMakeLists.txt +++ b/mindspore-lite/src/litert/delegate/npu/CMakeLists.txt @@ -1,5 +1,5 @@ include_directories(${DDK_PATH}) -include_directories(${OPS_DIR}/kernel/cpu) +include_directories(${NNACL_DIR}/../) file(GLOB_RECURSE NPU_RUNTIME_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/op/*.cc diff --git a/mindspore-lite/src/litert/delegate/npu/npu_converter_utils.h b/mindspore-lite/src/litert/delegate/npu/npu_converter_utils.h index 6ffa3975ea89642a2178306d148fb5286ce01ec0..d76c6f2f87ed4a701394914feb6fbf0fd8a84b49 100644 --- a/mindspore-lite/src/litert/delegate/npu/npu_converter_utils.h +++ b/mindspore-lite/src/litert/delegate/npu/npu_converter_utils.h @@ -29,7 +29,7 @@ #include "include/api/data_type.h" #include "include/graph/op/all_ops.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { enum NCHW_SHAPE { NCHW_INVALID = -1, NCHW_N = 0, NCHW_C = 1, NCHW_H = 2, NCHW_W = 3 }; diff --git a/mindspore-lite/src/litert/delegate/npu/op/convolution_base_npu.cc b/mindspore-lite/src/litert/delegate/npu/op/convolution_base_npu.cc index 1b344a90a9710f982cc8a710238a04559891330a..0b2e176c388369576b868cae2db2dafe639a28c4 100644 --- a/mindspore-lite/src/litert/delegate/npu/op/convolution_base_npu.cc +++ b/mindspore-lite/src/litert/delegate/npu/op/convolution_base_npu.cc @@ -18,7 +18,7 @@ #include "src/litert/delegate/npu/npu_converter_utils.h" #include "src/litert/delegate/npu/transpose_kernel.h" #include "src/litert/delegate/delegate_utils.h" -#include "nnacl/int8/pack_int8.h" +#include "nnacl_c/int8/pack_int8.h" namespace mindspore::lite { ConvolutionBaseNPUOp::~ConvolutionBaseNPUOp() { diff --git a/mindspore-lite/src/litert/delegate/npu/op/deconvolution_npu.cc b/mindspore-lite/src/litert/delegate/npu/op/deconvolution_npu.cc index 02bf60fa7e33860a535ace9feb579f77cc1af9cb..1c171de018bb5d44337c5d3be019c4cd78fbb868 100644 --- a/mindspore-lite/src/litert/delegate/npu/op/deconvolution_npu.cc +++ b/mindspore-lite/src/litert/delegate/npu/op/deconvolution_npu.cc @@ -16,7 +16,7 @@ #include "src/litert/delegate/npu/op/deconvolution_npu.h" #include "src/litert/delegate/npu/npu_converter_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore::lite { diff --git a/mindspore-lite/src/litert/delegate/npu/op/npu_op.h b/mindspore-lite/src/litert/delegate/npu/op/npu_op.h index 9628bf8652892f56da24d4ef577971a9a38dc4ad..215d445c9fc67467afc69ba4c63d6f8ae58cebb3 100644 --- a/mindspore-lite/src/litert/delegate/npu/op/npu_op.h +++ b/mindspore-lite/src/litert/delegate/npu/op/npu_op.h @@ -28,7 +28,7 @@ #include "include/api/data_type.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_NOT_SUPPORT; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/delegate/npu/transpose_kernel.cc b/mindspore-lite/src/litert/delegate/npu/transpose_kernel.cc index 36e9409cbec9ffdf169cca6562204ceb6500850e..38c826aa9d9f1484ca551a57b6fc80cad48db62f 100644 --- a/mindspore-lite/src/litert/delegate/npu/transpose_kernel.cc +++ b/mindspore-lite/src/litert/delegate/npu/transpose_kernel.cc @@ -18,7 +18,7 @@ #include "src/litert/delegate/npu/npu_converter_utils.h" #include "src/litert/delegate/npu/op/npu_op.h" #include "src/litert/delegate/delegate_utils.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::lite { int TransposeNPUKernel::Execute() { if (perm_ != NHWC2NCHW_PERM && perm_ != NCHW2NHWC_PERM) { diff --git a/mindspore-lite/src/litert/infer_manager.cc b/mindspore-lite/src/litert/infer_manager.cc index d8a240b4171ba779c11ca7581ba64e0fd7d21998..6d7e7c2038f5c4d011068a6650eac2556c8c369c 100644 --- a/mindspore-lite/src/litert/infer_manager.cc +++ b/mindspore-lite/src/litert/infer_manager.cc @@ -23,7 +23,7 @@ #include "src/litert/cxx_api/tensor/tensor_impl.h" #include "schema/model_generated.h" #include "include/errorcode.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" #include "src/tensorlist.h" #include "include/registry/register_kernel_interface.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/infer_manager.h b/mindspore-lite/src/litert/infer_manager.h index 39465bfea5ebedfffcdf612d4dbc6c96e9d0f344..9a8766cc36a53aa03d226e78b4b2eef4d09541d5 100644 --- a/mindspore-lite/src/litert/infer_manager.h +++ b/mindspore-lite/src/litert/infer_manager.h @@ -24,8 +24,8 @@ #include #include "src/common/prim_util.h" #include "src/tensor.h" -#include "nnacl/tensor_c.h" -#include "nnacl/infer/infer.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/infer/infer.h" #include "include/api/kernel.h" #include "include/api/allocator.h" diff --git a/mindspore-lite/src/litert/inner_context.h b/mindspore-lite/src/litert/inner_context.h index 88281eb1f29158569ff253375cc6fba1fcdaa194..e5f02fb4fc1c28e697b67ff97210ed05feb21745 100644 --- a/mindspore-lite/src/litert/inner_context.h +++ b/mindspore-lite/src/litert/inner_context.h @@ -27,8 +27,8 @@ #include "src/litert/inner_allocator.h" #endif #include "thread/threadpool.h" -#include "nnacl/op_base.h" -#include "nnacl/kernel.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel.h" #ifdef ENABLE_ARM #include "src/litert/cpu_info.h" #endif diff --git a/mindspore-lite/src/litert/kernel/ascend/src/custom_interface.cc b/mindspore-lite/src/litert/kernel/ascend/src/custom_interface.cc index 3954ed4235632112374fdd6b68016d20610561de..e46cc6c9bbc68bb0df41968f56d50492fb39b473 100644 --- a/mindspore-lite/src/litert/kernel/ascend/src/custom_interface.cc +++ b/mindspore-lite/src/litert/kernel/ascend/src/custom_interface.cc @@ -19,7 +19,7 @@ #include "include/errorcode.h" #include "include/registry/register_kernel_interface.h" #include "common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::kernel { namespace acl { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.cc b/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.cc index 4455901e78e3aa00be870ac8980fcf2a1ade6bfe..37541ed7f699d28900b510d38dd153595b7217a1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.cc @@ -18,7 +18,7 @@ #include #include #include -#include "nnacl/base/arithmetic_base.h" +#include "nnacl_c/base/arithmetic_base.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.h b/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.h index f436e8df2973bb1f910960cf16548013113d4c3a..3e3ad63dc3005ea57419ab62a76152de5f2d2399 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore::kernel { class ArithmeticBaseCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/constant_of_shape.h b/mindspore-lite/src/litert/kernel/cpu/base/constant_of_shape.h index bc706e4f71df301b77c2d7bf4d967d637d20155e..b2032cab5a1c30d13a44c68df7aef965b52272e6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/constant_of_shape.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/constant_of_shape.h @@ -19,9 +19,9 @@ #include #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/constant_of_shape_parameter.h" -#include "nnacl/fp32/constant_of_shape_fp32.h" -#include "nnacl/fp16/constant_of_shape_fp16.h" +#include "nnacl_c/constant_of_shape_parameter.h" +#include "nnacl_c/fp32/constant_of_shape_fp32.h" +#include "nnacl_c/fp16/constant_of_shape_fp16.h" namespace mindspore::kernel { class ConstantOfShapeCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/custom_is_inf.cc b/mindspore-lite/src/litert/kernel/cpu/base/custom_is_inf.cc index edffea42670d722cef0a1166b546870114fc25d6..44ac568e4d0e3e5811c85f6e683e0f6c9e09187d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/custom_is_inf.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/custom_is_inf.cc @@ -17,7 +17,7 @@ #include "include/errorcode.h" #include "src/litert/kernel/cpu/base/custom_is_inf.h" #include "src/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/custom_masked_fill.cc b/mindspore-lite/src/litert/kernel/cpu/base/custom_masked_fill.cc index 85cfeab6b936f6eebef22f357d4b21bed52e0c44..cb384f0ad62af73c42f848c0174597a5ff5b9d3d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/custom_masked_fill.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/custom_masked_fill.cc @@ -17,7 +17,7 @@ #include "include/errorcode.h" #include "src/litert/kernel/cpu/base/custom_masked_fill.h" #include "src/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc b/mindspore-lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc index e118e8c123bc6f70ae8e2dc49c9f38b9f71c5f13..be974c847c3b933b4b3ad08fff253be49dc658fa 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/base/scatter_nd_binary.h" +#include "nnacl_c/base/scatter_nd_binary.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.cc b/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.cc index 46477ac9320fe0f3dfdf741177afc6cb15379d68..a13c9c9067f313a84e4ee7e9b21394a76d903a0e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.cc @@ -20,7 +20,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.h b/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.h index 3f4798272152a253f7ccfc9aaf292bb16aafa9bd..41e0458b1034e6fcc20657ab20bf549ec7b1f57f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/detection_post_process_fp32.h" +#include "nnacl_c/fp32/detection_post_process_fp32.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.cc b/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.cc index 56ddd255cd2e47cf2d79117148a0b44bb5ba22a8..7f77e9780e57159d96e370c706ad4e1061b3fa3d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/base/format_transpose.h" -#include "nnacl/base/format_transpose.h" +#include "nnacl_c/base/format_transpose.h" #include "src/litert/kernel_registry.h" using mindspore::kernel::KERNEL_ARCH; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.h b/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.h index 062bee69459be8bcc72ca898bdf9574060384195..1468808d934b284ea88a0ea3762c0e4c792b028e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/format_transpose_parameter.h" +#include "nnacl_c/format_transpose_parameter.h" namespace mindspore::kernel { class FormatTransposeCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_base.h b/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_base.h index f4d34d385cac196b8cb48963e8bcf3e02bf4f3ea..e11b85dd9f3b256a52b83e6db42d8fb36921f4b9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_base.h @@ -20,9 +20,9 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_common_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" #include "src/litert/kernel/cpu/base/group_convolution_creator.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_creator.h b/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_creator.h index 0afaa11bab4688d9a3fde085457701a66cc21a7d..c362a7ef1e4bd4f691a734d751bcdd5d079089f8 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_creator.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_creator.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "src/litert/tensor_category.h" #include "include/api/allocator.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/base/layout_transform.h b/mindspore-lite/src/litert/kernel/cpu/base/layout_transform.h index ee1b4c2dc1c7ab6c69c5ff1c024d25986629a334..db0fb59fbfcfb53a97e381191d3de766f8ee1679 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/layout_transform.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/layout_transform.h @@ -20,7 +20,7 @@ #ifdef ENABLE_FP16 #include #endif -#include "nnacl/pack.h" +#include "nnacl_c/pack.h" #include "src/tensor.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/quant_dtype_cast.cc b/mindspore-lite/src/litert/kernel/cpu/base/quant_dtype_cast.cc index ad2898e001586377897fb472462459501dbac5d5..812f6dc34de441f59bd41dcb288e0ca783a8373a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/quant_dtype_cast.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/quant_dtype_cast.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/base/quant_dtype_cast.h" #include -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" #include "src/litert/kernel_registry.h" #include "schema/model_generated.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/base/random_normal.h b/mindspore-lite/src/litert/kernel/cpu/base/random_normal.h index b352386bbd5fdd1e1016bb63870c19507ac6736f..91ba1db05d53d4437b6d083f6ce9d53a8861a713 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/random_normal.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/random_normal.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/reduce_base.h b/mindspore-lite/src/litert/kernel/cpu/base/reduce_base.h index b38421c0a95f81b56c75e318645ec86fd01a708e..963e808571a7045cd89633105b05ec6226af3a8c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/reduce_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/reduce_base.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/reduce_parameter.h" namespace mindspore::kernel { class ReduceBaseCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/resize_base.h b/mindspore-lite/src/litert/kernel/cpu/base/resize_base.h index c82068bcf499b627e2be18991b781b503ab76c5c..1f94416e52cb57c916254ddc56d0f7f641d8b969 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/resize_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/resize_base.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" using mindspore::schema::PrimitiveType_Resize; using mindspore::schema::ResizeMethod; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_base.h b/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_base.h index 6376f26e743ec0ff483a8a690fd5af71d355fe97..043481f971c11add79ac5df3c72e5e62df26ee05 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_base.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/base/scatter_nd_binary.h" +#include "nnacl_c/base/scatter_nd_binary.h" namespace mindspore::kernel { class ScatterNDCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_binary.h b/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_binary.h index 36db5771ca6dea06b6a910d361a51894ee6b76eb..88972f13a022bab3cbce180f8ab82f2ef2165d1e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_binary.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_binary.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/base/scatter_nd_binary.h" +#include "nnacl_c/base/scatter_nd_binary.h" namespace mindspore::kernel { constexpr int kScatterUpdateInputIndex = 0; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/split_base.h b/mindspore-lite/src/litert/kernel/cpu/base/split_base.h index 582905e36a3ac0fa65d1642246bbf6f27f2b18ae..8ec8ce5c3cdaa5e6729fe317b44a469dbcb912cf 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/split_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/split_base.h @@ -20,8 +20,8 @@ #include #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/split_parameter.h" -#include "nnacl/base/split_base.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/base/split_base.h" namespace mindspore::kernel { class SplitBaseCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.cc b/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.cc index 88efe29b1b5529eead6c099ae633277bce7a5b17..77cf39ceeac99738de082d32360cdf4c3df78628 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.cc @@ -17,7 +17,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "src/tensor.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.h b/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.h index df3e84eafb26f795e66f09a06c6c6213eb4458c8..dc674980848d21a50a2d23ef30623725cd9da213 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.h @@ -20,8 +20,8 @@ #include #include "include/errorcode.h" #include "src/executor/kernel_exec.h" -#include "nnacl/split_parameter.h" -#include "nnacl/base/split_with_over_lap_base.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/base/split_with_over_lap_base.h" namespace mindspore::kernel { class SplitWithOverlapBaseCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/transpose_base.h b/mindspore-lite/src/litert/kernel/cpu/base/transpose_base.h index 930da7dacf14621ee89bd3d4d6fd5e5696a05a22..d664ab2b0c8bdee765b208af0bd918b3cfbe3308 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/transpose_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/transpose_base.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_BASE_TRANSPOSE_BASE_H_ #include -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.cc b/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.cc index 28203bed63af3569499c5cb544f5e4322178d8af..233545ab8110a091e0cade80bd906953f1ee5053 100644 --- a/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.cc +++ b/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.cc @@ -16,7 +16,7 @@ #include "bolt/bolt_parameter_manager.h" #include "bolt/bolt_utils.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "schema/ops_generated.h" namespace mindspore::kernel::bolt { diff --git a/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.h b/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.h index b0702b92c44d733d83298eca75538e8b3ba096cc..87720f64f89caa3beb71ec0215ac352d0a3cc0f6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.h +++ b/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.h @@ -19,7 +19,7 @@ #include #include "bolt/common/uni/include/parameter_spec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_adapter.h" namespace mindspore::kernel::bolt { diff --git a/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_utils.h b/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_utils.h index 4763820421ece3e9f3dbd59914896ac849e1272a..00da8cdd9bc856a0a900cead2f0fc7ccefd7beb9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_utils.h +++ b/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_utils.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BOLT_BOLT_UTILS_H_ #include "bolt/common/memory/include/tensor.hpp" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "bolt/common/uni/include/parameter_spec.h" typedef Tensor BoltTensor; diff --git a/mindspore-lite/src/litert/kernel/cpu/bolt/convolution_bolt.cc b/mindspore-lite/src/litert/kernel/cpu/bolt/convolution_bolt.cc index c7ea263916043adebd1f7b4e8b8eef2c744fc307..b7738db39e741c51a5f6869441623faf1b54b3fa 100644 --- a/mindspore-lite/src/litert/kernel/cpu/bolt/convolution_bolt.cc +++ b/mindspore-lite/src/litert/kernel/cpu/bolt/convolution_bolt.cc @@ -16,8 +16,8 @@ #include "bolt/convolution_bolt.h" #include "bolt/bolt_kernel_manager.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/pack.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/pack.h" #include "bolt/compute/tensor/include/tensor_computing.h" #include "bolt/common/memory/include/tensor_desc.h" #include "bolt/bolt_tensor_utils.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/control/tensor_array.h b/mindspore-lite/src/litert/kernel/cpu/control/tensor_array.h index 3d691bec25633ead0e9dde407446f8e143f5b8e0..6413ccdb40f557e32f0802b77ca6dcd6c863f9f3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/control/tensor_array.h +++ b/mindspore-lite/src/litert/kernel/cpu/control/tensor_array.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/tensor_array_parameter.h" +#include "nnacl_c/tensor_array_parameter.h" #include "src/litert/lite_kernel.h" #include "src/tensorlist.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_fromtensor.h b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_fromtensor.h index c4f8836a99bdcd98974c9c3b32c2d66442b1dab1..0cbf87fe40ea12ca033b68a8506a92ae7bcf6ce3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_fromtensor.h +++ b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_fromtensor.h @@ -21,7 +21,7 @@ #include "src/litert/lite_kernel.h" #include "src/tensorlist.h" #include "schema/model_generated.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" namespace mindspore::kernel { class TensorListFromTensorCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_getitem.h b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_getitem.h index d2d81699b97d20dea9b1f007a29a5cdac626b207..374c82340596c5cd638443c87a89c4467aad0327 100644 --- a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_getitem.h +++ b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_getitem.h @@ -21,7 +21,7 @@ #include "src/litert/lite_kernel.h" #include "src/tensorlist.h" #include "schema/model_generated.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" namespace mindspore::kernel { class TensorListGetItemCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_reserve.h b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_reserve.h index 884e0498e4b6d308d0a412f3c9e57c3acddd5603..4fdb8d9299324ada4258f8ad0e8e92e6294f55ed 100644 --- a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_reserve.h +++ b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_reserve.h @@ -21,7 +21,7 @@ #include "src/litert/lite_kernel.h" #include "src/tensorlist.h" #include "schema/model_generated.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" namespace mindspore::kernel { class TensorListReserveCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_setitem.h b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_setitem.h index 032646bc74081a5633d1b2b70fb5610f5d7d029e..f3f516ba64c1eab2109d4551493a8bb8c527a3c3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_setitem.h +++ b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_setitem.h @@ -21,7 +21,7 @@ #include "src/litert/lite_kernel.h" #include "src/tensorlist.h" #include "schema/model_generated.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" namespace mindspore::kernel { class TensorListSetItemCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_stack.h b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_stack.h index f97c09d8153a4fb5b9c5fbbc09e49a9414c5cd04..85d3f03f00cb1b42c594d46739d4a1516e947c82 100644 --- a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_stack.h +++ b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_stack.h @@ -22,7 +22,7 @@ #include "src/litert/lite_kernel.h" #include "src/tensorlist.h" #include "schema/model_generated.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" namespace mindspore::kernel { class TensorListStackCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/biasadd_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/biasadd_fp16.h index 25d977fd3307f79d08f8a25cff27bf3f2078c9e9..c4b838cd6457ffb881ac86e7bce60d9c3bfa6495 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/biasadd_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/biasadd_fp16.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_BIASADD_FP16_H_ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp16/arithmetic_fp16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" namespace mindspore::kernel { class BiasAddCPUFp16Kernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/cast_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/cast_fp16.h index c88a68e123e4746fc03b41bedae5cc1a60c2175d..447daff336d015c58e58f103eb7ca224ffa8914b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/cast_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/cast_fp16.h @@ -18,9 +18,9 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" -#include "nnacl/fp16/cast_fp16.h" -#include "nnacl/base/cast_base.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/base/cast_base.h" namespace mindspore::kernel { class CastFp16CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/common_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/common_fp16.cc index 47ac4de3d0d5d3840ee80eaafba99b91054acf7f..1a1d5dc6ae9e331a9d9f459c3860000df827156e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/common_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/common_fp16.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/litert/kernel/cpu/fp16/common_fp16.h" -#include "nnacl/fp16/cast_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" #include "include/errorcode.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.cc index 536b962cca7de4ebe8cf955a9ed5aab26cdbe517..493f7fa0c08165957b8858e671e7f1252e31c860 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.cc @@ -15,10 +15,10 @@ */ #include "src/litert/kernel/cpu/fp16/convolution_1x1_fp16.h" -#include "nnacl/base/conv1x1_base.h" -#include "nnacl/fp16/conv_fp16.h" -#include "nnacl/fp16/cast_fp16.h" -#include "nnacl/fp16/pack_fp16.h" +#include "nnacl_c/base/conv1x1_base.h" +#include "nnacl_c/fp16/conv_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" #include "src/litert/kernel/cpu/fp16/layout_transform_fp16.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.h index 76cde5b90207c17448e9cf3eb2109ad1fa5c64dc..640e789ad261a9864aa527b64f37279937151993 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.h @@ -22,8 +22,8 @@ #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" #include "src/common/utils.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/fp16/matmul_fp16.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/fp16/matmul_fp16.h" namespace mindspore::kernel { class Convolution1x1FP16CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.cc index 42288a210c114494eb266131ff256d19e5177a25..53d04f421f2fa7293830f34d08ed36b9d1f1c44c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.cc @@ -24,7 +24,7 @@ #include "src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.h" #include "src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.h" #include "src/litert/kernel/cpu/base/group_convolution_creator.h" -#include "nnacl/base/conv_common_base.h" +#include "nnacl_c/base/conv_common_base.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.h index c1581f525995e323c68924df220272ac35cc547c..94c294d3b507ffce5c37b633e3450d913703f113 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.h @@ -19,8 +19,8 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/op_base.h" #define WEIGHT_NEED_FREE 0001 #define BIAS_NEED_FREE 0010 diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.cc index 435560d74cc9972a5bdcd3eab63e664096df35e3..88e9f7202c59f214424d609796810ac58dc8c526 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.cc @@ -17,8 +17,8 @@ #ifdef ENABLE_ARM #include "src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.h" #include "include/errorcode.h" -#include "nnacl/fp16/pack_fp16.h" -#include "nnacl/fp16/conv_depthwise_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/fp16/conv_depthwise_fp16.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INFER_INVALID; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.h index 303837245d1c0088c700a21ea44ce916055d44cc..28a8714aa3c1b92c0d7c1b247ba6c3d4f7dfa74b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.h @@ -21,7 +21,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwise3x3Fp16CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.cc index 04312bb66e3f62a182b8e5eb07d8e520af3015e1..b38c694a780126943d2fc3e23430dbac56f70f08 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.cc @@ -15,8 +15,8 @@ */ #include "src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.h" -#include "nnacl/fp16/pack_fp16.h" -#include "nnacl/fp16/cast_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" #include "include/errorcode.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.h index 706faeecabcff2675264dcc23466b6594dd54379..bc8ded4ba674d5c63f32b1125d1d39452b8be64e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp16/conv_depthwise_fp16.h" +#include "nnacl_c/fp16/conv_depthwise_fp16.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.cc index bdeb2e574b6c3b87327ed08008f02020f965463f..50edd50dfe7621d8e94d58512c4100a105c53747 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.cc @@ -15,8 +15,8 @@ */ #include "src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.h" -#include "nnacl/fp16/pack_fp16.h" -#include "nnacl/fp16/cast_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" #include "include/errorcode.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.h index c944041e2a6f7bff103ccfd0065d8555300aebc1..a535bedf60d702e8898274ebe13d2ee5d179a1db 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp16/conv_depthwise_fp16.h" +#include "nnacl_c/fp16/conv_depthwise_fp16.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_fp16.cc index cebd9ea740f2e065c8a1340e49a9860fa9456264..a8d04577dfacac34334d4f094b72892c1e8ce6b9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_fp16.cc @@ -17,11 +17,11 @@ #include "src/litert/kernel/cpu/fp16/convolution_fp16.h" #include #include "include/errorcode.h" -#include "nnacl/fp16/conv_fp16.h" -#include "nnacl/fp16/matmul_fp16.h" -#include "nnacl/fp16/cast_fp16.h" -#include "nnacl/fp16/pack_fp16.h" -#include "nnacl/fp16/winograd_utils_fp16.h" +#include "nnacl_c/fp16/conv_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/fp16/winograd_utils_fp16.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_winograd_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_winograd_fp16.h index f15c93c06005b5ab4187e6ad72e19064b338c6b8..c12cf98085e8d48be7a268b044d90990f4213fb5 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_winograd_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_winograd_fp16.h @@ -21,10 +21,10 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp16/conv_fp16.h" -#include "nnacl/fp16/winograd_utils_fp16.h" +#include "nnacl_c/fp16/conv_fp16.h" +#include "nnacl_c/fp16/winograd_utils_fp16.h" #include "src/common/utils.h" -#include "nnacl/base/minimal_filtering_generator.h" +#include "nnacl_c/base/minimal_filtering_generator.h" namespace mindspore::kernel { class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/custom_gru_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/custom_gru_fp16.cc index 1487188372c99fd726fb34a954f2671c70ee8aab..4dda8d7af463bade03041d6827253e548919f54c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/custom_gru_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/custom_gru_fp16.cc @@ -20,10 +20,10 @@ #include "include/errorcode.h" #include "src/common/log_adapter.h" #include "src/litert/pack_weight_manager.h" -#include "nnacl/fp16/pack_fp16.h" -#include "nnacl/custom_gru_parameter.h" -#include "nnacl/fp16/custom_gru_fp16.h" -#include "nnacl/fp16/matmul_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/custom_gru_parameter.h" +#include "nnacl_c/fp16/custom_gru_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.cc index b42ccdc216c38523d4988fc6b9b171288cbbf21b..9059260bafd172703187473480a8eedf0ab2f224 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.h" -#include "nnacl/fp16/pack_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" #include "include/errorcode.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.h index 585572209cc8939e7d2a0d6ca5247baf220138f2..ea28cb1c483840feb5ebf3ea95b4dadf3e62b3b2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp16/conv_depthwise_fp16.h" +#include "nnacl_c/fp16/conv_depthwise_fp16.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_fp16.h index 4a39cbe4621ced27e71e7ecabd52a58016e13076..2966679775e7f6253fb97bf138365d26c8e2478b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_fp16.h @@ -18,8 +18,8 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_DECONVOLUTION_FP16_H_ #include -#include "nnacl/fp16/deconv_fp16.h" -#include "nnacl/fp16/matmul_fp16.h" +#include "nnacl_c/fp16/deconv_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" #include "src/litert/kernel_registry.h" #include "src/litert/kernel/cpu/base/convolution_base.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_winograd_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_winograd_fp16.h index 12120d9c0206bd27fae033098f648029e5eb90dc..6e276d3f08dd79105afe646825e69af9aecfa015 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_winograd_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_winograd_fp16.h @@ -19,9 +19,9 @@ #include #include "include/errorcode.h" -#include "nnacl/fp16/common_func_fp16.h" -#include "nnacl/fp16/deconv_winograd_fp16.h" -#include "nnacl/fp16/pack_fp16.h" +#include "nnacl_c/fp16/common_func_fp16.h" +#include "nnacl_c/fp16/deconv_winograd_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/dynamic_quant_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/dynamic_quant_fp16.cc index 5857eeae3a5ca28e3e18ef37006dba18a477f48a..0722c9d6747017cb56b698f5c91ae26f3ac78fb4 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/dynamic_quant_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/dynamic_quant_fp16.cc @@ -19,9 +19,9 @@ #include "src/litert/kernel_registry.h" #include "schema/model_generated.h" #include "include/errorcode.h" -#include "nnacl/dynamic_quant_parameter.h" -#include "nnacl/fp16/dynamic_quant_fp16.h" -#include "nnacl/fp16/quant_dtype_cast_fp16.h" +#include "nnacl_c/dynamic_quant_parameter.h" +#include "nnacl_c/fp16/dynamic_quant_fp16.h" +#include "nnacl_c/fp16/quant_dtype_cast_fp16.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/fp16_op_handler.h b/mindspore-lite/src/litert/kernel/cpu/fp16/fp16_op_handler.h index b7656eebcf80af2a636b35f5136b1db54e112f07..7447359def21aeb6c620f4c08a9a769a6df3a6da 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/fp16_op_handler.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/fp16_op_handler.h @@ -16,10 +16,10 @@ #ifdef ENABLE_ARM #include #ifdef ENABLE_FP16 -#include "nnacl/fp16/cast_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" #endif #endif -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/group_convolution_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/group_convolution_fp16.h index ca6d09f46525cc82a96482bfa183bb4cbef46a25..1441abc419362424c4c04aa87178fa6a0334f222 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/group_convolution_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/group_convolution_fp16.h @@ -20,9 +20,9 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/group_convolution_base.h" -#include "nnacl/fp16/conv_fp16.h" +#include "nnacl_c/fp16/conv_fp16.h" namespace mindspore::kernel { class GroupConvolutionFP16CPUKernel : public GroupConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.cc index c971a33ba42b57d97ed5ff8f9675a2f0474cc63e..5bdbdd9fd041b2600e74ce02d1180926a27f2f2a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.cc @@ -18,10 +18,10 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp16/gru_fp16.h" -#include "nnacl/fp16/cast_fp16.h" -#include "nnacl/fp16/lstm_fp16.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp16/gru_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/fp16/lstm_fp16.h" +#include "nnacl_c/errorcode.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.h index baf5191a652de3a536670f65371ee728dd441bbf..c5bf5dacc89d054300be2344060db4f57e97e213 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_GRU_FP16_H_ #include #include "src/litert/lite_kernel.h" -#include "nnacl/gru_parameter.h" +#include "nnacl_c/gru_parameter.h" namespace mindspore::kernel { class GruFp16CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.cc index 0c599fdfe59488cb7a2dc5aec9b64aaa299c2dc3..9dc1d9ca49a862c24d1297ee806154308bdc362e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.cc @@ -17,9 +17,9 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp16/cast_fp16.h" -#include "nnacl/fp16/instance_norm_fp16.h" -#include "nnacl/fp16/pack_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/fp16/instance_norm_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.h index da58e101d784d404fee0546743d341a0e4c39117..31009de6012117ad2c1fa76976a0bfb31f5e2534 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_INSTANCE_NORM_FP16_H_ #include #include "src/litert/lite_kernel.h" -#include "nnacl/instance_norm_parameter.h" +#include "nnacl_c/instance_norm_parameter.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/layout_transform_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/layout_transform_fp16.cc index 73ec20258f0011e4cd9bda1ddbce8e687e80a463..361a0870a4148e6ec8fca55be1ee1ec24e135058 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/layout_transform_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/layout_transform_fp16.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/litert/kernel/cpu/fp16/layout_transform_fp16.h" -#include "nnacl/fp16/pack_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" #include "src/common/log_adapter.h" #include "schema/ops_types_generated.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.cc index e358e58b5648f9e6858a0e1195a3f992bade3075..2d82ff3248243247fdb20e5f4280a51ea1aec2b1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.cc @@ -16,8 +16,8 @@ #include "src/litert/kernel/cpu/fp16/lstm_fp16_base.h" #include -#include "nnacl/fp16/lstm_fp16.h" -#include "nnacl/fp16/cast_fp16.h" +#include "nnacl_c/fp16/lstm_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.h b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.h index a5c155485ab187a23b9995257715b76988556747..f68ac600f318a8e1d5cb2c9de4d54342124e7d04 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/lstm_parameter.h" +#include "nnacl_c/lstm_parameter.h" namespace mindspore::kernel { class LstmFp16BaseCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.cc index b8100db07ec206f5fd8732a3d13c99517f027062..4977df7be6f8a432efe916196bb6f835c084f094 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp16/lstm_mindir_fp16.h" -#include "nnacl/fp16/lstm_fp16.h" +#include "nnacl_c/fp16/lstm_fp16.h" namespace mindspore::kernel { namespace { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.cc index c6adc6aa68a69879af35413e1d716288b0cab074..cf7a32e42c7ab5ba92f03db6efff017b94e3d7de 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.h" -#include "nnacl/fp16/lstm_fp16.h" +#include "nnacl_c/fp16/lstm_fp16.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc index 982daeefc94f36b458b7b15a5506669e5da59aa8..9f01f0f51edb3dc9c4984818033d39f8dbc7bb55 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc @@ -16,8 +16,8 @@ #include "src/litert/kernel/cpu/fp16/matmul_base_fp16.h" #include -#include "nnacl/fp16/matmul_fp16.h" -#include "nnacl/fp16/cast_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" #include "include/errorcode.h" using mindspore::lite::kCHWDimNumber; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.h index fd7a27d67a57837f8dd042b5c321923b3a34fdcf..3a467f3401ecd539cc3770a81f8b72e040a37e5e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.h @@ -23,7 +23,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/common/common.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::kernel { class MatmulBaseFP16CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/quant_dtype_cast_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/quant_dtype_cast_fp16.cc index a8428c39a796b76799bede33469cd1e841c06f69..4841a3bb3bd65c1a89ff823ea7f2615c9b032f9c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/quant_dtype_cast_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/quant_dtype_cast_fp16.cc @@ -15,8 +15,8 @@ */ #include "src/litert/kernel/cpu/fp16/quant_dtype_cast_fp16.h" #include -#include "nnacl/int8/quant_dtype_cast_int8.h" -#include "nnacl/fp16/quant_dtype_cast_fp16.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/fp16/quant_dtype_cast_fp16.h" #include "src/litert/kernel_registry.h" #include "schema/model_generated.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/resize_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/resize_fp16.h index cdae563fb67370346eb28f6445476c89f854d2bb..faf0944b869ade12988a50bb00ea3982acb0a044 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/resize_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/resize_fp16.h @@ -19,7 +19,7 @@ #include #include #include "src/litert/kernel/cpu/fp32/resize_fp32.h" -#include "nnacl/fp16/resize_fp16.h" +#include "nnacl_c/fp16/resize_fp16.h" namespace mindspore::kernel { class ResizeFp16CPUKernel : public ResizeCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.cc index 8d54bb34ca649f950e0afa720daf8a7a1d256979..f6d1a1d2230c26f1585d6f8e26c495895f1bdc1a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.h" -#include "nnacl/fp16_grad/activation_grad_fp16.h" +#include "nnacl_c/fp16_grad/activation_grad_fp16.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.h index 0b81bd699e16df5fe85e204b580a4a0c60917ec2..ee88b288076b454e8d8d5ee411b4b7c6fd2e087b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp16_grad/activation_grad_fp16.h" -#include "nnacl/fp32_grad/activation_grad_fp32.h" +#include "nnacl_c/fp16_grad/activation_grad_fp16.h" +#include "nnacl_c/fp32_grad/activation_grad_fp32.h" namespace mindspore::kernel { class ActivationGradCPUKernelFp16 : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.cc index 15aa2fe45b6f368bc670649e23f490f7a4bf6dce..064eb5f38c111ef443bf13bf2524edfeb9695c8d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.cc @@ -17,7 +17,7 @@ #include "src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp16_grad/arithmetic_grad.h" +#include "nnacl_c/fp16_grad/arithmetic_grad.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.h index 114a8e5d57b3002619588f86083db2cce85a52c1..20c45cebae762d738228c5a639dc7a524b39205c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp16/arithmetic_fp16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" #include "schema/model_generated.h" using mindspore::schema::PrimitiveType_AddGrad; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_self_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_self_grad.h index 105db40eee3a2cd11e486fffe6ded29b23ed63f1..ac5e6e73746c0f46714214f36f2ec6f16e7af945 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_self_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_self_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp16_grad/arithmetic_self_grad.h" +#include "nnacl_c/fp16_grad/arithmetic_self_grad.h" namespace mindspore::kernel { class ArithmeticSelfGradFp16CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bias_fp16_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bias_fp16_grad.h index 685ede33c1aeb112eb159047e1c168ee1da85192..24c72273aec008839e36da0f796e30a7ad12496f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bias_fp16_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bias_fp16_grad.h @@ -19,7 +19,7 @@ #include #include "src/executor/kernel_exec.h" -#include "nnacl/fp16/arithmetic_fp16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" namespace mindspore::kernel { class BiasGradCPUKernelFp16 : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.cc index 25787bad42d28b2ff7f1b1b96fa816bcb4c73315..b8817f9ee73a42f2d4126b8807e7955047ac961f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.cc @@ -24,7 +24,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp16_grad/batch_norm.h" +#include "nnacl_c/fp16_grad/batch_norm.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.h index 7c9757bcd787df9f95c04b9d1a806d45a8a1279d..6b931821b73c12ca8681278b655046975382721a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.h @@ -19,7 +19,7 @@ #include #include "src/executor/kernel_exec.h" -#include "nnacl/fp32_grad/batch_norm_grad.h" +#include "nnacl_c/fp32_grad/batch_norm_grad.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_filter.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_filter.cc index 6eeaf07ce04303f480f25360d92e5a9e042345b6..29b3f479c5e5e6c280339482c10bb98df0eb27f1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_filter.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_filter.cc @@ -16,11 +16,11 @@ #include "src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_filter.h" #include "src/litert/kernel_registry.h" -#include "nnacl/pack.h" -#include "nnacl/fp16_grad/convolution_grad_filter.h" -#include "nnacl/fp16_grad/pack_fp16_ext.h" -#include "nnacl/fp16_grad/gemm_fp16.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/pack.h" +#include "nnacl_c/fp16_grad/convolution_grad_filter.h" +#include "nnacl_c/fp16_grad/pack_fp16_ext.h" +#include "nnacl_c/fp16_grad/gemm_fp16.h" +#include "nnacl_c/errorcode.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_input.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_input.cc index 3ac5983292f88012e84c33529f1a1289b2fbda26..80c65cb7240d4b2894d6bb04e11d0324e5779730 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_input.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_input.cc @@ -16,10 +16,10 @@ #include "src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_input.h" #include "src/litert/kernel_registry.h" -#include "nnacl/pack.h" -#include "nnacl/fp16_grad/pack_fp16_ext.h" -#include "nnacl/fp16_grad/gemm_fp16.h" -#include "nnacl/fp16_grad/convolution_grad_input.h" +#include "nnacl_c/pack.h" +#include "nnacl_c/fp16_grad/pack_fp16_ext.h" +#include "nnacl_c/fp16_grad/gemm_fp16.h" +#include "nnacl_c/fp16_grad/convolution_grad_input.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/dropout_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/dropout_fp16_grad.cc index e37547ed573e8503452d9c61eec9343d4d25a18f..86ad09d5a11e5414339d1da8abf08230489a8b41 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/dropout_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/dropout_fp16_grad.cc @@ -16,11 +16,11 @@ #include #include "src/litert/kernel/cpu/fp16_grad/dropout_fp16_grad.h" -#include "nnacl/fp16_grad/dropout_grad.h" +#include "nnacl_c/fp16_grad/dropout_grad.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32_grad/dropout_parameter.h" +#include "nnacl_c/fp32_grad/dropout_parameter.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/layernorm_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/layernorm_fp16_grad.cc index 6576ac15659951dbd67eb54ba286b53ca56e0224..a77b764a33c5d13be30d45fa07d9003e7793b5ff 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/layernorm_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/layernorm_fp16_grad.cc @@ -19,8 +19,8 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp16_grad/layernorm_grad.h" -#include "nnacl/fp32_grad/layernormgrad_parameter.h" +#include "nnacl_c/fp16_grad/layernorm_grad.h" +#include "nnacl_c/fp32_grad/layernormgrad_parameter.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/neg_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/neg_fp16_grad.cc index 7a56be8e8b6f3c4fb1cfd2b1d0158ddb4a4b882c..10cf6d17bc92ecf2f02691a4504cfa29d41cb17b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/neg_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/neg_fp16_grad.cc @@ -18,7 +18,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp16/arithmetic_self_fp16.h" +#include "nnacl_c/fp16/arithmetic_self_fp16.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.cc index 1cd92cf08aac87f4b2f407e7810c56a5161bc771..2dc0d177f8989705baf4066d177e4fb4363d3cd3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp16/pooling_fp16.h" -#include "nnacl/fp16_grad/pooling_grad.h" +#include "nnacl_c/fp16/pooling_fp16.h" +#include "nnacl_c/fp16_grad/pooling_grad.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.h index 678b1a3cf85f8dbaef5e096331b090d0b30a43a1..10bf78d79126d40a683572ec2aeca5a139929645 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/kernel/pooling.h" +#include "nnacl_c/kernel/pooling.h" namespace mindspore::kernel { using mindspore::schema::PadMode; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/resize_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/resize_fp16_grad.cc index f1050657a4b866716e538803a4f79100855d4313..57064d04641486a0f412edb4c7fe38656073ef4b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/resize_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/resize_fp16_grad.cc @@ -16,8 +16,8 @@ #include #include "src/litert/kernel/cpu/fp16_grad/resize_fp16_grad.h" -#include "nnacl/fp16_grad/resize_grad.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp16_grad/resize_grad.h" +#include "nnacl_c/errorcode.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.cc index b6db3387c6917684fb3cb922a24b3c918ddbd191..ca0fb484012eeb56d254ae58c497ef1f0d48ceb0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.cc @@ -19,7 +19,7 @@ #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp16_grad/strided_slice_grad.h" +#include "nnacl_c/fp16_grad/strided_slice_grad.h" #include "src/common/ops/populate/strided_slice_populate.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.h index f90a59156ba10bfddd82140bf407fca236a2b3b7..4c9015473ad97f835ffe1a1664158e3a7380d29c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_GRAD_STRIDED_SLICE_FP16_GRAD_H_ #include -#include "nnacl/fp16_grad/strided_slice_grad.h" +#include "nnacl_c/fp16_grad/strided_slice_grad.h" #include "src/executor/kernel_exec.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/unsorted_segment_sum_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/unsorted_segment_sum_fp16.cc index b7c1703f40a3d7164da686bca83bd955cc369925..f6e1eefcc505d1234787c17ade63981935ca94ab 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/unsorted_segment_sum_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/unsorted_segment_sum_fp16.cc @@ -19,7 +19,7 @@ #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp16_grad/unsorted_segment_sum.h" +#include "nnacl_c/fp16_grad/unsorted_segment_sum.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.cc index a3b712e2d923faa6f78296ea81c06ea829bee077..1095386120ae51a5e1706965fb20a78f9d679e3a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.cc @@ -19,8 +19,8 @@ #include "src/litert/kernel_registry.h" #include "include/errorcode.h" #include "schema/model_generated.h" -#include "nnacl/fp32/adder_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/adder_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.h index 366e930a765db89aa47ed937e389803cdaa93747..b2b1a7e80e3b10b308d90b748598ffde40ece4b2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.h @@ -20,7 +20,7 @@ #ifdef ENABLE_NNACL_KERNEL_LIB #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/fp32/convolution_fp32.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.cc index 8f2d96b6e738ba7dc9d83fe770d723680dccfdf3..a183abc9e3e16d3c7540c422e990c39655adb996 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.cc @@ -19,8 +19,8 @@ #include #include "src/litert/kernel/cpu/fp32/matmul_fp32.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/fp32/splice_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/fp32/splice_fp32.h" #include "src/common/utils.h" using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.h index 4cf296fee660e4505eec0d94a1181fc15b0efe0a..917c0d8f89d650a6aaf2b3f6c3cce6944568b4fb 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/affine_parameter.h" -#include "nnacl/splice_parameter.h" +#include "nnacl_c/affine_parameter.h" +#include "nnacl_c/splice_parameter.h" namespace mindspore::kernel { constexpr auto kAffineMinInputNum = 2; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/all_gather_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/all_gather_fp32.h index 04f8066b11c18c57938113c4bab402f5595f175e..80585bf0c7405e70f846ef7abf8ec9909eb12328 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/all_gather_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/all_gather_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/all_gather_parameter.h" +#include "nnacl_c/all_gather_parameter.h" namespace mindspore::kernel { class AllGatherCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/arithmetic_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/arithmetic_fp32.cc index 9794faf021520ae196c62213f036e1bd16abadc5..bb290ccc93cc87bdc21bc26f39fd4eda1940dc29 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/arithmetic_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/arithmetic_fp32.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/fp32/arithmetic_fp32.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/broadcast_to_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/broadcast_to_fp32.h index 90d158e40d781ac067e9cf2008472d151aa1931d..76360efd3c56067498485c923d7c8919b50f3975 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/broadcast_to_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/broadcast_to_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/base/broadcast_to.h" +#include "nnacl_c/base/broadcast_to.h" namespace mindspore::kernel { class BroadcastToCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/cast_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/cast_fp32.h index e10cfacd980112a4bcb909bd5d9d559d5827cd3a..c90b2156fa8fc6b5f52df1148e70048ee1f54474 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/cast_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/cast_fp32.h @@ -20,8 +20,8 @@ #include "include/errorcode.h" #include "src/litert/lite_kernel.h" #include "src/tensor.h" -#include "nnacl/op_base.h" -#include "nnacl/base/cast_base.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/cast_base.h" namespace mindspore::kernel { class CastCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_1x1_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_1x1_fp32.h index 134827211eef7189419cb672192883bb7e785c5b..15efa5b86dab3a1c51e00c6f04d81d29b801589a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_1x1_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_1x1_fp32.h @@ -21,13 +21,13 @@ #include #include "src/litert/lite_kernel.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/convolution_base.h" #include "src/litert/kernel/cpu/base/layout_transform.h" -#include "nnacl/base/conv1x1_base.h" -#include "nnacl/fp32/common_func_fp32.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/base/conv1x1_base.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/fp32/matmul_fp32.h" namespace mindspore::kernel { class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.cc index ea46a4152675c51bba33129c573c75a4e27b566c..9d26b96c0fc7b579bb32832281362115d841494e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.cc @@ -25,8 +25,8 @@ #include "src/litert/kernel/cpu/base/group_convolution_creator.h" #include "src/litert/kernel/cpu/fp32/group_convolution_fp32.h" #include "src/litert/kernel/cpu/fp32/convolution_sw_1x1_fp32.h" -#include "nnacl/base/conv_common_base.h" -#include "nnacl/fp32/conv_sw_arm64_fp32.h" +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/fp32/conv_sw_arm64_fp32.h" #include "schema/model_generated.h" #include "include/errorcode.h" #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.h index 6fb53eede63a3488e89cd8fb3297dc84b9c6fd98..3fcbf754a68653079cdece31e23939ef1c0a009e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.h @@ -18,9 +18,9 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/op_base.h" using mindspore::lite::InnerContext; namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.cc index f4200f54f709c1bbebc64b73b72919b77e381edf..bacd4d528b6b33f217620b2d1102ca8d8c8c4287 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.h index e186617990ba4332f16b31246bb03c857e7293fb..fa3ed605ff2057414af7e499913bbbc6356473c3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.h @@ -21,7 +21,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.cc index 510799e09814f74efdf0e3d3ecf7fab2bdfafaac..3ea5ac11183386c3490767685671ecd0303ed192 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.cc @@ -15,10 +15,10 @@ */ #include "src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.h" -#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" #include "include/errorcode.h" #include "src/litert/pack_weight_manager.h" -#include "nnacl/fp32/conv_depthwise_avx_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_avx_fp32.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INFER_INVALID; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.h index 6179c437863401381b16d63e7bd731e43ae6aa38..8bdef907b8a892cc70991d9c249868b89aa16a71 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.h @@ -21,7 +21,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_indirect_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_indirect_fp32.h index 9840e704e039cf11b67a9d31e376d411971560d5..ae08a9de36c45dd97a04fc8bec55377d42dbbe61 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_indirect_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_indirect_fp32.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwiseIndirectCPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_fp32.h index e4d4bb46455af072984cd7342aa8d427937917d9..ecc94cd9ae21a00a40ad03e9a6bfcbbd9c683734 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_fp32.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwiseSWCPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_x86_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_x86_fp32.h index e959fe457557e720f8df8fb4032c6f73946ee50d..d6b7ca98689719082c250feb0dce8e424d5454c3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_x86_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_x86_fp32.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwiseSWCPUKernelX86 : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.cc index b813ef80cb5f9556d4aea0b2c79da3a0ad5ca89f..6ce671278f3f71baea478a8f0f43e87e5e109224 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.cc @@ -18,11 +18,11 @@ #include "src/litert/kernel/cpu/fp32/convolution_fp32.h" #include "src/litert/pack_weight_manager.h" #include "include/errorcode.h" -#include "nnacl/common_func.h" +#include "nnacl_c/common_func.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/conv_common_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INFER_INVALID; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.h index 0c2729911e03b4a76472271730884d32a6b56481..5ca88340f60926af1a2bab5a10148aa4f2358b63 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.h @@ -20,7 +20,7 @@ #ifdef ENABLE_NNACL_KERNEL_LIB #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_arm64_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_arm64_fp32.cc index ed7d31fb6b862e03fdc24d3629e28f3924b277a7..de407f717ff4a6ce3cb7adf2dde12b91c5f18c78 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_arm64_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_arm64_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/convolution_im2col_arm64_fp32.h" -#include "nnacl/fp32/conv_common_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INFER_INVALID; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx512_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx512_fp32.cc index d892b94ba33a664142e996ce16e1ad913501072d..0b7ba386689159479b72fa14ad9b0c5c80e764f6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx512_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx512_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/convolution_im2col_avx512_fp32.h" -#include "nnacl/fp32/conv_im2col_avx512_fp32.h" +#include "nnacl_c/fp32/conv_im2col_avx512_fp32.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_NULL_PTR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx_fp32.cc index 33a3368feac1df1170a4b7839d0264b855cada92..aaf7f597859d667086cbbd8c0aa14d34d6bdd38e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/convolution_im2col_avx_fp32.h" -#include "nnacl/fp32/conv_common_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INFER_INVALID; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.cc index 9808547309a1ead2336b1080faf97af2db8dcd5b..9484973a67e3f04a360e8724012161c3089d3c58 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.cc @@ -17,11 +17,11 @@ #include "src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.h" #include "src/litert/pack_weight_manager.h" #include "include/errorcode.h" -#include "nnacl/common_func.h" +#include "nnacl_c/common_func.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/conv_common_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INFER_INVALID; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.h index b7d80d438c65429af0027c513fa51896d416fa96..70d00e87ddf79c2f6c292e27450b2d1d3cfcaaf8 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc index 409a1ee7cc20fcd0b4ff5ebd9cd9b43d25d00f60..908af5b0125c58fe852b06766554d3a602326d40 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc @@ -34,7 +34,7 @@ #if defined(ENABLE_ARM64) #include "src/litert/kernel/cpu/fp32/convolution_im2col_arm64_fp32.h" #endif -#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" namespace mindspore::kernel { LiteKernel *CreateConvolutionIm2ColCPUKernel(OpParameter *parameter, const std::vector &inputs, diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.h index c97e62b6c09e247f3fad268edfcea21a70c5549b..e1b22015a579cabc6afe39c05097497d77b44ea6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_arm64_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_arm64_fp32.cc index 32628a88e99a2bc71129e1dbaac207e2b4b90a9d..111388796b25fe64750b90075577b78fdb1b9256 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_arm64_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_arm64_fp32.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/litert/kernel/cpu/fp32/convolution_slidewindow_arm64_fp32.h" -#include "nnacl/fp32/conv_sw_arm64_fp32.h" +#include "nnacl_c/fp32/conv_sw_arm64_fp32.h" namespace mindspore::kernel { void ConvolutionSWARM64CPUKernel::InitGlobalVariable() { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_avx_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_avx_fp32.cc index e8f3b62d2e7c525d6ddebb58cce7bef7e1459ca5..54dcd0a2fd4c53cba8fb022c9f9e29756dfe11b1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_avx_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_avx_fp32.cc @@ -15,8 +15,8 @@ */ #ifdef ENABLE_AVX #include "src/litert/kernel/cpu/fp32/convolution_slidewindow_avx_fp32.h" -#include "nnacl/fp32/conv_common_fp32.h" -#include "nnacl/fp32/conv_1x1_x86_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" +#include "nnacl_c/fp32/conv_1x1_x86_fp32.h" namespace mindspore::kernel { void ConvolutionSWAVXCPUKernel::InitGlobalVariable() { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.cc index e46bdb01c50242d6119ecfad3166fceec0d57658..4026cf8535922077fc51004966d50dc1ec12ec4b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.cc @@ -15,8 +15,8 @@ */ #if defined(ENABLE_AVX) || defined(ENABLE_ARM64) #include "src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" -#include "nnacl/fp32/conv_common_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.h index e00ce68d1cac3eb4a228060b5613d2944173024a..6f8f40805e20d263ce99f879725a36ff443cd63d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.h @@ -18,7 +18,7 @@ #if defined(ENABLE_AVX) || defined(ENABLE_ARM64) #include #include "src/executor/kernel_exec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_sw_1x1_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_sw_1x1_fp32.h index 2349dc9b75f14d7669190c852a11e22cc555ad93..071d21d3c1cd0f0301b5165e30480d4fc86dc88b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_sw_1x1_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_sw_1x1_fp32.h @@ -19,9 +19,9 @@ #include #include "include/errorcode.h" -#include "nnacl/intrinsics/ms_simd_cpu_info.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/fp32/matmul_fp32.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_arm64_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_arm64_fp32.cc index f8ce6fe5d2f1ea30dd192d2c284f9683f9c5b894..ee8ea71c0b129bfd45daa888b1f0321a85f1a306 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_arm64_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_arm64_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/convolution_winograd_arm64_fp32.h" -#include "nnacl/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_utils.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_MEMORY_FAILED; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_avx_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_avx_fp32.cc index 80dd3d9e7704454309fc80cf794fc430fba608c3..1837425a7d6c11814fccc94864efa3c65720b0cf 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_avx_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_avx_fp32.cc @@ -15,8 +15,8 @@ */ #include "src/litert/kernel/cpu/fp32/convolution_winograd_avx_fp32.h" -#include "nnacl/fp32/conv_winograd_fp32.h" -#include "nnacl/pack.h" +#include "nnacl_c/fp32/conv_winograd_fp32.h" +#include "nnacl_c/pack.h" #include "include/errorcode.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.cc index df4bc44185a8b5c2f6428ec4a5f43a786ab97ee2..e77872c127f8d6100811e72ca3b9df23f0f6e741 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.cc @@ -15,8 +15,8 @@ */ #include "src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h" -#include "nnacl/fp32/conv_winograd_fp32.h" -#include "nnacl/pack.h" +#include "nnacl_c/fp32/conv_winograd_fp32.h" +#include "nnacl_c/pack.h" #include "include/errorcode.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h index 4f6a45fd3c1c9f9eae43755b0b37b0f211017a41..0968c500b1e678545ee54a87ffb1f47cc5799004 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h @@ -19,9 +19,9 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/winograd_transform.h" -#include "nnacl/base/minimal_filtering_generator.h" -#include "nnacl/fp32/conv_winograd_fp32.h" +#include "nnacl_c/fp32/winograd_transform.h" +#include "nnacl_c/base/minimal_filtering_generator.h" +#include "nnacl_c/fp32/conv_winograd_fp32.h" #include "src/litert/kernel/cpu/base/convolution_base.h" #define CONV_INPUT_UNIT_SIZE 8 diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc index bbd7e6d04273fcb6a56a687211bea089f4a2a438..c6a2e2e133d34800a5d14a9ddc58f0267ff27097 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc @@ -31,7 +31,7 @@ #if defined(ENABLE_ARM64) #include "src/litert/kernel/cpu/fp32/convolution_winograd_arm64_fp32.h" #endif -#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" namespace mindspore::kernel { LiteKernel *CreateConvolutionWinogradCPUKernel(OpParameter *parameter, const std::vector &inputs, diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.h index e38cd91120a320a3cfc4062e86925bcf630b4f26..67c69734f1cf628acec6115e8a36c625bfedbfab 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.cc index 9e5acfe6e998f4128c4fb83e0e27532ab0bf7ab9..969ecc6cb3f2b3ee0155a587c04936418cac35f7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/litert/kernel/cpu/fp32/cumsum_fp32.h" -#include "nnacl/fp32/cumsum_fp32.h" +#include "nnacl_c/fp32/cumsum_fp32.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.h index 16edda65cd9bd2683d87e100f7e12a7ea4816b37..d26af0b5d8ac673e22c66b22b429f99d3b23b38b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.h @@ -18,7 +18,7 @@ #include #include "include/errorcode.h" -#include "nnacl/cumsum_parameter.h" +#include "nnacl_c/cumsum_parameter.h" #include "src/executor/kernel_exec.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/custom_gru_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/custom_gru_fp32.cc index 37e7d3ac7eeceae0fb617fa21ff5a7c0eee48867..1b9d7e6d1d71b412c6735569d26b05cb641212c2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/custom_gru_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/custom_gru_fp32.cc @@ -20,9 +20,9 @@ #include "include/errorcode.h" #include "src/common/log_adapter.h" #include "src/litert/pack_weight_manager.h" -#include "nnacl/fp32/pack_fp32.h" -#include "nnacl/custom_gru_parameter.h" -#include "nnacl/fp32/custom_gru_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/custom_gru_parameter.h" +#include "nnacl_c/fp32/custom_gru_fp32.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_depthwise_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_depthwise_fp32.h index eae5f0cf54745ecacdc9d6dd168c58c4698c3203..de8298cb13272607bae4822618ab23b84a13e08f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_depthwise_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_depthwise_fp32.h @@ -21,7 +21,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class DeconvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_fp32.h index c76f5ca76ab8e210688f516a9313ca1cfc6b00ce..af0579a77ac8317968371a215da8ff0cf4ba2ce9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_fp32.h @@ -25,8 +25,8 @@ #include "include/errorcode.h" #include "schema/model_generated.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/deconv_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/deconv_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" namespace mindspore::kernel { #define DECONV_WINOGRAD_MAX 2000 diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_winograd_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_winograd_fp32.h index 9f4646aa5d14656f7c828f9ac523e024bab958ac..1aff6a336c596b5568aaf8701414cfaf9ea61101 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_winograd_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_winograd_fp32.h @@ -24,8 +24,8 @@ #include "src/litert/kernel_registry.h" #include "include/errorcode.h" #include "schema/model_generated.h" -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32/deconv_winograd_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/deconv_winograd_fp32.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.cc index a9907583053e9b040baa25af4a9401d39411e8dd..b29aa3fdc99cda478c460f0335fcb03aabbc9d78 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.cc @@ -18,7 +18,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.h index 4261c6b1b059d85770bf9e35efb5cec95f3bb38c..d641c8c030322604593c910880f9d9202611b8df 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/detection_post_process_base.h" -#include "nnacl/fp32/detection_post_process_fp32.h" +#include "nnacl_c/fp32/detection_post_process_fp32.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/embedding_lookup_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/embedding_lookup_fp32.h index 58be4aad9c732691ef98cd710f055f7d5dc51fde..6813374c9818f2414c63e8c81763ea1bc6269065 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/embedding_lookup_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/embedding_lookup_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/embedding_lookup_fp32.h" +#include "nnacl_c/fp32/embedding_lookup_fp32.h" namespace mindspore::kernel { class EmbeddingLookupCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.cc index 4f080ed79555be6230c93d8e05bd5725754ee26e..c97df7cc9ae55da2d1d63226f0b30f3bad381dea 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.cc @@ -18,8 +18,8 @@ #include #include "src/litert/kernel_registry.h" #include "src/litert/kernel/cpu/base/split_base.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.h index 7bb4cc4ff4049bdb324410b6051aa1b1a752c8ef..903e8fe7665b7d8bbb0b6ddad9da9a80a8ac312d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.h @@ -20,9 +20,9 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" -#include "nnacl/split_parameter.h" -#include "nnacl/glu_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/glu_parameter.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/group_convolution_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/group_convolution_fp32.h index 8c5c1fac682bc831b932ec316878b52392c20a4e..de04f9025095b65b972535a9c866e4334b7d06b1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/group_convolution_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/group_convolution_fp32.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/group_convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.cc index c6794d0f09b29e721b7beb7a949e2b8a2e652eb4..7c7a872082086915d8d3512e3ab3c8f16085eff0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.cc @@ -18,8 +18,8 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/gru_fp32.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/gru_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.h index 6db5f0068bcdbeaec546a4df20b4680f27228f27..0736f4ffdb8778cd416fefa8e9976fadcfb27ee6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRU_FP32_H_ #include #include "src/litert/lite_kernel.h" -#include "nnacl/gru_parameter.h" +#include "nnacl_c/gru_parameter.h" namespace mindspore::kernel { class GruCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.cc index 8359695b65344b0d7772407ed87992a80e7077cc..83765b80fa2f28ae31f24401167a03c9258c06f5 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.cc @@ -17,8 +17,8 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/instance_norm_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/instance_norm_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.h index 56926cec334d0320fd41fb4b1b33e4fff2ecfefa..1ce9cec68936d262181561863d99a7f8f150a716 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_INSTANCE_NORM_FP32_H_ #include #include "src/litert/lite_kernel.h" -#include "nnacl/instance_norm_parameter.h" +#include "nnacl_c/instance_norm_parameter.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/invert_permutation_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/invert_permutation_fp32.cc index 49b1b48da56e94b2bf9c9d3a6ce103eba8be434c..f15019d47e6acd17d3354d20023757b9e45ffcf3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/invert_permutation_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/invert_permutation_fp32.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/fp32/invert_permutation_fp32.h" #include "src/litert/kernel_registry.h" #include "schema/model_generated.h" -#include "nnacl/fp32/invert_permutation_fp32.h" -#include "mindspore/ops/kernel/cpu/nnacl/errorcode.h" +#include "nnacl_c/fp32/invert_permutation_fp32.h" +#include "nnacl_c/errorcode.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.cc index 06a571072509e956d6dbf1814a1d22f9ff70c229..67b2771a514c49cbba7f35e8448e150bcd4d661c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.cc @@ -18,7 +18,7 @@ #include #include "src/litert/kernel/cpu/fp32/l2_norm_fp32.h" #include "include/errorcode.h" -#include "nnacl/fp32/l2_norm_fp32.h" +#include "nnacl_c/fp32/l2_norm_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.h index 032a63c45527cc13911f2fee784bd1ee3da77aea..eb09b3b6d191a78777eff58da6978f8ba642e515 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/l2_norm_parameter.h" +#include "nnacl_c/l2_norm_parameter.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.cc index bd0f0e7d575063a201254e024ba47b7b77a926e6..d5975a8d915d3ff670c6568c8f8c8725a778c6f9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/fp32/lstm_fp32_base.h" #include #include "include/errorcode.h" -#include "nnacl/fp32/pack_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_MEMORY_FAILED; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.h b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.h index 38800c0711241761969177d36a8724836e6cdd4c..2f96c661751084bbb82b29206a2b0e5a3ebdbdec 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" namespace mindspore::kernel { class LstmFp32BaseCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.cc index 476d5940dce0ec89a23829ba9d8d0ecfac96c2e8..97ccf931d3006bd71fcdde80f1b7f587936ff7b2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/lstm_mindir_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::kernel { namespace { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.cc index 62f9f2b7bf794847c4f5af3821eb09e47aebfa86..317ea2cfbe205df10def94362e327fe98de0d2a1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::kernel { namespace { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.cc index 29e8dba3c1bd1b293fc6e7080ac4b1f943899925..42cea2732ce06d515fb9768b02182cf09b0442fe 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.cc @@ -17,9 +17,9 @@ #include "src/litert/kernel/cpu/fp32/matmul_fp32.h" #include #include "include/errorcode.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" #include "src/litert/kernel_registry.h" -#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" #if defined(ENABLE_AVX512) #include "src/litert/kernel/cpu/fp32/matmul_fp32_avx512.h" #endif diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.h index 262c7da6e5433d7a07a274486bc4b919633ae5ac..1f0f3403e2d3359014a6a9752496d838896ed739 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_MATMUL_FP32_H_ #include -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm32.cc index 86e28d2ee7767d1b626a73b0929da4160f0d521f..55cd42fb4f8db3f48f94c96d4211bd1e334539eb 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm32.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/fp32/matmul_fp32_arm32.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::kernel { void MatmulFp32ARM32CPUKernel::InitGlobalVariable() { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm64.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm64.cc index a0aaddff73e9a5e2063ce2da5858069d37a4e0bf..902c01fc25c5adb3a0bc9933c78fecda6109de26 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm64.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm64.cc @@ -18,9 +18,9 @@ #include "src/litert/kernel/cpu/fp32/matmul_fp32_arm64.h" #include #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32/pack_fp32.h" -#include "nnacl/fp32/pack_fp32_opt.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32_opt.h" namespace mindspore::kernel { namespace { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx.cc index b9bb87815cd0753f6971e2e88fc57eab926ec80e..401bff220f5b80250c5e5e165509446b45ff1830 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/fp32/matmul_fp32_avx.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::kernel { void MatmulFp32AVXCPUKernel::InitGlobalVariable() { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx512.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx512.cc index e3cbbbbcba22c9db99372e26ed2b85932432a2cd..bd96027f135ba2f76ddcdc6134519119a1501ece 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx512.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx512.cc @@ -17,10 +17,10 @@ #include "src/litert/kernel/cpu/fp32/matmul_fp32_avx512.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" -#include "nnacl/fp32/matmul_avx512_fp32.h" -#include "nnacl/fp32/matmul_avx512_mask_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_avx512_fp32.h" +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::kernel { namespace { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc index 38b72c3932eb4bb27da70a47a4766fc6db01a83b..cd202df3c2508a0631fe445277e82e5231c00c2b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc @@ -16,9 +16,9 @@ #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" #include -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32/pack_fp32.h" -#include "nnacl/fp32/pack_fp32_opt.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32_opt.h" #if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) #include "thread/parallel_thread_pool_manager.h" #endif diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.h b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.h index cf17528c74eec9dedde5ed11aa66725e233082d8..f1ac4f117245b38be851d8c2696bae7749f7c831 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/pack_weight_manager.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "include/errorcode.h" #include "src/common/common.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_sse.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_sse.cc index 7790fb7b53e9d3c525333a540d3e81b7d10d6a70..996e968dc743870008e060d64f48cd6787a74e7f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_sse.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_sse.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/fp32/matmul_fp32_sse.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::kernel { void MatmulFp32SSECPUKernel::InitGlobalVariable() { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.cc index 00d63f21e6dcd8c4c703ad0abaa705e1c5e4f22d..5c31033fe9fc4c6538252a2a59fd71f9b81e098e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.cc @@ -18,7 +18,7 @@ #include #include #include -#include "nnacl/non_max_suppression_parameter.h" +#include "nnacl_c/non_max_suppression_parameter.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.h index 7e6b011bc607766d138487e21b892c2b61b4217d..0cce8151fd7fe48f0c67e9c346acec4c1667c1b1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.h @@ -21,7 +21,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/non_max_suppression_parameter.h" +#include "nnacl_c/non_max_suppression_parameter.h" using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/cast_gather_reduce_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/cast_gather_reduce_fp32.cc index 59d6a5d76b4c6887f4ab7fd4c719f70752fa484a..5a937d59bbd4a9378eebdbf47246a1699ae00fc2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/cast_gather_reduce_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/cast_gather_reduce_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/online_fusion/cast_gather_reduce_fp32.h" -#include "nnacl/fp32/online_fusion/cast_gather_reduce_fp32.h" +#include "nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.h" #include #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/reduce_concat_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/reduce_concat_fp32.cc index 9e890e8e537484fa7a5ea70fdc4168bdea0bddbf..76a7b698199c9a91fea8febea770b5add9cc982a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/reduce_concat_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/reduce_concat_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/online_fusion/reduce_concat_fp32.h" -#include "nnacl/fp32/online_fusion/reduce_concat_fp32.h" +#include "nnacl_c/fp32/online_fusion/reduce_concat_fp32.h" #include #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.cc index b77dea64f343a271b7df3aea71d593a97b7ae707..d6f00fb94774fcb07c8f0d5e43702af8ab297c0b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.h" -#include "nnacl/fp32/online_fusion/split_reduce_concat_fp32.h" +#include "nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.h" #include #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.h index bcaaaabba8c35eb7a51cba57abd408b2c66b7cbc..376cd7c4c56abb71d9e7fe439843fa3b4d4d65bd 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::kernel { class SplitReduceConcatFusionCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/reduce_scatter_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/reduce_scatter_fp32.h index 909d751d10401245618955f487166462b9e7db74..6dd140fcdc4b1893d09ff147bb40104bbd813c76 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/reduce_scatter_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/reduce_scatter_fp32.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/reduce_scatter_parameter.h" +#include "nnacl_c/reduce_scatter_parameter.h" namespace mindspore::kernel { class ReduceScatterCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.cc index 8f83579f4de83ed81aad88e0d5a53d5c933ed78b..1bb57069b830e41b799bb8fa80881ad00ac22211 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.cc @@ -18,7 +18,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.h index 1c73c94b13c732d33a815d8bdb37bca5918f4669..1ad095d44988ff6249739a88b5084669df85704b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/attention_fp32.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/fp32/attention_fp32.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::kernel { // inputs: 0:Q 1:K 2:V 3:P 4:WQ 5:WK 6:WV 7:WP 8:PU 9:PV 10:WO 11:BQ 12:BK 13:BV 14:BO 15:output diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/resize_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/resize_fp32.h index f7b36b5c4fa003719831e59ab938a1b89f97cf4e..a38309d6436de546f67efc515deeb8ec80d6eb83 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/resize_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/resize_fp32.h @@ -19,7 +19,7 @@ #include #include #include "include/errorcode.h" -#include "nnacl/fp32/resize_fp32.h" +#include "nnacl_c/fp32/resize_fp32.h" #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/resize_base.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/reverse_sequence_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/reverse_sequence_fp32.h index a7b28818ba790615015eeaaf6b35f6cf9ea84de7..5960b3703bae00181f192b0f9d4592e65589bfa2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/reverse_sequence_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/reverse_sequence_fp32.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/reverse_sequence_fp32.h" +#include "nnacl_c/fp32/reverse_sequence_fp32.h" namespace mindspore::kernel { class ReverseSequenceCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.cc index c341271246d1cfeaac0afade8705bd330dfb4dd6..cbc26b7eafdf8eb3a6fbae9a69a2fa3877d1280d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/litert/kernel/cpu/fp32/roi_pooling_fp32.h" -#include "nnacl/fp32/roi_pooling_fp32.h" +#include "nnacl_c/fp32/roi_pooling_fp32.h" #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.h index f0a029052bf33949d9783506ac4a0482dc5ca81f..de6af7b0f53125827f88de863dbef9aac6dc208d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/roi_pooling_fp32.h" +#include "nnacl_c/fp32/roi_pooling_fp32.h" namespace mindspore::kernel { class ROIPoolingCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_batch_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_batch_fp32.h index 1ea041d2473b2a0a80a2cbd698feca9e7bc23f3f..73c99f6f57a5e0ada152995ba997544d87fc1658 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_batch_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_batch_fp32.h @@ -18,8 +18,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/space_to_batch_fp32.h" -#include "nnacl/common_func.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" +#include "nnacl_c/common_func.h" namespace mindspore::kernel { class SpaceToBatchCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.cc index cea0b28504127c06b873a8074ce3adefa13b5149..eb574a07f935d1250104941c1d00692f8f62a6a9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.cc @@ -14,10 +14,10 @@ * limitations under the License. */ #include "src/litert/kernel/cpu/fp32/space_to_depth_fp32.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/base/space_to_depth_base.h" +#include "nnacl_c/base/space_to_depth_base.h" #include "include/errorcode.h" using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.h index 6f92f7b920cd1a6535eb13cf4183976f6cee2a24..2e46b09129dc6441e49085a374e14805de026b0e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/space_to_depth_parameter.h" +#include "nnacl_c/space_to_depth_parameter.h" namespace mindspore::kernel { class SpaceToDepthCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_fill_empty_rows_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_fill_empty_rows_fp32.cc index 84c9ef33c780e72fa28514209dbd203358c2873e..2c96dac32f6b084ceba7ec31fb3e90dd342b8e3e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_fill_empty_rows_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_fill_empty_rows_fp32.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/common_func.h" +#include "nnacl_c/common_func.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_reshape_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_reshape_fp32.cc index 165ae71bbbbe02dadaf8ee56eea865fcb84234c3..98e637f1d4fe48d730bbc9a9e9019eb66509e6c7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_reshape_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_reshape_fp32.cc @@ -18,7 +18,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/common_func.h" +#include "nnacl_c/common_func.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_segment_sum_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_segment_sum_fp32.cc index 9d5de564599717d7cce6b2331e5af04cf247d86c..59491b5b3827eecbbac85c41f08a1f0ea02c75bb 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_segment_sum_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_segment_sum_fp32.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/common_func.h" +#include "nnacl_c/common_func.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.cc index e3411c7bacbc9a1d98e92e56972524c8239d3b22..23c2d68d587aba84d3fa6614862e0d728b58e56e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.cc @@ -17,9 +17,9 @@ #include #include #include "include/errorcode.h" -#include "nnacl/fp32/sparse_to_dense_fp32.h" +#include "nnacl_c/fp32/sparse_to_dense_fp32.h" #ifdef ENABLE_FP16 -#include "nnacl/fp16/sparse_to_dense_fp16.h" +#include "nnacl_c/fp16/sparse_to_dense_fp16.h" #endif #include "schema/ops_generated.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.h index 6ab288d4e410b5ceaa9d99ffa657776d42404b64..c4d0d2ce3571b7de49ed3a9c7633a5cc52af87b2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/sparse_to_dense_fp32.h" +#include "nnacl_c/fp32/sparse_to_dense_fp32.h" #include "src/litert/kernel/cpu/base/layout_transform.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/topk_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/topk_fp32.h index 02c553221f427b2673faf45d5602dd46981877e1..27ebdfa18ddc9258849e1c611710489d3d4cbbe2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/topk_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/topk_fp32.h @@ -18,9 +18,9 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/topk_fp32.h" +#include "nnacl_c/fp32/topk_fp32.h" #ifdef ENABLE_FP16 -#include "nnacl/fp16/topk_fp16.h" +#include "nnacl_c/fp16/topk_fp16.h" #endif namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.cc index 238d85437f86ffb0ccadec06a9b230fc1f92348e..18233eea6298e6f82b53d63eefabf7598467af40 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.cc @@ -16,7 +16,7 @@ */ #include "src/litert/kernel/cpu/fp32/transpose_server_fp32.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.h index 26a35c7e3cce89a782434d7057c8449820ab70b8..3793c3fcdd02373190c6a4d4a3afb6388e5a492e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.h @@ -19,7 +19,7 @@ #ifdef BFC_MEMORY #include #include "src/litert/kernel/cpu/base/transpose_base.h" -#include "nnacl/fp32/transpose_server_fp32.h" +#include "nnacl_c/fp32/transpose_server_fp32.h" namespace mindspore::kernel { class TransposeServerCPUKernel : public TransposeBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/uniform_real_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/uniform_real_fp32.h index f6cada1d44d1b197fac34bf8c9efa327b10fbcd5..790d1e29cf3723322f6fc96275bc998b6d339417 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/uniform_real_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/uniform_real_fp32.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" namespace mindspore::kernel { class UniformRealCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/unstack_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/unstack_fp32.h index 21212a747671ca71e07ca3a42625e8daec3eea1c..2e11ac413d5bce1a6a073732a2016c7969f5d2ee 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/unstack_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/unstack_fp32.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/base/unstack_base.h" +#include "nnacl_c/base/unstack_base.h" namespace mindspore::kernel { class UnstackCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.cc index e442c2164eac24c9145c9d359bb9616f99c0a661..15f6ac92a88ed9425314ebff41430506b6c21b2b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32_grad/activation_grad.h" -#include "nnacl/fp32_grad/activation_grad_fp32.h" +#include "nnacl_c/fp32_grad/activation_grad_fp32.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.h index a5be9df8700013c28b8732962eb76d52b47d65c3..1e0e79cc94456ae4c5a8d5850051ae30c8203394 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" namespace mindspore::kernel { class ActivationGradCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.cc index de17bc03f652cb8393335d2127e4f0bad08d7d02..1a783e936067682693f962f8d827673c0e25c5ae 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.cc @@ -20,8 +20,8 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/adam_fp32.h" -#include "nnacl/op_base.h" +#include "nnacl_c/fp32/adam_fp32.h" +#include "nnacl_c/op_base.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.h index dab71ae52d0f9c36125a17a4b37e0844013c1626..dfcc88355e1c3d253013f0da431d3a47c1598ca2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.h @@ -19,7 +19,7 @@ #include #include "src/train/optimizer_kernel.h" -#include "nnacl/fp32_grad/optimizer.h" +#include "nnacl_c/fp32_grad/optimizer.h" namespace mindspore::kernel { constexpr int kAdamLrIndex = 5; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.cc index 1b3d578e27544c072483a29b08ae7dcb48de65c9..5c30fb8db014001ba96d1b6ce778877e93dc888f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/adam_fp32.h" +#include "nnacl_c/fp32/adam_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/apply_momentum.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/apply_momentum.h index 2fd6ef7824ff1d8f43068a4fbd600fd1dce4cd53..d6f11298c97ce21acd52c2a26f2e02a8e9b3dd18 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/apply_momentum.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/apply_momentum.h @@ -19,7 +19,7 @@ #include #include "src/train/optimizer_kernel.h" -#include "nnacl/fp32_grad/optimizer.h" +#include "nnacl_c/fp32_grad/optimizer.h" namespace mindspore::kernel { constexpr int kApplyMomentumLrIndex = 2; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.cc index b74231e2c6ab9c1e3f9f34ca8bed401cee085c74..309d584ede338442f4749889df3933c88485a920 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.cc @@ -17,10 +17,10 @@ #include "src/litert/kernel/cpu/fp32_grad/arithmetic_grad.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32_grad/reduce_grad.h" -#include "nnacl/fp32_grad/arithmetic_grad.h" +#include "nnacl_c/fp32_grad/reduce_grad.h" +#include "nnacl_c/fp32_grad/arithmetic_grad.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.h index 92bcb2d80b29d5fa8a6e6f6f73d649c9ff893269..8402685d6ae676c093a4c52ebed339d97f52b1e7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" #include "schema/model_generated.h" using mindspore::schema::PrimitiveType_AddGrad; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_self_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_self_grad.cc index 725bddd286f5c0b3aac55a48e400847b1553a8b6..24044b8f1989f3b0d1405d253243af98d93b3835 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_self_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_self_grad.cc @@ -18,9 +18,9 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/arithmetic_fp32.h" -#include "nnacl/fp32_grad/arithmetic_grad.h" -#include "nnacl/op_base.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32_grad/arithmetic_grad.h" +#include "nnacl_c/op_base.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/assign.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/assign.h index 6d42ac377cc7e0cc77d5bf1a24ffd1c81a01e96d..9b1391154e217e97774e453b2faf63c860d1a7ca 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/assign.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/assign.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32_grad/optimizer.h" +#include "nnacl_c/fp32_grad/optimizer.h" namespace mindspore::kernel { class AssignCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bias_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bias_grad.h index 056c88f69ee472088dd502b44ba6c060b48b8179..a759e0865f325e504653e350948b214b5f998332 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bias_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bias_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" namespace mindspore::kernel { class BiasGradCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy.cc index 814977fbade96cd7c563ad28b8f42fdc1e6d1025..d0376e1493d05d8c4d38710e04ef81c509a1fe74 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy.cc @@ -17,7 +17,7 @@ #include "src/litert/kernel/cpu/fp32_grad/binary_cross_entropy.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32_grad/binary_cross_entropy.h" +#include "nnacl_c/fp32_grad/binary_cross_entropy.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc index cd50a86b2d19b0272bdfa9ad071124c25db7eb72..36b3676c21ae0a5f33479ef698f9eed606a5943a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc @@ -17,7 +17,7 @@ #include "src/litert/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32_grad/binary_cross_entropy_grad.h" +#include "nnacl_c/fp32_grad/binary_cross_entropy_grad.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bn_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bn_grad.cc index 4855042690f03f2aed59b2bbd25023fccf36887c..ea84f8db062532720dad2c71d88f40abd92853e0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bn_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bn_grad.cc @@ -23,7 +23,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32_grad/batch_norm_grad.h" +#include "nnacl_c/fp32_grad/batch_norm_grad.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution.cc index d9eaff5e52d20bbc1d59d5667c2121857913822c..0639cec2901d2ee02135b72dede8a8e55a76daf0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution.cc @@ -15,10 +15,10 @@ */ #include "src/litert/kernel/cpu/fp32_grad/convolution.h" -#include "nnacl/fp32_grad/pack_ext.h" -#include "nnacl/fp32_grad/gemm.h" +#include "nnacl_c/fp32_grad/pack_ext.h" +#include "nnacl_c/fp32_grad/gemm.h" #include "include/errorcode.h" -#include "nnacl/pack.h" +#include "nnacl_c/pack.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.cc index 7bcb222e583e46c3b75b16baf5d338f11df0a9a1..35fc941ac1c1005c65b8018c23ecef1430ba5da9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.cc @@ -16,11 +16,11 @@ #include "src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.h" #include "src/litert/kernel_registry.h" -#include "nnacl/pack.h" -#include "nnacl/fp32_grad/convolution_grad_filter.h" -#include "nnacl/fp32_grad/pack_ext.h" -#include "nnacl/fp32_grad/gemm.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/pack.h" +#include "nnacl_c/fp32_grad/convolution_grad_filter.h" +#include "nnacl_c/fp32_grad/pack_ext.h" +#include "nnacl_c/fp32_grad/gemm.h" +#include "nnacl_c/errorcode.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_input.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_input.cc index 68e5968ab7d696ca2b44478a807fba86b7add14e..515ce2ec573e2e7374cda841bc732e9bd9ab3ce4 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_input.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_input.cc @@ -16,10 +16,10 @@ #include "src/litert/kernel/cpu/fp32_grad/convolution_grad_input.h" #include "src/litert/kernel_registry.h" -#include "nnacl/pack.h" -#include "nnacl/fp32_grad/pack_ext.h" -#include "nnacl/fp32_grad/gemm.h" -#include "nnacl/fp32_grad/convolution_grad_input.h" +#include "nnacl_c/pack.h" +#include "nnacl_c/fp32_grad/pack_ext.h" +#include "nnacl_c/fp32_grad/gemm.h" +#include "nnacl_c/fp32_grad/convolution_grad_input.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.cc index 5b2dae82597058df8b1faa9c21631bab662d8bce..94ced4c2a71603ae2c72a5dc8da43e1396b5672c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.cc @@ -16,9 +16,9 @@ #include "src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.h" #include "src/litert/kernel_registry.h" -#include "nnacl/pack.h" -#include "nnacl/fp32_grad/pack_ext.h" -#include "nnacl/fp32_grad/gemm.h" +#include "nnacl_c/pack.h" +#include "nnacl_c/fp32_grad/pack_ext.h" +#include "nnacl_c/fp32_grad/gemm.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout.cc index d5df2396d6214243be36a0602c71a7102582a9b0..aef8c820cd41dcbf82e7b8c538a898ade7dcc557 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32_grad/dropout_parameter.h" +#include "nnacl_c/fp32_grad/dropout_parameter.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout_grad.cc index c22b6b9cfc8a81a89c52c6f4ea0e920e92ecb38a..f057ba4f00ecd2a072a6eb11fd7a39d8768f9f6f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout_grad.cc @@ -16,11 +16,11 @@ #include #include "src/litert/kernel/cpu/fp32_grad/dropout_grad.h" -#include "nnacl/fp32_grad/dropout_grad.h" +#include "nnacl_c/fp32_grad/dropout_grad.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32_grad/dropout_parameter.h" +#include "nnacl_c/fp32_grad/dropout_parameter.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/layernorm_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/layernorm_grad.cc index ea9d9b0f92c6e6f9f9cd0120c29aa009b0637ac9..4edbed4a7ee27a9d8d1cc48681129824981eaeec 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/layernorm_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/layernorm_grad.cc @@ -19,9 +19,9 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32_grad/layernorm_grad.h" -#include "nnacl/fp32_grad/layernormgrad_parameter.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp32_grad/layernorm_grad.h" +#include "nnacl_c/fp32_grad/layernormgrad_parameter.h" +#include "nnacl_c/errorcode.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.cc index b5f0a3ee19b5b9478040a44477662dc79cc0ba24..1d783c5db3c9921633c589d33721d30860a1f6fa 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.cc @@ -20,7 +20,7 @@ #include "utils/ms_utils.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.h index 429c875d3bc792ace87ead88e0b9ebc23c2e5809..6f1f62b0e91d95e449193df4478a815f6de43a33 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32_grad/lstm_grad_fp32.h" +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.cc index a4979a9d6e71cfacbb730b84003b1c973bff748a..a8edc072c77ece268d4d272bb624507b12f68f30 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.cc @@ -20,7 +20,7 @@ #include "utils/ms_utils.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.h index 6eabbc9ce6944c30e540fc333382b5b0723b6834..bca70e2d8e11532515028c65cbf879806176ae8d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32_grad/lstm_grad_fp32.h" +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.cc index 51a9d2b631d96373b152c19507106d746f0e5b54..4b064da1e178ac649ca841a81b6c752470d39de8 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.cc @@ -19,7 +19,7 @@ #include "utils/ms_utils.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.h index 844bc011c5f70a1cf29c01c1f77c1f31d3094008..5db3f80e48bb3fbabe084af7e5059aef2c6d194a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32_grad/lstm_grad_fp32.h" +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/neg_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/neg_grad.cc index ab058d9ba9e3e4396a20f80b9f7ded49823dcb4d..916d7e400e1e142eec99b355d8e80cf10c3b6fdb 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/neg_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/neg_grad.cc @@ -18,7 +18,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/arithmetic_self_fp32.h" +#include "nnacl_c/fp32/arithmetic_self_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.cc index 8dd4938a53f5f3a3e0823bc1a9937a51f596b025..26751714899fba5bf3d87c5a7a7f2a9ff1e67c59 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.cc @@ -21,7 +21,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32_grad/nllloss_grad_fp32.h" +#include "nnacl_c/fp32_grad/nllloss_grad_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.h index df4433768df53026b2c6e0cf34a4a22b0ba8996e..5b58572807499731f63337201fefd5e8dff52211 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/nllloss_parameter.h" +#include "nnacl_c/nllloss_parameter.h" namespace mindspore::kernel { class NLLLossGradCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.cc index 55752ce3b845f8852dbac21e71d49ad90753a298..f79c2401409442d9452e6eb2ce6283228af7fdf0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/fp32_grad/pooling_grad.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/pooling_fp32.h" -#include "nnacl/fp32_grad/pooling_grad.h" +#include "nnacl_c/fp32/pooling_fp32.h" +#include "nnacl_c/fp32_grad/pooling_grad.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.h index 8a9e53d8223eb079b8e88c6ea9d681fff437a6a7..0e68017209737668206565d2004cf249fea82d33 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/kernel/pooling.h" +#include "nnacl_c/kernel/pooling.h" namespace mindspore::kernel { using mindspore::schema::PadMode; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.cc index 77e5af4c9bcb03e0e2afddb2a867e64d274ac325..2247d2bc4ee2af6c97bd83f136ae40e9614edba9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.cc @@ -15,11 +15,11 @@ */ #include "src/litert/kernel/cpu/fp32_grad/power_grad.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.h index ec7257bcee682dabc5eb97088e928b1786a0852b..5c6d2e3cb0d0c5ee6768484bcee5117d3df6b514 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/pow_parameter.h" -#include "nnacl/fp32/power_fp32.h" +#include "nnacl_c/pow_parameter.h" +#include "nnacl_c/fp32/power_fp32.h" namespace mindspore::kernel { class PowerGradCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/resize_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/resize_grad.cc index 15b035d571520f0e68fbd6683bba1e369bdda0f5..43ec046901daa5bd3ccbf6dafef0c5c9dfa67a4b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/resize_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/resize_grad.cc @@ -16,12 +16,12 @@ #include #include "src/litert/kernel/cpu/fp32_grad/resize_grad.h" -#include "nnacl/fp32_grad/resize_grad.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp32_grad/resize_grad.h" +#include "nnacl_c/errorcode.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sgd.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sgd.h index d43de4a4d11c4aff802f6c012db1701d8417269e..2b8758b0638908b05046f4d376d84f566c58bbb3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sgd.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sgd.h @@ -20,7 +20,7 @@ #include #include #include "src/train/optimizer_kernel.h" -#include "nnacl/fp32_grad/optimizer.h" +#include "nnacl_c/fp32_grad/optimizer.h" namespace mindspore::kernel { constexpr int kSgdLrIndex = 2; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss.h index f4560c539969cb7ab61501024c5c84ef910c4f75..ea93ea06e6594e53763951f5e4b5031407ef45f4 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32_grad/smooth_l1_loss.h" +#include "nnacl_c/fp32_grad/smooth_l1_loss.h" namespace mindspore::kernel { class SmoothL1LossCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss_grad.h index d33b758a97b9a37d12c15a9e5274e70203eee960..2af38da5de6dbb7ef580634e52d81c91aa42e818 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32_grad/smooth_l1_loss.h" +#include "nnacl_c/fp32_grad/smooth_l1_loss.h" namespace mindspore::kernel { class SmoothL1LossGradCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.cc index d01602d50231d76976bfd03a25a414d58f26bacf..67551d30c76ed3f20070d90189b573723dbba70c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.cc @@ -16,9 +16,9 @@ #include #include "src/litert/kernel_registry.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/fp32/softmax_fp32.h" -#include "nnacl/fp32_grad/softmax_cross_entropy_with_logits.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/fp32/softmax_fp32.h" +#include "nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.h" #include "src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.h index 51fcca830a7009bacd6a09a7e9a179e244938197..416a8152e5189397fcf36b69317938d5c4a66da8 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.h @@ -19,9 +19,9 @@ #include #include "src/train/loss_kernel.h" -#include "nnacl/fp32_grad/softmax_grad.h" -#include "nnacl/fp32/arithmetic_fp32.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/fp32_grad/softmax_grad.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/softmax_parameter.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.cc index 55ec5dc18dc2b9d7fa3161a6f977e909932106cb..b2d6f003634c62a6a5186825ecec5b7387d746f9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.cc @@ -17,7 +17,7 @@ #include "src/litert/kernel/cpu/fp32_grad/softmax_grad.h" #include #include -#include "nnacl/fp32_grad/softmax_grad.h" +#include "nnacl_c/fp32_grad/softmax_grad.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.h index ccde97fb3fea68482ab732e904d6a144f94ad684..f87f857c54271146e5e98868be9ff63719291330 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" namespace mindspore::kernel { class SoftmaxGradCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc index 7fbb4cc47cb51797cd63b1a838751af70ba4c286..200c0e14f859f427580d4611a5b971a0499cdd62 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc @@ -16,9 +16,9 @@ #include #include "src/litert/kernel_registry.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/fp32/softmax_fp32.h" -#include "nnacl/fp32_grad/softmax_grad_utils.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/fp32/softmax_fp32.h" +#include "nnacl_c/fp32_grad/softmax_grad_utils.h" #include "src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.h index f24c9a4ee9d0b9ce410fc05d28eb4c5c63c657e8..12730a04d3c0c9c71e1ee0bb2c09d1c0d51d8a5c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.h @@ -19,9 +19,9 @@ #include #include "src/train/loss_kernel.h" -#include "nnacl/fp32_grad/softmax_grad.h" -#include "nnacl/fp32/arithmetic_fp32.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/fp32_grad/softmax_grad.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/softmax_parameter.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.cc index 91bbd39aab05eea54d342564126afb0c7a9531eb..982be44bf68f2ba894286e2678b8fdf77e22ab48 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.cc @@ -21,7 +21,7 @@ #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32_grad/strided_slice_grad.h" +#include "nnacl_c/fp32_grad/strided_slice_grad.h" #include "src/common/ops/populate/strided_slice_populate.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.h index f34dd20d4ca2b2a53017470a29f8b530e25a14da..10fc1d50b0eaa54d0a1b5cc1d1635300f4b3beac 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_STRIDED_SLICE_GRAD_H_ #include -#include "nnacl/fp32_grad/strided_slice_grad.h" +#include "nnacl_c/fp32_grad/strided_slice_grad.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/unsorted_segment_sum.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/unsorted_segment_sum.cc index 7c97f07cd6772676d5d02c2cedb23251024616dd..681968fcb378f4097ebf68874cacee718dbeb6c6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/unsorted_segment_sum.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/unsorted_segment_sum.cc @@ -19,7 +19,7 @@ #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/base/unsorted_segment_sum_base.h" +#include "nnacl_c/base/unsorted_segment_sum_base.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.cc index 47373ccd2e21add4ca68764d42a64b1d5e66205f..220901e8012fe8476a4d44c8b8f6eb44d3622fd3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.cc @@ -23,9 +23,9 @@ #ifdef ENABLE_ARM64 #include #endif -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32_sparse/matmul_sparse_x1_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.h index 9475cad99ee4439a7934b5f9642c6b01db330bbd..1e6c94b63ff359cb61dbf5d877fc1b4e0df2595e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.h @@ -18,9 +18,9 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_SPARSE_MATMUL_SPARSE_FP32_H_ #include -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/transpose_fp32.h" +#include "nnacl_c/fp32/transpose_fp32.h" namespace mindspore::kernel { struct SparsityWeight { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.cc index 5e6c9ccc19ad828c0c4895ab285aaa2366887a5e..3096102078c230625b3f54d652bb0b73dafeafa7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/int8/add_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" #include "src/common/file_utils.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.h index c4c8e80b7f3a46327d23da02ac54aab2128e00a7..586184fcee5a83d4162389e085399deb65dc0c40 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.h @@ -20,8 +20,8 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/add_int8.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/int8/add_int8.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore::kernel { class QuantizedAddCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/argminmax_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/argminmax_int8.h index a94a5f006c797da7fa23a8f91529030ecec591c4..4eb95ab0272bdb51f3b4f6fcf501751a27960ab9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/argminmax_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/argminmax_int8.h @@ -17,12 +17,12 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_ARGMINMAX_INT8_H_ #include -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/arg_min_max_int8.h" -#include "nnacl/common_func.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/arg_min_max_int8.h" +#include "nnacl_c/common_func.h" #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/kernel/arg_min_max.h" +#include "nnacl_c/kernel/arg_min_max.h" namespace mindspore::kernel { class ArgMinMaxInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.cc index 6312577c25b5b276733a31c4f67e0d47448bc97d..ed066209f9cb8bd59361a4098df2f34917a0cfae 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.cc @@ -17,7 +17,7 @@ #include "src/litert/kernel/cpu/int8/arithmetic_int8.h" #include "src/litert/kernel/cpu/int8/add_int8.h" #include "src/litert/kernel/cpu/int8/mul_int8.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.h index f25a928ebaf7b66e62692ebcb96e6a855820e299..3c60c6fd2959ad7dc8d2d807fc2f1708ade75946 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "schema/model_generated.h" -#include "nnacl/int8/arithmetic_int8.h" +#include "nnacl_c/int8/arithmetic_int8.h" namespace mindspore::kernel { class ArithmeticInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_self_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_self_int8.h index 4477391c5ae3e66a41f7f8f1e2ee39d4088c3d7a..8930fe32e254e50e6273801164dd77c366a8e87c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_self_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_self_int8.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/arithmetic_self_parameter.h" -#include "nnacl/int8/arithmetic_self_int8.h" +#include "nnacl_c/arithmetic_self_parameter.h" +#include "nnacl_c/int8/arithmetic_self_int8.h" #include "schema/model_generated.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/batch_to_space_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/batch_to_space_int8.h index 64a8ece1868dd11e7865fc33623fc772503d4313..9cd263d80b556c9ab68cfe9781e190715421c46e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/batch_to_space_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/batch_to_space_int8.h @@ -18,9 +18,9 @@ #include #include "include/errorcode.h" -#include "nnacl/batch_to_space_parameter.h" -#include "nnacl/base/batch_to_space_base.h" -#include "nnacl/int8/batch_to_space_int8.h" +#include "nnacl_c/batch_to_space_parameter.h" +#include "nnacl_c/base/batch_to_space_base.h" +#include "nnacl_c/int8/batch_to_space_int8.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.cc index 39f9ee3d3226cc1090004c9b635c9f4bbc516207..10ed8e28004fca74492b2baf9a44e8b7aa4bdb6b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.h index 155342ec125245614e0b31c5b7b8cfb30930ee91..3312cd0f4f04125d5a63362bcdede03c81c5337a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/batchnorm_int8.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/int8/batchnorm_int8.h" +#include "nnacl_c/batchnorm_parameter.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/concat_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/concat_int8.h index 5b0e5a03ca6c27bfe7012aceff1eff92eb6e9c08..89d691d9db991ca9c52f46d4af0a58688ba88406 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/concat_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/concat_int8.h @@ -19,10 +19,10 @@ #include #include -#include "nnacl/int8/concat_int8.h" +#include "nnacl_c/int8/concat_int8.h" #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/kernel/concat.h" +#include "nnacl_c/kernel/concat.h" namespace mindspore::kernel { class ConcatInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_1x1_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_1x1_int8.h index 0d402dc1f896b69c38dc229ecffc9d82ea055782..0d009cbf701cffa498a7213c8e735fff8e45d892 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_1x1_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_1x1_int8.h @@ -22,10 +22,10 @@ #include "include/errorcode.h" #include "schema/model_generated.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/int8/conv1x1_int8.h" -#include "nnacl/base/conv1x1_base.h" -#include "nnacl/int8/matmul_int8.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/int8/conv1x1_int8.h" +#include "nnacl_c/base/conv1x1_base.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/matmul_parameter.h" #include "src/common/utils.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.cc index 95c284341e26349a578cd95c00eda5e578d16619..e408695f3ac9729600b53ce5dd1265dc70d42a60 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/int8/convolution_3x3_int8.h" -#include "nnacl/int8/conv3x3_int8.h" +#include "nnacl_c/int8/conv3x3_int8.h" #include "include/errorcode.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.h index 29ee39e4f152fad60cf82eb8b80ad6191e4a1af1..dd9621bdfbfbbc6c3bc212fb9b7ba0193f49f937 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/winograd_transform.h" +#include "nnacl_c/fp32/winograd_transform.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.cc index 4671a2493f2347a564753f2fd710e687e9f1bd34..66229b091eb47d450d350135f3c17f27fd25ed26 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.h" #include "include/errorcode.h" -#include "nnacl/int8/conv_depthwise_int8.h" +#include "nnacl_c/int8/conv_depthwise_int8.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.h index 87969621d175d532f1b2a10f483661e4b783225f..12a8f3f9f417845349b629560ecb63b39389cd97 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwise3x3Int8CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.cc index 86fbe5a183dfa34a0efd37c5db85bad05bcc0a78..473853923dc918da3c0171953d3d65ebd6084ccd 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/convolution_depthwise_int8.h" #include "include/errorcode.h" -#include "nnacl/int8/conv_depthwise_int8.h" +#include "nnacl_c/int8/conv_depthwise_int8.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.h index b069a4b95bfb1b1f7262c5658b5f05eac522086b..0164c0f240455d3f1b1d6342ff3f41226bf479e5 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwiseInt8CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.cc index 00f345c16cf9adc34b4f03b6721a63e973334e0e..d1b772d56031ad8e06943c233d703fbb47ee7f21 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.h" #include "include/errorcode.h" -#include "nnacl/int8/conv_depthwise_int8.h" +#include "nnacl_c/int8/conv_depthwise_int8.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.h index 21445a36001ff77eba7e7e22145316d26f821bea..61d27ac7cdd64ddf407625968ba1a1c5bfdbee21 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.h @@ -21,7 +21,7 @@ #include "src/litert/lite_kernel.h" #include "src/common/log_util.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwiseSWInt8CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.cc index 6118bcfaa7e39cdce2681196788bcf5dd62001f1..2d6b7fe278617848a77bc9e834f397e0fbc913c5 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/convolution_int8.h" #include "include/errorcode.h" -#include "nnacl/int8/conv_int8.h" +#include "nnacl_c/int8/conv_int8.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #ifdef ENABLE_ARM64 diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.h index f567a7059d9ac2029af6921b0464a925f6165a87..32d771fdd696db1c13a42882550ca0f1967d7d93 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.h @@ -21,7 +21,7 @@ #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" #include "src/common/utils.h" -#include "nnacl/int8/conv_int8.h" +#include "nnacl_c/int8/conv_int8.h" namespace mindspore::kernel { class ConvolutionInt8CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8_creator.h b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8_creator.h index c116ca1f177833a802622ef7f9c08464613230ce..f3f7916b79520d65711e5899d1b9b33a68c5ce4f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8_creator.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8_creator.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_CONVOLUTION_INT8_CREATOR_H_ #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.cc index 89a89c9bb61f5adffa3073ccaca05e39962bf7f2..8cc7ef1b5b95e524a7841def1c88365e665ef655 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/crop_int8.h" #include "src/litert/kernel_registry.h" -#include "nnacl/base/crop_base.h" +#include "nnacl_c/base/crop_base.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.h index 309af4accdedcde3d922383b7144fb70eabf770e..a5ac9dc2c7a9a490679ea68bc951ad81b944a7a0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.h @@ -20,7 +20,7 @@ #include #include #include "include/errorcode.h" -#include "nnacl/int8/crop_int8.h" +#include "nnacl_c/int8/crop_int8.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.cc index fef6ae2d3e52fe4f83a238d7013793bce5fe8634..eea6e79e98daca8127f733ffc7fe86ff8cccc2c9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.h" #include "include/errorcode.h" -#include "nnacl/int8/conv_depthwise_int8.h" +#include "nnacl_c/int8/conv_depthwise_int8.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.h index 2add61b4f7ff50705a5a5888d31b8450215566e8..3f5492be89139b35cb4b2d010d7ceafc26ba4238 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class DeconvolutionDepthwiseInt8CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_int8.h index 26dbb4a6240bd22253adf4a16975caeea6560f2c..55f2bf5ddd8b958db52f367ad8c38b9577f321c4 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_int8.h @@ -21,10 +21,10 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/int8/deconv_int8.h" -#include "nnacl/int8/common_func_int8.h" -#include "nnacl/int8/matmul_int8.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/int8/deconv_int8.h" +#include "nnacl_c/int8/common_func_int8.h" +#include "nnacl_c/int8/matmul_int8.h" #include "src/litert/kernel/cpu/base/layout_transform.h" #include "src/litert/kernel/cpu/base/convolution_base.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.cc index 378b4e3adadc70daf2af3bc59f3218c677b7ad50..ff893b42db29f36cae439a5c8b682e0a2241adb3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.cc @@ -18,7 +18,7 @@ #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.h index 291bc0711d422a41a82889b80380cf9a0e56a3f7..cf01900013d112088e48a54479f05da5e2a31be0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.h @@ -19,10 +19,10 @@ #include #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/base/depth_to_space_base.h" -#include "nnacl/int8/depth_to_space_int8.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/kernel/depth_to_space.h" +#include "nnacl_c/base/depth_to_space_base.h" +#include "nnacl_c/int8/depth_to_space_int8.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/kernel/depth_to_space.h" namespace mindspore::kernel { class DepthToSpaceInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.cc index 2dd204fc5a62d41c0c2b9341b01cd25eb1eb00cc..a4c43b57c79dc2011367edf14b7b47f2edda2d31 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.cc @@ -18,7 +18,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.h index f4ff2e8f0c0cbd5350b9f4e25e1d6433fdc74070..569d1591f28ee67011d2462e485b83c0592853f1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/detection_post_process_base.h" -#include "nnacl/fp32/detection_post_process_fp32.h" +#include "nnacl_c/fp32/detection_post_process_fp32.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.cc index 044c525b6c50b556920b9c07140a670fde8813e6..e2c8d8686d0d8f8077600b8488b8168447c5d6e6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.cc @@ -17,7 +17,7 @@ #include "src/litert/kernel/cpu/int8/div_int8.h" #include #include -#include "nnacl/int8/arithmetic_int8.h" +#include "nnacl_c/int8/arithmetic_int8.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.h index 151278158ff4967392f2ecb06cc98662d4698146..ddebb074071eb32ae335987250b66fab52cbc78a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/div_int8.h" +#include "nnacl_c/int8/div_int8.h" namespace mindspore::kernel { class DivInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.cc index 3c2d79bf2d2abbb4f5d1601c045682d0e346e8fe..e86c84ecdb541823b3d445271e027583ed513059 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.cc @@ -15,9 +15,9 @@ */ #include "src/litert/kernel/cpu/int8/dynamic_gather_int8.h" #include -#include "nnacl/gather_parameter.h" -#include "nnacl/int8/dynamic_gather_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/gather_parameter.h" +#include "nnacl_c/int8/dynamic_gather_int8.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.h index 3de463253b305c71803217a36edd91c04db57ee7..8fe495fe1ce99755c09dc0f2e6e68799cfbae052 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.h @@ -18,8 +18,8 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_DYNAMIC_GATHER_INT8_H_ #include -#include "nnacl/gather_parameter.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/gather_parameter.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.cc b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.cc index acc43c97528c0e6aa575aa6718837bdd2d3fb93e..daa79ff285df5481c3e05c79b505a08471f15d82 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.cc @@ -20,10 +20,10 @@ #include "src/litert/kernel_registry.h" #include "schema/model_generated.h" #include "include/errorcode.h" -#include "nnacl/int8/dynamic_quant_int8.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" -#include "nnacl/fp32/transpose_fp32.h" -#include "nnacl/int8/transpose_int8.h" +#include "nnacl_c/int8/dynamic_quant_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/fp32/transpose_fp32.h" +#include "nnacl_c/int8/transpose_int8.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.h b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.h index 023f1fabd365641206ae21307a27ba3382caeda9..137e3d0f0d96378ab5e85be19fc7e2233ff26f06 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.h @@ -21,7 +21,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/dynamic_quant_parameter.h" +#include "nnacl_c/dynamic_quant_parameter.h" namespace mindspore::kernel { class DynamicQuantCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.cc index a9e30b1e098107bf7b8d45d407f528c15325e3c9..40827d175cd37d230aef8f65342740ac8837b7e1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.cc @@ -21,7 +21,7 @@ #include "schema/model_generated.h" #include "include/errorcode.h" #include "src/litert/kernel_registry.h" -#include "nnacl/int8/gatherNd_int8.h" +#include "nnacl_c/int8/gatherNd_int8.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.h index d9f7f74e61afb3fb43221f920d142edf5dbbc49c..0d16b7dd56b1d9f96f3278c8ff079b7eb28c639a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_GATHERND_INT8_H_ #include -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.cc index 7f7f815c7ca1a26008ac4f2d47a8f3ceb4570a3a..3096adc98f6e3bc9ac5f8b1ec19019cf9d21d066 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.cc @@ -16,9 +16,9 @@ #include "src/litert/kernel/cpu/int8/gather_int8.h" #include #include "src/litert/kernel/cpu/int8/dynamic_gather_int8.h" -#include "nnacl/gather_parameter.h" -#include "nnacl/int8/gather_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/gather_parameter.h" +#include "nnacl_c/int8/gather_int8.h" +#include "nnacl_c/int8/quantize.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.h index 9b50d763eca4f4e59981ac7c028241f4407d3f41..8f72bcd9c7927b6a76a59c24c8b63d5f5fec49bc 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.h @@ -18,8 +18,8 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_GATHER_INT8_H_ #include -#include "nnacl/gather_parameter.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/gather_parameter.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/group_convolution_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/group_convolution_int8.h index 58d9cf356dad2529848a4d35c7793dc257419a9b..c3d3bd9fcd764b6b520b39175bcbac30cef0e3fd 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/group_convolution_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/group_convolution_int8.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/group_convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.cc index bff0702f47541a9c6ca777771f6221c41c4e6674..6a3c9d1e4fb28e36bdc9c013db789778a214bbf7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/hswish_int8.h" #include -#include "nnacl/int8/hswish_int8.h" +#include "nnacl_c/int8/hswish_int8.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.h index b43cd24a5c74e0c7907083beb23e25edf918206e..e5448aecf0ce3037171c1fd3c11024abc1c5ac82 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/hswish_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/hswish_int8.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::kernel { class HswishInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/l2_norm_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/l2_norm_int8.h index 74c7e6e8f58506c7bd25dc613eddd7a1de945067..4d50c56ebbdd2ac43f0beb3cf9e48196c66ed132 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/l2_norm_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/l2_norm_int8.h @@ -18,7 +18,7 @@ #include #include "src/litert/kernel/cpu/fp32/l2_norm_fp32.h" -#include "nnacl/int8/l2_norm_int8.h" +#include "nnacl_c/int8/l2_norm_int8.h" namespace mindspore::kernel { class L2NormInt8CPUKernel : public L2NormCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/layer_norm_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/layer_norm_int8.h index 1c37e1f526755f5d106349f5402296f089b2af56..cebae25a7a9f494e7cb684ba478ff8a8c3f94e0a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/layer_norm_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/layer_norm_int8.h @@ -18,10 +18,10 @@ #include #include -#include "nnacl/int8/layer_norm_int8.h" +#include "nnacl_c/int8/layer_norm_int8.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/kernel/layer_norm.h" +#include "nnacl_c/kernel/layer_norm.h" namespace mindspore::kernel { class LayerNormInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/leaky_relu_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/leaky_relu_int8.h index 5b65d2d9aeb61ebe3ff8bbb794cfc0ad68112016..1cb72ed7731633834e16851209d70baec4da6cf8 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/leaky_relu_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/leaky_relu_int8.h @@ -20,8 +20,8 @@ #include #include #include "include/errorcode.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/int8/leaky_relu_int8.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/int8/leaky_relu_int8.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_base_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_base_int8.h index 22402e5e0438803bd9ad543a09a51a8b103f6197..06a5bf3df40066c7d0a4cb2c914d41de13d4470f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_base_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_base_int8.h @@ -20,11 +20,11 @@ #include #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/common_func.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/common_func_int8.h" -#include "nnacl/int8/matmul_int8.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/common_func_int8.h" +#include "nnacl_c/int8/matmul_int8.h" namespace mindspore::kernel { class MatmulBaseInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc index bab1f730becbf41c9cb40de04d841232bb329130..ce95c450e7aa7004d2e083bbe26356d625903e16 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h" -#include "nnacl/int8/dynamic_matmul_int8.h" +#include "nnacl_c/int8/dynamic_matmul_int8.h" using mindspore::lite::kCHWDimNumber; using mindspore::lite::kHWDimNumber; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h index 858affc88625ea998d41704c538a013f8e5d72a9..42e0da55368bd1a03df42cb103b3800b7d887cd7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h @@ -21,10 +21,10 @@ #include #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/common_func.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/common_func_int8.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/common_func_int8.h" #include "src/common/common.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_int8.cc index 64c0d7055c28546be70e9544186f3acd5974fd7e..1c8ba2639f53187b97854693d202b8d880e5a549 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_int8.cc @@ -16,8 +16,8 @@ #include "src/litert/kernel/cpu/int8/matmul_dynamic_int8.h" #include "src/litert/kernel/cpu/int8/opt_op_handler.h" -#include "nnacl/int8/matmul_int8.h" -#include "nnacl/int8/dynamic_matmul_int8.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/int8/dynamic_matmul_int8.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_MEMORY_FAILED; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc index 611a02c1702786b8a18d8a01ef97e4b4485c6f80..b2a6ef1bc2dda12030a40545102f2fce1372bc07 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc @@ -16,8 +16,8 @@ #include "src/litert/kernel/cpu/int8/matmul_dynamic_sdot_int8.h" #include -#include "nnacl/int8/dynamic_matmul_int8.h" -#include "nnacl/int8/matmul_int8.h" +#include "nnacl_c/int8/dynamic_matmul_int8.h" +#include "nnacl_c/int8/matmul_int8.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_MEMORY_FAILED; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.cc index 85428f6f71c0eb6e675e907dec26a91ad0a41af6..fec4a8055b3696742370d7dfcc581a94ad5d5ac8 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/int8/matmul_int8.h" #include "src/litert/kernel/cpu/int8/matmul_dynamic_int8.h" #include "src/litert/kernel/cpu/int8/matmul_dynamic_sdot_int8.h" -#include "nnacl/int8/matmul_int8.h" -#include "nnacl/common_func.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/common_func.h" #include "include/errorcode.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.h index 0eb6ca65b1225b9a3b3046c3f77d70ffb24fb736..55bc989e670001e0e89bfcdd4e57b25af989f564 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.h @@ -18,8 +18,8 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_MATMUL_INT8_H_ #include -#include "nnacl/matmul_parameter.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/int8/matmul_base_int8.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/mul_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/mul_int8.h index 6464ea7cc08b8d697a7a3317f8428727dee8f915..0562a7277f456c765ce072d26124bc42268805bb 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/mul_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/mul_int8.h @@ -20,9 +20,9 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/mul_parameter.h" -#include "nnacl/int8/mul_int8.h" -#include "nnacl/int8/arithmetic_int8.h" +#include "nnacl_c/mul_parameter.h" +#include "nnacl_c/int8/mul_int8.h" +#include "nnacl_c/int8/arithmetic_int8.h" namespace mindspore::kernel { class MulInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.cc b/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.cc index 5718c898ac4ac6a4baa53f9438fd90ccf3dad3a6..229d4506e8ccff12ef2402ca1e71928edb84db8d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/opt_op_handler.h" #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.h b/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.h index 91691b1b0da3be143dbf120803290ad9d9f01f0b..128147bf9bcfba101500b8b38f3b91bb53a5d62a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.h @@ -18,7 +18,7 @@ #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.cc index 56fcec7b963ff8f3d1fd1ed8b98ee2876c74938f..27c1f8f50f79d7a0f63fdb4bec945b226b094bd8 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.cc @@ -18,7 +18,7 @@ #include #include #include "src/litert/kernel_registry.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_MEMORY_FAILED; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.h index 15774096cc25d72f667c124e5f310aa8e93ae86e..418d3deddcc8ff81cd892197c3090bc787c1b1f6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.h @@ -20,10 +20,10 @@ #include #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/errorcode.h" -#include "nnacl/pad_parameter.h" -#include "nnacl/int8/pad_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/pad_parameter.h" +#include "nnacl_c/int8/pad_int8.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::kernel { class PadInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.cc index 662871fd40d46732573d4c64957488888747707e..560a856197b3c5a38bd54fbb947ff24073b03e6a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/pooling_int8.h" #include -#include "nnacl/int8/pooling_int8.h" +#include "nnacl_c/int8/pooling_int8.h" #include "include/errorcode.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.h index ed8db60ef747d82b01e3d88849595ab262e58572..4c2f3dfaba35722e850d31fcd8219d35c5e74c2f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/pooling_int8.h" +#include "nnacl_c/int8/pooling_int8.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.cc index 4330b32b0d5cc0aeeecdb96c34ed962eb11487b4..d7cffb4562b27aaa689e49b51a9f4db95158bacf 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/power_int8.h" #include -#include "nnacl/int8/power_int8.h" +#include "nnacl_c/int8/power_int8.h" #include "include/errorcode.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.h index 5f2bfa6c2d28910ae0107b42cf40310325014639..e5e4532af67348c00289862ece2ac1fc3eee7a62 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/pow_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/pow_parameter.h" namespace mindspore::kernel { class PowerInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.cc index 008bbe30c0d4585d85155bb454880ef453273ed9..0ba3d0f804e32cf97b225a55aa012e68e5721b62 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.cc @@ -18,8 +18,8 @@ #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/pack.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/pack.h" #include "include/errorcode.h" using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.h index 149d209002fdea1507a5b525f5558ab27ce7f1cd..a1dc31900d642f8555e3ff82391c9d46bdfcf107 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.h @@ -19,9 +19,9 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/reduce_parameter.h" -#include "nnacl/int8/reduce_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/reduce_parameter.h" +#include "nnacl_c/int8/reduce_int8.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/kernel/cpu/base/reduce_base.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/relux_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/relux_int8.h index 504d964a204a388bdfb89f0c85ea8ca71128ef08..bb9abb5378e24bb5e44b13935fee4091dbfb00f0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/relux_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/relux_int8.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/int8/relux_int8.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/int8/relux_int8.h" namespace mindspore::kernel { constexpr size_t kRelu6Min = 0; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.cc index a91c5bd3a78279ac3d118a8de0ae6a494ae71036..1df86f3c0d1b54b11d7dfe1606d62c008600a117 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/reshape_int8.h" #include -#include "nnacl/int8/reshape_int8.h" +#include "nnacl_c/int8/reshape_int8.h" #include "schema/model_generated.h" #include "include/errorcode.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.h index dff07330b6f941a050da99faca6982ca3d540df9..91f6e9bad3e8bbe40fdd446dc7ee9e01bfd1d64e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/reshape_parameter.h" +#include "nnacl_c/reshape_parameter.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.cc index 53e39d3024dc0396f4a0df8900d3dce0e4443639..aac5c21717b01788bc87440c63097eaf876895c6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.cc @@ -18,7 +18,7 @@ #include #include #include "include/errorcode.h" -#include "nnacl/int8/resize_int8.h" +#include "nnacl_c/int8/resize_int8.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.h index 69d76b809bd667d947570fa551e5c13f87bbba9d..3e38c803412e5523aed7973e56aca1c18924ab24 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/resize_base.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" using mindspore::schema::PrimitiveType_Resize; using mindspore::schema::ResizeMethod; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/scale_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/scale_int8.h index e4cdc07ee4fbfa7518e0436a3188e5d33d01b5f2..9c569e9ef87ee7c22f57fbf63128661ad1a58267 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/scale_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/scale_int8.h @@ -20,10 +20,10 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/scale_parameter.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/arithmetic_int8.h" -#include "nnacl/int8/scale_int8.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/arithmetic_int8.h" +#include "nnacl_c/int8/scale_int8.h" namespace mindspore::kernel { class ScaleInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.cc index 00a3212a2098cc43ca20f28a4e752238b73543f6..ca4651cee4aeea0825527f90ce6c1664389a6468 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/int8/sigmoid_int8.h" #include #include -#include "nnacl/int8/sigmoid_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/sigmoid_int8.h" +#include "nnacl_c/int8/quantize.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.h index 1f383ae6f3938e5a77775e82d105a52a711f48fe..68fcd2cc808e5d2e61ddc5dbe628bb7e9a028906 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/sigmoid_int8.h" +#include "nnacl_c/int8/sigmoid_int8.h" namespace mindspore::kernel { class SigmoidInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.cc index 9708de00edc4f68a4061763f104049be6b6cb305..ab71d4b5105a7d02737cd45ceb806c6c6bf68dc7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.cc @@ -17,9 +17,9 @@ #include "src/litert/kernel/cpu/int8/slice_int8.h" #include #include "src/litert/kernel_registry.h" -#include "nnacl/int8/slice_int8.h" +#include "nnacl_c/int8/slice_int8.h" #include "include/errorcode.h" -#include "nnacl/base/slice_base.h" +#include "nnacl_c/base/slice_base.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.h index 2b2745ce88a0bd03e3f3647f039bce8e1e1fe2dc..3ebc7536ef68f761add1e301b726216d7d4ddeba 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.h @@ -19,9 +19,9 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/kernel/slice.h" -#include "nnacl/slice_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/kernel/slice.h" +#include "nnacl_c/slice_parameter.h" namespace mindspore::kernel { class SliceInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.cc index 0d418eb9530c0330f3dde791fe8b167640ab79c5..34ea1f6f7d81f0cee11c22bd9b07ee8297a11187 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/softmax_int8.h" #include -#include "nnacl/int8/softmax_int8.h" +#include "nnacl_c/int8/softmax_int8.h" #include "schema/model_generated.h" #include "include/errorcode.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.h index 61395c6753e65a43472087c4ccf9d947207aaba5..f6c0d0ace980c1e21ea224f688eefaa61534cce4 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::kernel { class SoftmaxInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/space_to_batch_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/space_to_batch_int8.cc index a478edc6b578132befddfbcb6b02b3650dacbd02..0cb03a4401d4ab75ae204232becb8f3696384a8d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/space_to_batch_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/space_to_batch_int8.cc @@ -15,8 +15,8 @@ */ #include "src/litert/kernel/cpu/int8/space_to_batch_int8.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/space_to_batch_fp32.h" -#include "nnacl/int8/space_to_batch_int8.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" +#include "nnacl_c/int8/space_to_batch_int8.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/split_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/split_int8.cc index a24bdd3ad2ce3387378197dcaab8dce804c2fd3d..87f8ae0a3fdd2c49e31a84e0e82247b51f48127c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/split_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/split_int8.cc @@ -16,8 +16,8 @@ #include "src/litert/kernel/cpu/int8/split_int8.h" #include -#include "nnacl/split_parameter.h" -#include "nnacl/int8/split_int8.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/int8/split_int8.h" #include "include/errorcode.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/squeeze_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/squeeze_int8.h index 0b0660382a9082aa0219d5705d82ca715fa08c67..28380f7874db4a6a8fd9eae2fd8f6c6441554df7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/squeeze_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/squeeze_int8.h @@ -20,8 +20,8 @@ #include #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/int8/squeeze_int8.h" -#include "nnacl/squeeze_parameter.h" +#include "nnacl_c/int8/squeeze_int8.h" +#include "nnacl_c/squeeze_parameter.h" using mindspore::lite::InnerContext; namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/sub_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/sub_int8.h index af3ab66f742476cc85f1668696712b132453d028..77ecd6c8e2835d8437acf8597919be6435442c4f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/sub_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/sub_int8.h @@ -19,10 +19,10 @@ #include #include #include -#include "nnacl/int8/arithmetic_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/arithmetic_int8.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/lite_kernel.h" -#include "nnacl/int8/sub_int8.h" +#include "nnacl_c/int8/sub_int8.h" namespace mindspore::kernel { class SubInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/tanh_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/tanh_int8.h index d202a6af67e7bdbcb04d261c49d87d7069ebd3e5..f43454a049f04ca162b6ac77f9dc8833e3e1cb30 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/tanh_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/tanh_int8.h @@ -21,8 +21,8 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/tanh_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/tanh_int8.h" +#include "nnacl_c/int8/quantize.h" #include "include/errorcode.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/topk_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/topk_int8.h index 3f319f9bd8fd19544a8a298d714057c8b533aa65..2250ef636b3da146ac968c3d8918cf3e25b90b4d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/topk_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/topk_int8.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/topk_int8.h" +#include "nnacl_c/int8/topk_int8.h" namespace mindspore::kernel { class TopKInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/transpose_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/transpose_int8.cc index 40d52b5e028d21670b9b0f9ed3a933f1805f13a5..6a92baab251f6cebcd92551ec7f5a7796c260de2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/transpose_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/transpose_int8.cc @@ -16,8 +16,8 @@ #include "src/litert/kernel/cpu/int8/transpose_int8.h" #include "src/litert/kernel_registry.h" -#include "nnacl/int8/transpose_int8.h" -#include "nnacl/int8/pack_int8.h" +#include "nnacl_c/int8/transpose_int8.h" +#include "nnacl_c/int8/pack_int8.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.cc index e539c2ba801e32192b1dc5624c447fcb730120b7..e490ab0ba4abc7f5ba1a1e7141253894c682698d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "nnacl/int8/unsqueeze_int8.h" +#include "nnacl_c/int8/unsqueeze_int8.h" #include "src/litert/kernel/cpu/int8/unsqueeze_int8.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.h index daa3e5e431001336e4a8ebb0d99d10db2781ae7c..52996b428f0309761628b749f0cb467d44cd3d49 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/unsqueeze_int8.h" +#include "nnacl_c/int8/unsqueeze_int8.h" #include "src/litert/kernel/cpu/base/layout_transform.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_batchnorm.cc b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_batchnorm.cc index 0ae492b06a269377c64bff4e68c74b6460bd05a3..e1f7671efc0eed8de4ce05f4912e150ac15ff080 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_batchnorm.cc +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_batchnorm.cc @@ -17,7 +17,7 @@ #include "nnacl/nnacl_batchnorm.h" #include "nnacl/nnacl_manager.h" #include "include/errorcode.h" -#include "nnacl/fp32/batchnorm_fp32.h" +#include "nnacl_c/fp32/batchnorm_fp32.h" using mindspore::schema::PrimitiveType_BatchNorm; diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_convolution.cc b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_convolution.cc index 86eade3c291a5a52857f08e96700d741ea162d4f..17f7fb978bb20e8e2e19d9147d00970b040a1766 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_convolution.cc +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_convolution.cc @@ -18,8 +18,8 @@ #include "nnacl/cxx_utils.h" #include "src/litert/pack_weight_manager.h" #include "nnacl/nnacl_manager.h" -#include "nnacl/kernel/convolution_base.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/conv_parameter.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_fused_batch_norm.cc b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_fused_batch_norm.cc index e604423ba5a959f23d668cfae149b301d1a70120..2d9f635e10e221cff3f676f2bba8b9cbe8162998 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_fused_batch_norm.cc +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_fused_batch_norm.cc @@ -17,8 +17,8 @@ #include "nnacl/nnacl_fused_batch_norm.h" #include "nnacl/nnacl_manager.h" #include "include/errorcode.h" -#include "nnacl/fp32/batchnorm_fp32.h" -#include "nnacl/kernel/fused_batch_norm.h" +#include "nnacl_c/fp32/batchnorm_fp32.h" +#include "nnacl_c/kernel/fused_batch_norm.h" using mindspore::schema::PrimitiveType_FusedBatchNorm; diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.cc b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.cc index e970b97ce727a048204bcc38cffe4724563dd0b0..e467a35130e90b17c23b7029d40087c955c4241b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.cc +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.cc @@ -18,7 +18,7 @@ #include "nnacl/cxx_utils.h" #include "src/tensor.h" #include "include/errorcode.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.h b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.h index c03a22aa949ba9fa16c7cf1dbc50d525c39034e0..232036860e35471ce533a4338ac74937aaf74c1b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.h +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_NNACL_KERNEL_H_ #include -#include "nnacl/kernel.h" +#include "nnacl_c/kernel.h" #include "src/executor/kernel_exec.h" #include "src/litert/lite_kernel.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.cc b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.cc index 187c95acbe18118f65ce0d119569b474dded0959..7f169ec149391a2891beaae32e8c1da9f90a6082 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.cc +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.cc @@ -17,7 +17,7 @@ #include "nnacl/nnacl_matmul.h" #include "nnacl/nnacl_manager.h" #include "include/errorcode.h" -#include "nnacl/kernel/matmul_base.h" +#include "nnacl_c/kernel/matmul_base.h" #include "nnacl/cxx_utils.h" #include "src/litert/pack_weight_manager.h" #if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.h b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.h index acd4d31b194a2d78b51fba282812bc04e3a6cdde..7e10fc2078efaaecd1247217176617020d2c54d7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.h +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.h @@ -19,7 +19,7 @@ #include #include "nnacl/nnacl_kernel.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::nnacl { class MatmulKernel : public NNACLKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_reduce.cc b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_reduce.cc index 03c9efd226c5e6a8d7d0a19407d8a9ad6d12d37a..a1d2f3d308787ad950908ff2a61e85e17b3d05e6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_reduce.cc +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_reduce.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/kernel/reduce.h" +#include "nnacl_c/kernel/reduce.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_strided_slice.cc b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_strided_slice.cc index 8370f83ec723939567743b1ccb951267af6ae8c4..fb86fdc4ff2304e7f0673fceb67c090db6292a20 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_strided_slice.cc +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_strided_slice.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/kernel/strided_slice.h" +#include "nnacl_c/kernel/strided_slice.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/CMakeLists.txt b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..374b8bb8817b3b9aeea159b0b5d4c783a2180298 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/CMakeLists.txt @@ -0,0 +1,293 @@ +project(nnacl) + +set(NNACL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${NNACL_DIR}/..) + +set(NNACL_SIMD_DIR ${CMAKE_BINARY_DIR}/src/nnacl_c) +include_directories(${NNACL_SIMD_DIR}/..) +file(GLOB SIMD_CONFIG_HEADER + ${NNACL_DIR}/base/*_simd.h.in + ${NNACL_DIR}/fp32/*_simd.h.in + ${NNACL_DIR}/fp32/online_fusion/*_simd.h.in + ${NNACL_DIR}/fp32_grad/*_simd.h.in +) +function(generate_simd_header_code) + foreach(simd_config_file ${SIMD_CONFIG_HEADER}) + string(REGEX MATCHALL "[0-9A-Za-z_]*_simd.h.in" tmp1 ${simd_config_file}) + string(REGEX REPLACE "_simd.h.in" "_${SIMD_INSTRUCTION_LOWER}.h" tmp2 ${tmp1}) + string(REGEX REPLACE "_simd.h.in" "" tmp3 ${tmp1}) + string(TOLOWER ${tmp3} OP_NAME_LOWER) + string(TOUPPER ${tmp3} OP_NAME_UPPER) + configure_file(${NNACL_DIR}/op_simd_header_file.h.in ${NNACL_SIMD_DIR}/${tmp3}_simd.h) + endforeach() +endfunction() + +function(generate_simd_code SIMD_INSTRUCTION SIMD_BLOCK_NUM SIMD_TARGET) + string(TOLOWER ${SIMD_INSTRUCTION} SIMD_INSTRUCTION_LOWER) + set(SIMD_DEFINE "#define MS_SIMD_${SIMD_INSTRUCTION}") + set(SIMD_UNDEF "#undef MS_SIMD_${SIMD_INSTRUCTION}") + set(SIMD_DEF_INSTRUCTION "#define MS_SIMD_INSTRUCTION MS_SIMD_${SIMD_INSTRUCTION}_INSTRUCTION") + set(SIMD_UNDEF_INSTRUCTION "#undef MS_SIMD_INSTRUCTION") + set(SIMD_DEF_BLOCK_NUM "#define BLOCK_NUM ${SIMD_BLOCK_NUM}") + set(SIMD_UNDEF_BLOCK_NUM "#undef BLOCK_NUM") + if(SIMD_INSTRUCTION_LOWER STREQUAL "neon") + set(SIMD_TARGET_BEGIN "") + set(SIMD_TARGET_END "") + else() + set(SIMD_TARGET_BEGIN "#pragma GCC push_options\n#pragma GCC target(${SIMD_TARGET})") + set(SIMD_TARGET_END "#pragma GCC pop_options") + endif() + + set(SIMD_INSTRUCTION_BEGIN "${SIMD_TARGET_BEGIN}\n${SIMD_DEF_INSTRUCTION}\n${SIMD_DEF_BLOCK_NUM}\n${SIMD_DEFINE}") + set(SIMD_INSTRUCTION_END "${SIMD_UNDEF_INSTRUCTION}\n${SIMD_UNDEF_BLOCK_NUM}\n${SIMD_TARGET_END}\n${SIMD_UNDEF}") + foreach(simd_config_file ${SIMD_CONFIG_HEADER}) + string(REGEX MATCHALL "[0-9A-Za-z_]*_simd.h.in" tmp1 ${simd_config_file}) + string(REGEX REPLACE "_simd.h.in" "_${SIMD_INSTRUCTION_LOWER}.h" tmp2 ${tmp1}) + configure_file(${simd_config_file} ${NNACL_SIMD_DIR}/${SIMD_INSTRUCTION_LOWER}/${tmp2}) + endforeach() +endfunction() +generate_simd_code(NEON 4 \"\") +generate_simd_code(SSE 4 \"sse4.1\") +generate_simd_code(AVX 8 "\"avx\", \"avx2\"") +generate_simd_code(AVX512 16 \"avx512f\") +generate_simd_header_code() + +if(ENABLE_CPU AND NOT MSVC) + set(CMAKE_C_FLAGS "-Wno-attributes ${CMAKE_C_FLAGS}") +endif() + +if(APPLE OR PLATFORM_ARM32 OR PLATFORM_ARM64 OR PLATFORM_MCU) + if("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND DEFINED ARCHS) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstrict-aliasing \ + -ffunction-sections -fdata-sections -ffast-math -Wno-shorten-64-to-32") + endif() + if("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND NOT DEFINED ARCHS) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fomit-frame-pointer -fstrict-aliasing \ + -ffunction-sections -fdata-sections -ffast-math") + endif() + if(TARGET_OHOS) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-inline-asm") + endif() +elseif(NOT MSVC) + if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fomit-frame-pointer -fstrict-aliasing -ffunction-sections \ + -fdata-sections") + endif() +endif() + +if(NOT MSVC) + if("${X86_64_SIMD}" STREQUAL "sse" OR "${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1") + endif() + if("${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma") + endif() + if(WIN32) + if("${X86_64_SIMD}" STREQUAL "avx512") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx512f -fno-asynchronous-unwind-tables") + endif() + endif() +endif() + + +########################### files ########################### +file(GLOB KERNEL_SRC + ${NNACL_DIR}/*.c + ${NNACL_DIR}/fp32/*.c + ${NNACL_DIR}/infer/*.c + ${NNACL_DIR}/base/*.c + ${NNACL_DIR}/fp32_grad/*.c + ${NNACL_DIR}/kernel/*.c + ${NNACL_DIR}/experimental/*.c + ${NNACL_DIR}/fp32/online_fusion/*.c +) + +set(KERNEL_AVX512_FILE ${NNACL_DIR}/fp32/matmul_avx512_fp32.c + ${NNACL_DIR}/fp32/matmul_avx512_mask_fp32.c + ${NNACL_DIR}/fp32/conv_im2col_avx512_fp32.c +) +list(REMOVE_ITEM KERNEL_SRC ${KERNEL_AVX512_FILE}) + +set(KERNEL_AVX_FILE ${NNACL_DIR}/fp32/conv_sw_avx_fp32.c + ${NNACL_DIR}/fp32/conv_1x1_avx_fp32.c + ${NNACL_DIR}/fp32/matmul_avx_fp32.c + ${NNACL_DIR}/fp32/conv_depthwise_avx_fp32.c) +list(REMOVE_ITEM KERNEL_SRC ${KERNEL_AVX_FILE}) + +set(KERNEL_ARM64_FILE ${NNACL_DIR}/fp32/conv_sw_arm64_fp32.c) +list(REMOVE_ITEM KERNEL_SRC ${KERNEL_ARM64_FILE}) + +if(NOT MSLITE_ENABLE_RUNTIME_PASS) + list(REMOVE_ITEM KERNEL_SRC ${NNACL_DIR}/infer/shape_fusion_infer.c) +endif() +if((NOT DEFINED MSLITE_ENABLE_INT8) OR MSLITE_ENABLE_INT8) + file(GLOB KERNEL_SRC_INT8 + ${NNACL_DIR}/int8/*.c + ) + set(KERNEL_SRC + ${KERNEL_SRC} + ${KERNEL_SRC_INT8} + ) +else() + set(KERNEL_SRC + ${KERNEL_SRC} + ${NNACL_DIR}/int8/pack_int8.c + ${NNACL_DIR}/int8/quantize.c + ) +endif() + +if(MSLITE_ENABLE_SPARSE_COMPUTE) + file(GLOB KERNEL_SRC_SPARSE + ${NNACL_DIR}/fp32_sparse/*.c + ) + set(KERNEL_SRC + ${KERNEL_SRC} + ${KERNEL_SRC_SPARSE} + ) +endif() + +if(MSLITE_ENABLE_STRING_KERNEL) + file(GLOB KERNEL_SRC_INFER_STRING + ${NNACL_DIR}/infer/string/*.c + ) + set(KERNEL_SRC + ${KERNEL_SRC} + ${KERNEL_SRC_INFER_STRING} + ) +endif() +if(MSLITE_ENABLE_CONTROLFLOW) + file(GLOB KERNEL_SRC_INFER_CONTROL_TENSORLIST + ${NNACL_DIR}/infer/control/*.c + ) + set(KERNEL_SRC + ${KERNEL_SRC} + ${KERNEL_SRC_INFER_CONTROL_TENSORLIST} + ) +endif() +if(PLATFORM_ARM64) + file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/assembly/arm64/*.S) + set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) + set(KERNEL_SRC ${KERNEL_SRC} ${KERNEL_ARM64_FILE}) +endif() + +if(PLATFORM_ARM32) + file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/assembly/arm32/*.S) + set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) +endif() + +if("${X86_64_SIMD}" STREQUAL "sse" OR "${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512") + file(GLOB ASSEMBLY_SSE_SRC ${NNACL_DIR}/intrinsics/sse/*.c) + set_property(SOURCE ${ASSEMBLY_SSE_SRC} PROPERTY LANGUAGE C) + + set(MS_X86_SSE_SRC + ${ASSEMBLY_SSE_SRC} + ${KERNEL_SSE_FILE}) + set_source_files_properties(${MS_X86_SSE_SRC} PROPERTIES LANGUAGE C + COMPILE_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -fPIC") + + set(MS_X86_SIMD_SRC ${MS_X86_SIMD_SRC} ${MS_X86_SSE_SRC}) +endif() + +if("${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512") + file(GLOB ASSEMBLY_AVX_SRC + ${NNACL_DIR}/intrinsics/avx/*.c + ${NNACL_DIR}/assembly/avx/*.S + ${NNACL_DIR}/intrinsics/ms_simd_cpu_info.c) + set_property(SOURCE ${ASSEMBLY_AVX_SRC} PROPERTY LANGUAGE C) + + set(MS_X86_AVX_SRC + ${ASSEMBLY_AVX_SRC} + ${KERNEL_AVX_FILE}) + set_source_files_properties(${MS_X86_AVX_SRC} PROPERTIES LANGUAGE C + COMPILE_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -fPIC") + + set(MS_X86_SIMD_SRC ${MS_X86_SIMD_SRC} ${MS_X86_AVX_SRC}) +endif() + +if("${X86_64_SIMD}" STREQUAL "avx512") + if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") + file(GLOB HPC_SRC ${NNACL_DIR}/experimental/HPC-generator/gemm_avx512/*.c + ${NNACL_DIR}/experimental/HPC-generator/gemm_mask_avx512/*.c) + + set_property(SOURCE ${HPC_SRC} PROPERTY LANGUAGE C) + endif() + + file(GLOB ASSEMBLY_AVX512_SRC ${NNACL_DIR}/assembly/avx512/*.S) + set_property(SOURCE ${ASSEMBLY_AVX512_SRC} PROPERTY LANGUAGE C) + + set(MS_X86_AVX512_SRC + ${HPC_SRC} + ${ASSEMBLY_AVX512_SRC} + ${KERNEL_AVX512_FILE}) + + set_source_files_properties(${MS_X86_AVX512_SRC} PROPERTIES LANGUAGE C + COMPILE_FLAGS "${CMAKE_C_FLAGS} -mavx512f -fPIC") + + set(MS_X86_SIMD_SRC ${MS_X86_SIMD_SRC} ${MS_X86_AVX512_SRC}) +endif() + +if(APPLE) + set_source_files_properties(${ASSEMBLY_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp") +endif() + +########################### build nnacl library ######################## +if(NOT MSVC) +string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") +endif() + +if(PLATFORM_ARM) + set(NO_FAST_MATH_OPTI ${NNACL_DIR}/fp32/resize_fp32.c) + set_source_files_properties(${NO_FAST_MATH_OPTI} PROPERTIES LANGUAGE C + COMPILE_FLAGS "${CMAKE_C_FLAGS} -fno-fast-math") +endif() + +add_library(nnacl_mid OBJECT ${KERNEL_SRC} ${ASSEMBLY_SRC} ${MS_X86_SIMD_SRC}) + +if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_DEBUG) +endif() + +if(ENABLE_CPU) + if(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "aarch64") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_ARM ENABLE_ARM64 ENABLE_NEON) + target_compile_options(nnacl_mid PRIVATE -ffast-math -flax-vector-conversions) + elseif("${X86_64_SIMD}" STREQUAL "sse") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE) + elseif("${X86_64_SIMD}" STREQUAL "avx") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX) + elseif("${X86_64_SIMD}" STREQUAL "avx512") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX ENABLE_AVX512) + endif() + if(NOT MSVC) + target_compile_options(nnacl_mid PRIVATE -fPIC -fstack-protector-all) + add_library(nnacl SHARED $) + else() + add_library(nnacl STATIC $) + endif() + if(NOT CMAKE_SYSTEM_NAME MATCHES "Windows") + if(NOT CMAKE_SYSTEM_NAME MATCHES "Darwin") + target_link_options(nnacl PRIVATE -Wl,-z,relro,-z,now,-z,noexecstack) + target_link_libraries(nnacl PRIVATE m) + endif() + if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") + target_link_options(nnacl PRIVATE -s) + endif() + endif() +endif() + +set(nnacl_static_obj $) +########################### arm fp16 build optimize library ######################## +if(PLATFORM_ARM) + add_subdirectory(${NNACL_DIR}/optimize) + if(TARGET nnacl_optimize_mid) + set(nnacl_static_obj ${nnacl_static_obj} $) + endif() + if(TARGET nnacl_fp16_mid) + set(nnacl_static_obj ${nnacl_static_obj} $) + endif() +endif() +if(NOT ${CMAKE_GENERATOR} MATCHES "Ninja") + add_library(nnacl_static STATIC ${nnacl_static_obj}) + set_target_properties(nnacl_static PROPERTIES OUTPUT_NAME "nnacl") + set_target_properties(nnacl_static PROPERTIES CLEAN_DIRECT_OUTPUT 1) +endif() diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/OWNERS b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/OWNERS new file mode 100644 index 0000000000000000000000000000000000000000..350278889db189cb4c82132710e9f9e54df85d5c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/OWNERS @@ -0,0 +1,11 @@ +approvers: +- jjfeing +- YeFeng_24 +- fatmouse007fatmouse007 +- xu_anyue + +reviewers: +- liuf9 + +options: + no_parent_owners: true diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/README.md b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c756a7db075467cf782af013fd36cc0c8ed942e3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/README.md @@ -0,0 +1 @@ +NNACL(neural network accelerated computing library) is a high performance library of neural network inference computing kernels for ARM. diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/activation_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/activation_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..7a9af3249c6321f7450ab10b32728867fdb10ec2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/activation_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_ACTIVATION_PARAMETER_H_ +#define NNACL_ACTIVATION_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +typedef struct ActivationParameter { + OpParameter op_parameter_; + int type_; + float alpha_; + float min_val_; + float max_val_; + bool approximate_; +} ActivationParameter; + +#endif // NNACL_ACTIVATION_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/affine_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/affine_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..d4d30934c090356862c0c763e28108e52bc313b0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/affine_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_AFFINE_PARAMETER_H_ +#define NNACL_AFFINE_PARAMETER_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/matmul_parameter.h" +typedef struct AffineParameter { + OpParameter op_parameter_; + // parameters from splice op + int context_size_; + int *context_; + int output_dim_; + // parameters from activation op + int activation_type_; + // parameters from matmul op + MatMulParameter *matmul_parameter_; +} AffineParameter; +#endif // NNACL_AFFINE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/all_gather_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/all_gather_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..3ddbde86b1140b553c01e71afa5eb5ef121ab9fb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/all_gather_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ALL_GATHER_PARAMETER_H_ +#define NNACL_ALL_GATHER_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct AllGatherParameter { + // primitive parameter + OpParameter op_parameter_; + char group_[DEFAULT_GROUP_NAME_LEN]; + + // other parameter + int rank_size_; +} AllGatherParameter; +#endif // NNACL_ALL_GATHER_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arg_min_max_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arg_min_max_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..0fdad1789873878d460c4aa3587392fa5a5d1d1b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arg_min_max_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ARG_MIN_MAX_PARAMETER_H_ +#define NNACL_ARG_MIN_MAX_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ArgMinMaxParameter { + OpParameter op_parameter_; + int32_t axis_; + int32_t topk_; + bool keep_dims_; + bool out_value_; +} ArgMinMaxParameter; + +#endif // NNACL_ARG_MIN_MAX_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arithmetic_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arithmetic_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..adc4643a49317cc6dc0e279604ba9bcb5deb9715 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arithmetic_parameter.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ARTITHMETIC_PARAMETER_H_ +#define NNACL_ARTITHMETIC_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/nnacl_utils.h" + +#define ARITHMETIC_SUPPORT_DIMS_NUM 10 + +typedef struct ArithmeticParameter { + OpParameter op_parameter_; + bool broadcasting_; + size_t ndim_; + int activation_type_; + int in_shape0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int64_t in_elements_num0_; + int in_shape1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int64_t in_elements_num1_; + + int out_shape_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int out_elements_num_; + + int in_strides0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int in_strides1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int out_strides_[ARITHMETIC_SUPPORT_DIMS_NUM]; + + int multiples0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int multiples1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int eltwise_mode_; // eltwise need +} ArithmeticParameter; + +#endif // NNACL_ARTITHMETIC_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arithmetic_self_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arithmetic_self_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..611a306412210da3c1e8d9dff6a0a1d1568689dd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arithmetic_self_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ARITHMETIC_SELF_PARAMETER_H_ +#define NNACL_ARITHMETIC_SELF_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/int8/quantize.h" + +// For Abs, Cos, Exp, Log, Square, Sqrt, Rsqrt ops. +typedef struct ArithmeticSelfParameter { + OpParameter op_parameter_; + ArithSelfQuantArg quant_arg_; +} ArithmeticSelfParameter; + +#endif // NNACL_ARITHMETIC_SELF_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDw3x3Int8BorderPixel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDw3x3Int8BorderPixel.S new file mode 100644 index 0000000000000000000000000000000000000000..eadcf972c9587211495e3ae1e9f7ba9c17f7c120 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDw3x3Int8BorderPixel.S @@ -0,0 +1,128 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, +// size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, +// size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max +// size_t per_channel) { + +// todo: support per channel +// r0: dst, r1: src, r2: weight, r3: bias, r4: height, r5: width, r6: in_kh_step, r7: in_kw_step, +// r8: channel, r9: in_zp, r10: out_zp, r11: out_multiplier, r12: left_shift, r13: right_shift +// r14: acc_min, r15: acc_max +asm_function ConvDw3x3Int8BorderPixel + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + + push {r4-r8, r9-r12, lr} + vpush {q4-q7} + add sp, sp, #104 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + ldrb r10, [sp, #20] // in_zp + vdup.8 d18, r10 + ldr r10, [sp, #24] // out_zp + vdup.32 q15, r10 + ldr r10, [sp, #28] // out_multiplier + vdup.32 q14, r10 + ldr r10, [sp, #32] // left_shift + vdup.32 q13, r10 + ldr r10, [sp, #36] // right_shift + vdup.32 q12, r10 + ldr r10, [sp, #40] // acc_min + vdup.32 q11, r10 + ldr r10, [sp, #44] // acc_max + vdup.32 q10, r10 + + mov r4, #2 + mul lr, r8, r4 + + LoopC: + mov r9, r1 + mov r10, r2 + ldr r4, [sp] + + vld1.32 {q3}, [r3]! + vld1.32 {q4}, [r3]! + LoopH: + mov r11, r9 + mov r12, r10 + ldr r5, [sp, #4] + LoopW: + vld1.8 {d0}, [r11], r7 + vld1.16 {d2, d3}, [r12], lr // weight + vsubl.s8 q2, d0, d18 // -zp + + vmlal.s16 q3, d4, d2 + vmlal.s16 q4, d5, d3 + + subs r5, r5, #1 + bne LoopW + subs r4, r4, #1 + add r9, r9, r6 + mov r11, #3 + mul r5, lr, r11 + add r10, r10, r5 + bne LoopH + + vshl.s32 q3, q3, q13 + vqrdmulh.s32 q3, q3, q14 + vand q5, q3, q12 + vshr.s32 q5, q5, #31 + vqadd.s32 q3, q3, q5 + vrshl.s32 q3, q3, q12 + vadd.i32 q3, q3, q15 + vmax.s32 q3, q3, q11 + vmin.s32 q3, q3, q10 + vqmovn.s32 d14, q3 + + vshl.s32 q4, q4, q13 + vqrdmulh.s32 q4, q4, q14 + vand q6, q4, q12 + vshr.s32 q6, q6, #31 + vqadd.s32 q4, q4, q6 + vrshl.s32 q4, q4, q12 + vadd.i32 q4, q4, q15 + vmax.s32 q4, q4, q11 + vmin.s32 q4, q4, q10 + vqmovn.s32 d15, q4 + vqmovn.s16 d16, q7 + + vst1.8 {d16}, [r0]! + add r1, r1, #8 + add r2, r2, #16 + + sub r8, r8, #8 + cmp r8, #8 + bge LoopC + + sub sp, sp, #104 + vpop {q4-q7} + pop {r4-r8, r9-r12, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Border.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Border.S new file mode 100644 index 0000000000000000000000000000000000000000..5da6cdd47797dfc614981b2fba64792868e24741 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Border.S @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6) +// r0: dst, r1: src, r2: weight, r3: bias, r4: height, r5: width, r6: in_kh_step, r7: in_kw_step, +// r8: kernel_w, r9: relu, r10: relu6 +asm_function ConvDwFp32Border + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r12, lr} + vpush {q4-q7} + add sp, sp, #104 + + ldr r4, [sp] // height + ldr r5, [sp, #4] // width + ldr r6, [sp, #8] // in_kh_step + ldr r7, [sp, #12] // in_kw_step + ldr r8, [sp, #16] // kernel_w + ldr r9, [sp, #20] // relu + ldr r10, [sp, #24] // relu6 + + vld1.32 {q0}, [r3] // bias + vmov.i32 q1, #6 // relu6 + vcvt.f32.s32 q1, q1 + veor q2, q2, q2 // relu + + LoopH: + mov r11, r1 + mov r12, r2 + mov r14, r5 + LoopW: + vld1.32 {q3}, [r11], r7 + vld1.32 {q4}, [r12]! + vmla.f32 q0, q3, q4 + subs r14, r14, #1 + bne LoopW + subs r4, r4, #1 + add r1, r1, r6 + add r2, r2, r8 + bne LoopH + + cmp r10, #0 + bne Relu6 + cmp r9, #0 + bne Relu + b Write + Relu6: + vmin.f32 q0, q0, q1 + Relu: + vmax.f32 q0, q0, q2 + Write: + vst1.32 {q0}, [r0] + + sub sp, sp, #104 + vpop {q4-q7} + pop {r4-r12, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Center.S new file mode 100644 index 0000000000000000000000000000000000000000..9935418bc218f902221d083ffe7e63f449ac819c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Center.S @@ -0,0 +1,176 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +// r0: dst, r1: src, r2: weight, r3: bias, #0: height, #4: width, #8: kernel_h, #12: kernel_w, +// #16: out_h_step, #20: block_channel, #24: in_sh_step, #28: in_sw_step, #32: in_kh_step,#36: in_kw_step +// #40: relu, #44: relu6 +asm_function ConvDwFp32Center + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #112 + + ldr r4, [sp] // height + + vld1.32 {q13}, [r3] + vmov.i32 q14, #6 + vcvt.f32.s32 q14, q14 + veor q15, q15, q15 + + LoopH: + ldr r1, [sp, #-44] // src_w, src_h = src + ldr r5, [sp, #4] // width + ldr r0, [sp, #-48] // dst_w, dst_h = dst + cmp r5, #4 + blt LoopW + LoopW4: + ldr r11, [sp, #28] // in_sw_step + mov r8, r1 // src_kh, src_w + ldr r2, [sp, #-40] // weight_kh, weight + ldr r6, [sp, #8] // kernel_h + vmov q0, q13 + vmov q1, q13 + vmov q2, q13 + vmov q3, q13 + LoopKh4: + ldr r7, [sp, #12] // kernel_w + mov lr, r8 // src_kw, src_kh + LoopKw4: + ldr r12, [sp, #36] //in_kw_step + mov r10, lr + vld1.32 {q12}, [r2]! + vld1.32 {q4}, [r10] + add r10, r10, r11 + vmla.f32 q0, q4, q12 + vld1.32 {q5}, [r10] + add r10, r10, r11 + vmla.f32 q1, q5, q12 + vld1.32 {q6}, [r10] + add r10, r10, r11 + vmla.f32 q2, q6, q12 + vld1.32 {q7}, [r10] + add r10, r10, r11 + vmla.f32 q3, q7, q12 + subs r7, r7, #1 + add lr, lr, r12 + bne LoopKw4 + ldr r12, [sp, #32] // in_kh_step + add r8, r8, r12 + subs r6, r6, #1 + bne LoopKh4 + ldr r12, [sp, #44] + cmp r12, #0 + bne Relu64 + ldr r12, [sp, #40] + cmp r12, #0 + bne Relu4 + b Write4 + Relu64: + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmin.f32 q2, q2, q14 + vmin.f32 q3, q3, q14 + Relu4: + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vmax.f32 q2, q2, q15 + vmax.f32 q3, q3, q15 + Write4: + ldr r12, [sp, #20] // block_channel + vst1.32 {q0}, [r0] + add r0, r0, r12 + vst1.32 {q1}, [r0] + add r0, r0, r12 + vst1.32 {q2}, [r0] + add r0, r0, r12 + vst1.32 {q3}, [r0] + add r0, r0, r12 + mov r12, #4 + mul r11, r11, r12 + add r1, r1, r11 // src_w += in_sw_step + sub r5, r5, #4 + cmp r5, #0 + ble LoopWEnd + cmp r5, #4 + bge LoopW + LoopW: + mov r8, r1 // src_kh, src_w + ldr r2, [sp, #-40] // weight_kh, weight + ldr r6, [sp, #8] // kernel_h + vmov q0, q13 // bias + LoopKh: + ldr r7, [sp, #12] // kernel_w + mov r10, r8 // src_kw, src_kh + LoopKw: + ldr r12, [sp, #36] //in_kw_step + vld1.32 {q1}, [r10] + add r10, r10, r12 + vld1.32 {q12}, [r2]! + vmla.f32 q0, q1, q12 + subs r7, r7, #1 + bne LoopKw + ldr r12, [sp, #32] // in_kh_step + add r8, r8, r12 + subs r6, r6, #1 + bne LoopKh + ldr r12, [sp, #44] + cmp r12, #0 + bne Relu6 + ldr r12, [sp, #40] + cmp r12, #0 + bne Relu + b Write + Relu6: + vmin.f32 q0, q0, q14 + Relu: + vmax.f32 q0, q0, q15 + Write: + ldr r12, [sp, #20] // block_channel + vst1.32 {q0}, [r0] // dst_kw += block_channel + add r0, r0, r12 + ldr r12, [sp, #28] // in_sw_step + add r1, r1, r12 // src_w += in_sw_step + subs r5, r5, #1 + bne LoopW + ldr r3, [sp, #16] // out_h_step + ldr r12, [sp, #-48] + add r12, r12, r3 + str r12, [sp, #-48] + + ldr r3, [sp, #24] // in_sh_step + ldr r12, [sp, #-44] // src_h += in_sh_step + add r12, r12, r3 + str r12, [sp, #-44] + + subs r4, r4, #1 // height + bne LoopH +LoopWEnd: + sub sp, sp, #112 + vpop {q4-q7} + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Row.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Row.S new file mode 100644 index 0000000000000000000000000000000000000000..e1b2ff7e013dd1f4fada011bd907943d993fe164 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Row.S @@ -0,0 +1,125 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// voidConvDwFp32Row(float* output_ptr, const float* input_ptr, const float* filter_ptr, +// size_t num_pixels, size_t input_channel, size_t input_step) +// r0: output_ptr, r1: input_ptr, r2: filter_ptr, r3: num_pixels, +// r4: input_channel, r5: input_step +asm_function ConvDwFp32Row + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + + push {r4-r6, r8, r10, r11} + vpush {q4-q7} + add sp, sp, #88 + mov r11, r0 + ldr r4, [sp] + ldr r5, [sp, #4] + mov r6, #4 + mul r5, r5, r6 + cmp r3, #0 + ble End + + LoopNumPixel: + mov r6, r1 // input_ptr + mov r8, r2 // filter_ptr + mov r10, r4 // input_channel + + LoopDepth16In: + cmp r10, #16 + blt L4 + sub r10, r10, #16 + + vld1.32 {q0, q1}, [r6]! + vld1.32 {q4, q5}, [r8]! + vld1.32 {q8, q9}, [r0]! + + cmp r10, #16 + blt LoopDepth16Out + LoopDepth16: + vmla.f32 q8, q0, q4 + vmla.f32 q9, q1, q5 + vst1.32 {q8, q9}, [r11]! + + vld1.32 {q2, q3}, [r6]! + vld1.32 {q6, q7}, [r8]! + vld1.32 {q10, q11}, [r0]! + vmla.f32 q10, q2, q6 + vmla.f32 q11, q3, q7 + vst1.32 {q10, q11}, [r11]! + + vld1.32 {q0, q1}, [r6]! + vld1.32 {q4, q5}, [r8]! + vld1.32 {q8, q9}, [r0]! + + sub r10, r10, #16 + cmp r10, #16 + bge LoopDepth16 + + LoopDepth16Out: + vmla.f32 q8, q0, q4 + vmla.f32 q9, q1, q5 + vst1.32 {q8, q9}, [r11]! + + vld1.32 {q2, q3}, [r6]! + vld1.32 {q6, q7}, [r8]! + vld1.32 {q10, q11}, [r0]! + vmla.f32 q10, q2, q6 + vmla.f32 q11, q3, q7 + vst1.32 {q10, q11}, [r11]! + + L4: + cmp r10, #4 + blt L0 + + LoopDepth4: + vld1.32 {q0}, [r6]! + vld1.32 {q4}, [r8]! + vld1.32 {q8}, [r0]! + vmla.f32 q8, q0, q4 + vst1.32 {q8}, [r11]! + sub r10, r10, #4 + cmp r10, #4 + bge LoopDepth4 + + L0: + cmp r10, #0 + beq Loop16LineEnd + + LoopDepth0: + vld1.32 d0[0], [r6]! + vld1.32 d2[0], [r8]! + vld1.32 d4[0], [r0]! + vmla.f32 s8, s0, s4 + vst1.32 d4[0], [r11]! + subs r10, r10, #1 + bne LoopDepth0 + + Loop16LineEnd: + subs r3, r3, #1 + add r1, r1, r5 + bne LoopNumPixel + + End: + sub sp, sp, #88 + vpop {q4-q7} + pop {r4-r6, r8, r10, r11} + bx lr +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8Center.S new file mode 100644 index 0000000000000000000000000000000000000000..fc41d0fd8e9e46d35a95993a573f22d24343bfe4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8Center.S @@ -0,0 +1,290 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DepthwiseCenterInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, +// int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, +// int in_sw_step, int in_kh_step, int in_kw_step, int8_t *in_zp, int32_t *out_zp, +// int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t *acc_min, +// int32_t *acc_max) +// #-48: dst, #-44: src, #-40: weight, #-36: bias, #0: height, #4: width, #8: kernel_h, #12: kernel_w, +// #16: out_h_step, #20: block_channel, #24: in_sh_step, #28: in_sw_step, #32: in_kh_step, #36: in_kw_step +// #40: in_zp, #44: out_zp, #48: out_multiplier, #52: left_shift, #56: right_shift, #60:acc_min, #64: acc_max +asm_function ConvDwInt8Center +// at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" +// according to https://stackoverflow.com/questions/53625807 +// even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway +// clang's rule seems more simple, though there are no subroutine calls here +// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + vpush {q4-q7} + + ldr lr, [sp, #168] + vld1.32 {q0, q1}, [lr] + vpush {q0, q1} + ldr lr, [sp, #204] + vld1.32 {q0, q1}, [lr] + vpush {q0, q1} + ldr lr, [sp, #240] + vld1.32 {q0, q1}, [lr] + vpush {q0, q1} + add sp, sp, #208 + + ldr r1, [sp, #-36] + vld1.32 {q8, q9}, [r1] + ldr r1, [sp, #44] + vld1.32 {q10, q11}, [r1] + ldr r1, [sp, #48] + vld1.32 {q12, q13}, [r1] + ldr r1, [sp, #52] + vld1.32 {q14, q15}, [r1] + + ldr r11, [sp, #28] + ldr r4, [sp] + LoopH: + ldr r1, [sp, #-44] + ldr r0, [sp, #-48] + ldr r5, [sp, #4] + LoopW2: + vmov q4, q8 + vmov q5, q9 + vmov q6, q8 + vmov q7, q9 + mov r7, r1 + ldr r3, [sp, #-40] + ldr r6, [sp, #8] + LoopKH: + mov r9, r7 + ldr r10, [sp, #12] + LoopKW: + mov r8, r9 + vld1.16 {q0}, [r3]! + ldr lr, [sp, #40] + vld1.8 {d2}, [lr] + + vld1.8 {d3}, [r8] + add r8, r8, r11 + vsubl.s8 q2, d3, d2 + vmlal.s16 q4, d4, d0 + vmlal.s16 q5, d5, d1 + + vld1.8 {d3}, [r8] + add r8, r8, r11 + vsubl.s8 q2, d3, d2 + vmlal.s16 q6, d4, d0 + vmlal.s16 q7, d5, d1 + + ldr r12, [sp, #36] + add r9, r9, r12 + subs r10, r10, #1 + bne LoopKW + ldr r12, [sp, #32] + add r7, r7, r12 + subs r6, r6, #1 + bne LoopKH + + vshl.s32 q4, q4, q14 + vshl.s32 q5, q5, q15 + vshl.s32 q6, q6, q14 + vshl.s32 q7, q7, q15 + + vqrdmulh.s32 q4, q4, q12 + vqrdmulh.s32 q5, q5, q13 + vqrdmulh.s32 q6, q6, q12 + vqrdmulh.s32 q7, q7, q13 + + sub lr, sp, #144 + vld1.32 {q0, q1}, [lr] + + vand q2, q4, q0 + vshr.s32 q2, q2, #31 + vqadd.s32 q4, q4, q2 + vrshl.s32 q4, q4, q0 + + vand q2, q5, q1 + vshr.s32 q2, q2, #31 + vqadd.s32 q5, q5, q2 + vrshl.s32 q5, q5, q1 + + vand q2, q6, q0 + vshr.s32 q2, q2, #31 + vqadd.s32 q6, q6, q2 + vrshl.s32 q6, q6, q0 + + vand q2, q7, q1 + vshr.s32 q2, q2, #31 + vqadd.s32 q7, q7, q2 + vrshl.s32 q7, q7, q1 + + vadd.i32 q4, q4, q10 + vadd.i32 q5, q5, q11 + vadd.i32 q6, q6, q10 + vadd.i32 q7, q7, q11 + + sub lr, sp, #176 + vld1.32 {q0, q1}, [lr] + vmax.s32 q4, q4, q0 + vmax.s32 q5, q5, q1 + vmax.s32 q6, q6, q0 + vmax.s32 q7, q7, q1 + + sub lr, sp, #208 + vld1.32 {q0, q1}, [lr] + vmin.s32 q4, q4, q0 + vmin.s32 q5, q5, q1 + vmin.s32 q6, q6, q0 + vmin.s32 q7, q7, q1 + + vqmovn.s32 d0, q4 + vqmovn.s32 d1, q5 + vqmovn.s32 d2, q6 + vqmovn.s32 d3, q7 + vqmovn.s16 d0, q0 + vqmovn.s16 d1, q1 + + + ldr r12, [sp, #20] + mov r8, r0 + vst1.8 {d0[0]}, [r8]! + vst1.8 {d0[1]}, [r8]! + vst1.8 {d0[2]}, [r8]! + vst1.8 {d0[3]}, [r8]! + vst1.8 {d0[4]}, [r8]! + vst1.8 {d0[5]}, [r8]! + vst1.8 {d0[6]}, [r8]! + vst1.8 {d0[7]}, [r8]! + add r0, r0, r12 + + mov r8, r0 + vst1.8 {d1[0]}, [r8]! + vst1.8 {d1[1]}, [r8]! + vst1.8 {d1[2]}, [r8]! + vst1.8 {d1[3]}, [r8]! + vst1.8 {d1[4]}, [r8]! + vst1.8 {d1[5]}, [r8]! + vst1.8 {d1[6]}, [r8]! + vst1.8 {d1[7]}, [r8]! + add r0, r0, r12 + + add r1, r1, r11 + add r1, r1, r11 + subs r5, r5, #2 + beq LoopEndW + cmp r5, #2 + bge LoopW2 + + LoopW: + vmov q4, q8 + vmov q5, q9 + mov r7, r1 + ldr r3, [sp, #-40] + ldr r6, [sp, #8] + LoopKH1: + mov r9, r7 + ldr r10, [sp, #12] + LoopKW1: + vld1.16 {q0}, [r3]! + ldr lr, [sp, #40] + vld1.8 {d2}, [lr] + + vld1.8 {d3}, [r9] + vsubl.s8 q2, d3, d2 + vmlal.s16 q4, d4, d0 + vmlal.s16 q5, d5, d1 + + ldr r12, [sp, #36] + add r9, r9, r12 + subs r10, r10, #1 + bne LoopKW1 + ldr r12, [sp, #32] + add r7, r7, r12 + subs r6, r6, #1 + bne LoopKH1 + + vshl.s32 q4, q4, q14 + vshl.s32 q5, q5, q15 + + vqrdmulh.s32 q4, q4, q12 + vqrdmulh.s32 q5, q5, q13 + + sub lr, sp, #144 + vld1.32 {q0, q1}, [lr] + vand q2, q4, q0 + vshr.s32 q2, q2, #31 + vqadd.s32 q4, q4, q2 + vrshl.s32 q4, q4, q0 + + vand q2, q5, q1 + vshr.s32 q2, q2, #31 + vqadd.s32 q5, q5, q2 + vrshl.s32 q5, q5, q1 + + vadd.i32 q4, q4, q10 + vadd.i32 q5, q5, q11 + + sub lr, sp, #176 + vld1.32 {q0, q1}, [lr] + vmax.s32 q4, q4, q0 + vmax.s32 q5, q5, q1 + + sub lr, sp, #208 + vld1.32 {q0, q1}, [lr] + vmin.s32 q4, q4, q0 + vmin.s32 q5, q5, q1 + + vqmovn.s32 d0, q4 + vqmovn.s32 d1, q5 + vqmovn.s16 d0, q0 + + mov r8, r0 + vst1.8 {d0[0]}, [r8]! + vst1.8 {d0[1]}, [r8]! + vst1.8 {d0[2]}, [r8]! + vst1.8 {d0[3]}, [r8]! + vst1.8 {d0[4]}, [r8]! + vst1.8 {d0[5]}, [r8]! + vst1.8 {d0[6]}, [r8]! + vst1.8 {d0[7]}, [r8]! + ldr r12, [sp, #20] + add r0, r0, r12 + add r1, r1, r11 + subs r5, r5, #1 + bne LoopW + + LoopEndW: + ldr r12, [sp, #16] + ldr r1, [sp, #-48] + add r1, r1, r12 + str r1, [sp, #-48] + ldr r12, [sp, #24] + ldr r1, [sp, #-44] + add r1, r1, r12 + str r1, [sp, #-44] + subs r4, r4, #1 + bne LoopH + + LoopEndH: + sub sp, sp, #208 + vpop {q0, q1} + vpop {q0, q1} + vpop {q0, q1} + vpop {q4-q7} + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8PostAlign4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8PostAlign4.S new file mode 100644 index 0000000000000000000000000000000000000000..d74d5e2cbb25186704a4dc67a6d0391174e45b87 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8PostAlign4.S @@ -0,0 +1,120 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, +// int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); +// r0: dst, r1: buffer, r2: num_pixels, r3: output_zp, r4: out_multiplier, +// r5: left_shift, r6: right_shift, r7: acc_min, r8: acc_max + +asm_function ConvDwInt8PostAlign4 + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10} + vpush {q4-q7} + add sp, sp, #88 + + vdup.32 q15, r3 // output_zp + + ldr r4, [sp] // out_multiplier + vdup.32 q14, r4 + + ldr r5, [sp, #4] // left_shift + vdup.32 q13, r5 + + ldr r6, [sp, #8] // right_shift + vdup.32 q12, r6 + + ldr r7, [sp, #12] // acc_min + vdup.32 q11, r7 + + ldr r8, [sp, #16] // acc_max + vdup.32 q10, r8 + + mov r10, r0 + + LoopDepth8: + cmp r2, #8 + blt End + vld1.32 {q0}, [r1]! + vshl.s32 q0, q0, q13 + vqrdmulh.s32 q0, q0, q14 + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + vqmovn.s32 d4, q0 + + vld1.32 {q1}, [r1]! + vshl.s32 q1, q1, q13 + vqrdmulh.s32 q1, q1, q14 + vand q4, q1, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q1, q1, q4 + vrshl.s32 q1, q1, q12 + vadd.i32 q1, q1, q15 + vmax.s32 q1, q1, q11 + vmin.s32 q1, q1, q10 + vqmovn.s32 d5, q1 + vqmovn.s16 d4, q2 + + vst1.8 {d4}, [r10]! + + sub r2, r2, #8 + b LoopDepth8 + + LoopDepth4: + cmp r2, #4 + blt End + vld1.32 {q0}, [r1]! + + vshl.s32 q0, q0, q13 + vqrdmulh.s32 q0, q0, q14 + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + + vqmovn.s32 d0, q0 + vqmovn.s16 d0, q0 + + vst1.8 {d0[0]}, [r10]! + vst1.8 {d0[1]}, [r10]! + vst1.8 {d0[2]}, [r10]! + vst1.8 {d0[3]}, [r10]! + + sub r2, r2, #4 + b LoopDepth4 + End: + sub sp, sp, #88 + vpop {q4-q7} + pop {r4-r8, r10} + bx lr + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8PostAlign4PerChannel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8PostAlign4PerChannel.S new file mode 100644 index 0000000000000000000000000000000000000000..40bbb8a40c6490fc231a709bdaceeb12fce9567d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8PostAlign4PerChannel.S @@ -0,0 +1,123 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max); +// r0: dst, r1: buffer, r2: num_pixels, r3: output_zp, r4: out_multiplier, +// r5: left_shift, r6: right_shift, r7: acc_min, r8: acc_max + +asm_function ConvDwInt8PostAlign4PerChannel + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10} + vpush {q4-q7} + add sp, sp, #88 + + vdup.32 q15, r3 // output_zp + + ldr r4, [sp] // out_multiplier + ldr r5, [sp, #4] // left_shift + ldr r6, [sp, #8] // right_shift + + ldr r7, [sp, #12] // acc_min + vdup.32 q11, r7 + + ldr r8, [sp, #16] // acc_max + vdup.32 q10, r8 + + mov r10, r0 + + LoopDepth8: + cmp r2, #8 + blt End + vld1.32 {q0}, [r1]! + vld1.32 {q13}, [r5]! + vshl.s32 q0, q0, q13 + vld1.32 {q14}, [r4]! + vqrdmulh.s32 q0, q0, q14 + vld1.32 {q12}, [r6]! + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + vqmovn.s32 d4, q0 + + vld1.32 {q1}, [r1]! + vld1.32 {q13}, [r5]! + vshl.s32 q1, q1, q13 + vld1.32 {q14}, [r4]! + vqrdmulh.s32 q1, q1, q14 + vld1.32 {q12}, [r6]! + vand q4, q1, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q1, q1, q4 + vrshl.s32 q1, q1, q12 + vadd.i32 q1, q1, q15 + vmax.s32 q1, q1, q11 + vmin.s32 q1, q1, q10 + vqmovn.s32 d5, q1 + vqmovn.s16 d4, q2 + + vst1.8 {d4}, [r10]! + + sub r2, r2, #8 + b LoopDepth8 + + LoopDepth4: + cmp r2, #4 + blt End + vld1.32 {q0}, [r1]! + vld1.32 {q13}, [r5]! + vshl.s32 q0, q0, q13 + vld1.32 {q14}, [r4]! + vqrdmulh.s32 q0, q0, q14 + vld1.32 {q12}, [r6]! + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + + vqmovn.s32 d0, q0 + vqmovn.s16 d0, q0 + + vst1.8 {d0[0]}, [r10]! + vst1.8 {d0[1]}, [r10]! + vst1.8 {d0[2]}, [r10]! + vst1.8 {d0[3]}, [r10]! + + sub r2, r2, #4 + b LoopDepth4 + End: + sub sp, sp, #88 + vpop {q4-q7} + pop {r4-r8, r10} + bx lr + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8Row.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8Row.S new file mode 100644 index 0000000000000000000000000000000000000000..2833b4a2786bf656aa3dce4a20a60b6c8b8009dd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8Row.S @@ -0,0 +1,144 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, +// int output_channel, int input_step, int8_t input_zp) +// r0: output_ptr, r1: input_ptr, r2: weight_ptr, r3: num_pixels, +// r4: output_channel, r5: input_step, r6: input_zp, + +asm_function ConvDwInt8Row + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r9-r12, lr} + vpush {q4-q7} + add sp, sp, #104 + + cmp r3, #0 + beq End + + ldr r4, [sp] // channel + ldr r5, [sp, #4] // input_step + ldr r6, [sp, #8] // input_zp + vdup.8 d30, r6 + + mov r7, r0 + + LoopPixel: + mov r8, r1 // input + mov r10, r2 // weight + mov r11, r4 + + LoopDepth16In: + cmp r11, #16 + blt L8 + sub r11, r11, #16 + + vld1.8 {q0}, [r8]! + vld1.16 {q1, q2}, [r10]! // weight + + vsubl.s8 q3, d0, d30 // -zp + vld1.32 {q4, q5}, [r0]! + vmlal.s16 q4, d6, d2 + vmlal.s16 q5, d7, d3 + + cmp r11, #16 + blt LoopDepth16Out + LoopDepth16: + vst1.32 {q4, q5}, [r7]! + + vsubl.s8 q6, d1, d30 + vld1.32 {q7, q8}, [r0]! + vmlal.s16 q7, d12, d4 + vmlal.s16 q8, d13, d5 + vst1.32 {q7, q8}, [r7]! + + vld1.8 {q0}, [r8]! + vld1.16 {q1, q2}, [r10]! // weight + + vsubl.s8 q3, d0, d30 // -zp + vld1.32 {q4, q5}, [r0]! + vmlal.s16 q4, d6, d2 + vmlal.s16 q5, d7, d3 + + sub r11, r11, #16 + cmp r11, #16 + bge LoopDepth16 + + LoopDepth16Out: + vst1.32 {q4, q5}, [r7]! + + vsubl.s8 q6, d1, d30 + vld1.32 {q7, q8}, [r0]! + vmlal.s16 q7, d12, d4 + vmlal.s16 q8, d13, d5 + vst1.32 {q7, q8}, [r7]! + + L8: + cmp r11, #8 + blt L0 + + LoopDepth8: + vld1.8 {d0}, [r8]! + vld1.16 {d2, d3}, [r10]! // weight + + vsubl.s8 q2, d0, d30 // -zp + + vld1.32 {q3}, [r0]! + vmlal.s16 q3, d4, d2 + vst1.32 {q3}, [r7]! + + vld1.32 {q4}, [r0]! + vmlal.s16 q4, d5, d3 + vst1.32 {q4}, [r7]! + + sub r11, r11, #8 + cmp r11, #8 + bge LoopDepth8 + + L0: + cmp r11, #0 + beq LoopDepthEnd + + LoopDepth0: + ldrsb r12, [r8], #1 + ldrsh r9, [r10], #2 + sub r12, r12, r6 + + ldr lr, [r0], #4 + smlabb r12, r12, r9, lr + str r12, [r7], #4 + + subs r11, r11, #1 + bne L0 + + LoopDepthEnd: + add r1, r1, r5 + subs r3, r3, #1 + bne LoopPixel + + End: + sub sp, sp, #104 + vpop {q4-q7} + pop {r4-r8, r9-r12, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwFp32Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwFp32Center.S new file mode 100644 index 0000000000000000000000000000000000000000..f42a1b82c6766c8ea40bf615e95a1e9e6a470315 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwFp32Center.S @@ -0,0 +1,79 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, +// size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +// r0: dst, r1: src, r2: weight, r3: height, r4: width, #52: kernel_h, #56: kernel_w, #60: out_h_step +// #64: block_channel, #68: in_sh_step, #72: in_sw_step, #76: in_kh_step, #80: in_kw_step +asm_function DeconvDwFp32Center + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + + ldr r10, [sp, #80] // in_kw_step + ldr r11, [sp, #76] // in_kh_step + + LoopH: + ldr r0, [sp] // dst_w + ldr r1, [sp, #4] // src_w + ldr r4, [sp, #48] // width + LoopW: + mov r6, r0 // dst_kh + ldr r2, [sp, #8] // weight_kh + ldr r5, [sp, #52] // kernel_h + vld1.32 {q1}, [r1] + LoopKh: + mov r7, r6 // dst_kw + ldr r12, [sp, #56] // kernel_w + LoopKw: + vld1.32 {q0}, [r7] + vld1.32 {q2}, [r2]! + vmla.f32 q0, q1, q2 + vst1.32 {q0}, [r7] + add r7, r7, r10 + subs r12, r12, #1 + bne LoopKw + add r6, r6, r11 + subs r5, r5, #1 + bne LoopKh + ldr r12, [sp, #72] + add r0, r0, r12 + ldr r8, [sp, #64] + add r1, r1, r8 + subs r4, r4, #1 + bne LoopW + ldr r8, [sp, #68] + ldr r12, [sp] + add r12, r12, r8 + str r12, [sp] + ldr r8, [sp, #60] + ldr r12, [sp, #4] + add r12, r12, r8 + str r12, [sp, #4] + subs r3, r3, #1 + bne LoopH + + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwInt8Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwInt8Center.S new file mode 100644 index 0000000000000000000000000000000000000000..e7e6cd15fb168cdea06450d6c855434ab732f7fa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwInt8Center.S @@ -0,0 +1,79 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, +// size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +// r0: dst, r1: src, r2: weight, r3: height, r4: width, #52: kernel_h, #56: kernel_w, #60: out_h_step +// #64: block_channel, #68: in_sh_step, #72: in_sw_step, #76: in_kh_step, #80: in_kw_step +asm_function DeconvDwInt8Center + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + + ldr r10, [sp, #80] // in_kw_step + ldr r11, [sp, #76] // in_kh_step + + LoopH: + ldr r0, [sp] // dst_w + ldr r1, [sp, #4] // src_w + ldr r4, [sp, #48] // width + LoopW: + mov r6, r0 // dst_kh + ldr r2, [sp, #8] // weight_kh + ldr r5, [sp, #52] // kernel_h + vld1.16 {d2}, [r1] + LoopKh: + mov r7, r6 // dst_kw + ldr r12, [sp, #56] // kernel_w + LoopKw: + vld1.32 {q0}, [r7] + vld1.16 {d24}, [r2]! + vmlal.s16 q0, d2, d24 + vst1.32 {q0}, [r7] + add r7, r7, r10 + subs r12, r12, #1 + bne LoopKw + add r6, r6, r11 + subs r5, r5, #1 + bne LoopKh + ldr r12, [sp, #72] + add r0, r0, r12 + ldr r8, [sp, #64] + add r1, r1, r8 + subs r4, r4, #1 + bne LoopW + ldr r8, [sp, #68] + ldr r12, [sp] + add r12, r12, r8 + str r12, [sp] + ldr r8, [sp, #60] + ldr r12, [sp, #4] + add r12, r12, r8 + str r12, [sp, #4] + subs r3, r3, #1 + bne LoopH + + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwInt8Post.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwInt8Post.S new file mode 100644 index 0000000000000000000000000000000000000000..afc0cd3055fd8ada62fac93033a638fa13033375 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwInt8Post.S @@ -0,0 +1,84 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwInt8Post(int8_t *dst, int32_t *output_buffer, const int32_t *bias, int block_channel, int pixel_nums, +// int out_multiplier, int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, +// int32_t acc_max) +// r0: dst, r1: output_buffer, r2: bias, r3: block_channel, r4: pixel_nums, r5: out_multiplier, +// r6: left_shift, r7: right_shift, r8: out_zp, r9: acc_min, r10: acc_max + +asm_function DeconvDwInt8Post + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8} + add sp, sp, #20 + + vld1.32 {q9}, [r2] + ldr r4, [sp] + ldr r5, [sp, #4] + vdup.32 q14, r5 // out_multiplier + ldr r6, [sp, #8] + vdup.32 q13, r6 // left_shift + ldr r5, [sp, #12] + vdup.32 q12, r5 // right_shift + ldr r6, [sp, #16] + vdup.32 q15, r6 // output_zp + ldr r7, [sp, #20] + vdup.32 q11, r7 // acc_min + ldr r8, [sp, #24] + vdup.32 q10, r8 // acc_max + + LoopCount: + mov r8, r0 + vld1.32 {q0}, [r1]! + vand q0, q0, q9 + + vshl.s32 q0, q0, q13 + vqrdmulh.s32 q0, q0, q14 + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + + vqmovn.s32 d0, q0 + vqmovn.s16 d0, q0 + + vst1.8 {d0[0]}, [r8]! + vst1.8 {d0[1]}, [r8]! + vst1.8 {d0[2]}, [r8]! + vst1.8 {d0[3]}, [r8]! + add r0, r0, r3 + + sub r4, r4, #1 + cmp r4, #1 + bge LoopCount + End: + sub sp, sp, #20 + pop {r4-r8} + bx lr + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/IndirectGemmInt16to32_8x4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/IndirectGemmInt16to32_8x4.S new file mode 100644 index 0000000000000000000000000000000000000000..a0464c90984775cd56054f7513f283a5d11a7c1a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/IndirectGemmInt16to32_8x4.S @@ -0,0 +1,249 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void IndirectGemmInt16to32_8x4(int *output, short *input, short *weight, size_t kszie, size_t ic8, size_t oc4, size_t offset); +// r0: output, r1: input, r2: weight, r3: kszie, r4: ic8, r5: oc4, r6: offset +asm_function IndirectGemmInt16to32_8x4 + + .macro INIT_ZERO + // we could also use "vmov.s32 q12, #0" to initialize q12 by 0 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + .endm + + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10, lr} + + ldr r4, [sp, #28] + ldr r5, [sp, #32] + ldr r6, [sp, #36] + + vpush {q4-q7} + + LoopOc: + + mov r7, r3 + mov r8, r1 + + LoopKsize: + mov r10, r0 + INIT_ZERO + + // load input + vld1.16 {q0, q1}, [r8]! + // load weight + vld1.16 {q4}, [r2]! + vmull.s16 q8, d8, d0[0] + vmull.s16 q9, d8, d2[0] + // load weight + vld1.16 {q5}, [r2]! + vmlal.s16 q8, d9, d0[1] + vmlal.s16 q9, d9, d2[1] + // load input + vld1.16 {q2, q3}, [r8]! + vmlal.s16 q8, d10, d0[2] + vmlal.s16 q9, d10, d2[2] + vmlal.s16 q8, d11, d0[3] + vmlal.s16 q9, d11, d2[3] + // load weight + vld1.16 {q6, q7}, [r2]! + vmull.s16 q10, d8, d4[0] + vmull.s16 q11, d8, d6[0] + + subs r12, r4, #1 + beq LoopIcEnd + + LoopIc: + + vmlal.s16 q10, d9, d4[1] + vmlal.s16 q11, d9, d6[1] + vmlal.s16 q10, d10, d4[2] + vmlal.s16 q11, d10, d6[2] + vmlal.s16 q10, d11, d4[3] + vmlal.s16 q11, d11, d6[3] + + vmlal.s16 q8, d12, d1[0] + vmlal.s16 q9, d12, d3[0] + vmlal.s16 q8, d13, d1[1] + vmlal.s16 q9, d13, d3[1] + vmlal.s16 q8, d14, d1[2] + vmlal.s16 q9, d14, d3[2] + vmlal.s16 q8, d15, d1[3] + vmlal.s16 q9, d15, d3[3] + // load input + vld1.16 {q0, q1}, [r8]! + vmlal.s16 q10, d12, d5[0] + vmlal.s16 q11, d12, d7[0] + vmlal.s16 q10, d13, d5[1] + vmlal.s16 q11, d13, d7[1] + vmlal.s16 q10, d14, d5[2] + vmlal.s16 q11, d14, d7[2] + vmlal.s16 q10, d15, d5[3] + vmlal.s16 q11, d15, d7[3] + + // load input + vld1.16 {q2, q3}, [r8]! + vmlal.s16 q12, d8, d0[0] + vmlal.s16 q13, d8, d2[0] + vmlal.s16 q12, d9, d0[1] + vmlal.s16 q13, d9, d2[1] + vmlal.s16 q12, d10, d0[2] + vmlal.s16 q13, d10, d2[2] + vmlal.s16 q12, d11, d0[3] + vmlal.s16 q13, d11, d2[3] + + vmlal.s16 q14, d8, d4[0] + vmlal.s16 q15, d8, d6[0] + vmlal.s16 q14, d9, d4[1] + vmlal.s16 q15, d9, d6[1] + vmlal.s16 q14, d10, d4[2] + vmlal.s16 q15, d10, d6[2] + vmlal.s16 q14, d11, d4[3] + vmlal.s16 q15, d11, d6[3] + // load weight + vld1.16 {q4, q5}, [r2]! + vmlal.s16 q12, d12, d1[0] + vmlal.s16 q13, d12, d3[0] + vmlal.s16 q12, d13, d1[1] + vmlal.s16 q13, d13, d3[1] + vmlal.s16 q12, d14, d1[2] + vmlal.s16 q13, d14, d3[2] + vmlal.s16 q12, d15, d1[3] + vmlal.s16 q13, d15, d3[3] + // load input + vld1.16 {q0, q1}, [r8]! + vmlal.s16 q14, d12, d5[0] + vmlal.s16 q15, d12, d7[0] + vmlal.s16 q14, d13, d5[1] + vmlal.s16 q15, d13, d7[1] + vmlal.s16 q14, d14, d5[2] + vmlal.s16 q15, d14, d7[2] + vmlal.s16 q14, d15, d5[3] + vmlal.s16 q15, d15, d7[3] + // load input + vld1.16 {q2, q3}, [r8]! + vmlal.s16 q8, d8, d0[0] + vmlal.s16 q9, d8, d2[0] + vmlal.s16 q8, d9, d0[1] + vmlal.s16 q9, d9, d2[1] + // load weight + vld1.16 {q6, q7}, [r2]! + vmlal.s16 q8, d10, d0[2] + vmlal.s16 q9, d10, d2[2] + vmlal.s16 q8, d11, d0[3] + vmlal.s16 q9, d11, d2[3] + vmlal.s16 q10, d8, d4[0] + vmlal.s16 q11, d8, d6[0] + + subs r12, r12, #1 + bne LoopIc + + LoopIcEnd: + + vmlal.s16 q10, d9, d4[1] + vmlal.s16 q11, d9, d6[1] + vmlal.s16 q10, d10, d4[2] + vmlal.s16 q11, d10, d6[2] + vmlal.s16 q10, d11, d4[3] + vmlal.s16 q11, d11, d6[3] + + vmlal.s16 q8, d12, d1[0] + vmlal.s16 q9, d12, d3[0] + vmlal.s16 q8, d13, d1[1] + vmlal.s16 q9, d13, d3[1] + vmlal.s16 q8, d14, d1[2] + vmlal.s16 q9, d14, d3[2] + vmlal.s16 q8, d15, d1[3] + vmlal.s16 q9, d15, d3[3] + // load input + vld1.16 {q0, q1}, [r8]! + vmlal.s16 q10, d12, d5[0] + vmlal.s16 q11, d12, d7[0] + vmlal.s16 q10, d13, d5[1] + vst1.32 {q8}, [r10], r6 + vmlal.s16 q11, d13, d7[1] + vmlal.s16 q10, d14, d5[2] + vst1.32 {q9}, [r10], r6 + vmlal.s16 q11, d14, d7[2] + vmlal.s16 q10, d15, d5[3] + vmlal.s16 q11, d15, d7[3] + + // load input + vld1.s16 {q2, q3}, [r8]! + vmlal.s16 q12, d8, d0[0] + vmlal.s16 q13, d8, d2[0] + vmlal.s16 q12, d9, d0[1] + vst1.32 {q10}, [r10], r6 + vmlal.s16 q13, d9, d2[1] + vmlal.s16 q12, d10, d0[2] + vst1.32 {q11}, [r10], r6 + vmlal.s16 q13, d10, d2[2] + vmlal.s16 q12, d11, d0[3] + vmlal.s16 q13, d11, d2[3] + + vmlal.s16 q14, d8, d4[0] + vmlal.s16 q15, d8, d6[0] + vmlal.s16 q14, d9, d4[1] + vmlal.s16 q15, d9, d6[1] + vmlal.s16 q14, d10, d4[2] + vmlal.s16 q15, d10, d6[2] + vmlal.s16 q14, d11, d4[3] + vmlal.s16 q15, d11, d6[3] + + vmlal.s16 q12, d12, d1[0] + vmlal.s16 q13, d12, d3[0] + vmlal.s16 q12, d13, d1[1] + vmlal.s16 q13, d13, d3[1] + vmlal.s16 q12, d14, d1[2] + vmlal.s16 q13, d14, d3[2] + vmlal.s16 q12, d15, d1[3] + vmlal.s16 q13, d15, d3[3] + vst1.32 {q12}, [r10], r6 + vmlal.s16 q14, d12, d5[0] + vmlal.s16 q15, d12, d7[0] + vmlal.s16 q14, d13, d5[1] + vmlal.s16 q15, d13, d7[1] + vmlal.s16 q14, d14, d5[2] + vst1.32 {q13}, [r10], r6 + vmlal.s16 q15, d14, d7[2] + vmlal.s16 q14, d15, d5[3] + vmlal.s16 q15, d15, d7[3] + + vst1.32 {q14}, [r10], r6 + vst1.32 {q15}, [r10] + + subs r7, r7, #1 + add r0, r0, #16 + bne LoopKsize + + subs r5, r5, #1 + bne LoopOc + + vpop {q4-q7} + pop {r4-r8, r10, pc} + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/IndirectGemmInt8_2x4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/IndirectGemmInt8_2x4.S new file mode 100644 index 0000000000000000000000000000000000000000..dd459e1fe3c8a246c143f684c1e6329021ef4bbb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/IndirectGemmInt8_2x4.S @@ -0,0 +1,306 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void IndirectGemmInt8_2x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, +// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier, +// int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset); +// r0: output, r1: input, r2: weight, r3: bias, r4: kSize, r5: ic4, r6: oc, r7: offset +// r8: input_sum, r10: act_min, r11: act_max, r10: out_zp, r11: out_multiplier, r10: shift_before, r11: shift_after +asm_function IndirectGemmInt8_2x4 + + .macro INIT_BIAS + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + .endm + + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #96 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + mul r5, r4, r5 + mov r4, #1 + + LoopOc: + + mov r8, r4 + mov r12, r1 + + LoopKsize: + INIT_BIAS + mov r11, r0 + + // as some processors do not support sdot intrinsic, we use instruction word + // dp support is stilled judged dymaticly, instruction word is just used to ensure compilation + // according to https://static.docs.arm.com/ddi0596/g/ISA_A64_xml_v86A-2020-03_OPT.pdf + // the instruction word of sdot vd.4s, vn.16b, vm.4b[index] is + // 0100 1111 10Lm mmmm 1110 H0nn nnnd dddd + // mmmmm/nnnnn/ddddd is the number of neon register, HL is the high/low bit of index + + // load input for output 1-2 + vld1.8 {q0, q1}, [r12]! + // load weight for oc 1-2 + vld1.8 {q2, q3}, [r2]! + vmull.s8 q6, d0, d4 + vmull.s8 q7, d0, d6 + vmlal.s8 q6, d1, d5 + vmlal.s8 q7, d1, d7 + vpaddl.s16 q8, q6 + vpaddl.s16 q9, q7 + // load weight for oc 3-4 + vld1.8 {q4, q5}, [r2]! + vmull.s8 q6, d0, d8 + vmull.s8 q7, d0, d10 + vmlal.s8 q6, d1, d9 + vmlal.s8 q7, d1, d11 + + subs r10, r5, #1 + beq LoopIcEnd + + LoopIc: + // load input for output 1 + vld1.8 {q0}, [r12]! + vpadal.s16 q10, q6 + vpadal.s16 q11, q7 + vmull.s8 q6, d2, d4 + vmull.s8 q7, d2, d6 + vmlal.s8 q6, d3, d5 + vmlal.s8 q7, d3, d7 + vld1.8 {q2, q3}, [r2]! + vpadal.s16 q12, q6 + vpadal.s16 q13, q7 + vmull.s8 q6, d2, d8 + vmull.s8 q7, d2, d10 + vmlal.s8 q6, d3, d9 + vmlal.s8 q7, d3, d11 + vld1.8 {q4, q5}, [r2]! + vpadal.s16 q14, q6 + vpadal.s16 q15, q7 + vmull.s8 q6, d0, d4 + vmull.s8 q7, d0, d6 + vmlal.s8 q6, d1, d5 + vmlal.s8 q7, d1, d7 + vld1.8 {q1}, [r12]! + vpadal.s16 q8, q6 + vpadal.s16 q9, q7 + vmull.s8 q6, d0, d8 + vmull.s8 q7, d0, d10 + vmlal.s8 q6, d1, d9 + vmlal.s8 q7, d1, d11 + + subs r10, r10, #1 + bne LoopIc + + LoopIcEnd: + vpadal.s16 q10, q6 + vpadal.s16 q11, q7 + vmull.s8 q6, d2, d4 + vmull.s8 q7, d2, d6 + vmlal.s8 q6, d3, d5 + vmlal.s8 q7, d3, d7 + vpadal.s16 q12, q6 + vpadal.s16 q13, q7 + vmull.s8 q6, d2, d8 + vmull.s8 q7, d2, d10 + vmlal.s8 q6, d3, d9 + vmlal.s8 q7, d3, d11 + vpadal.s16 q14, q6 + vpadal.s16 q15, q7 + + // pairwise add + vpadd.i32 d16, d16, d17 + vpadd.i32 d18, d18, d19 + vpadd.i32 d20, d20, d21 + vpadd.i32 d22, d22, d23 + vpadd.i32 d24, d24, d25 + vpadd.i32 d26, d26, d27 + vpadd.i32 d28, d28, d29 + vpadd.i32 d30, d30, d31 + + vpadd.i32 d16, d16, d18 + vpadd.i32 d17, d20, d22 + vpadd.i32 d24, d24, d26 + vpadd.i32 d25, d28, d30 + + // load sum + ldr lr, [sp, #44] + cmp lr, #0 + beq NoSum + ldr r10, [sp, #16] + ldr lr, [sp, #48] + cmp lr, #0 + beq SymSum + ldr lr, [sp, #52] + vld1.32 {d0, d1}, [r10] + add r10, r10, lr + vld1.32 {d2, d3}, [r10] + b AddSum + SymSum: + vld1.32 {d0[], d1[]}, [r10]! + vld1.32 {d2[], d3[]}, [r10]! + AddSum: + vsub.i32 q8, q8, q0 + vsub.i32 q12, q12, q1 + NoSum: + cmp r3, #0 + beq NoBias + vld1.32 {d4, d5}, [r3] + vadd.i32 q8, q8, q2 + vadd.i32 q12, q12, q2 + + NoBias: + ldr lr, [sp, #48] + cmp lr, #0 + bne PerChannel + ldr lr, [sp, #36] + vld1.32 {d6[], d7[]}, [lr] + ldr lr, [sp, #32] + vld1.32 {d8[], d9[]}, [lr] + ldr lr, [sp, #40] + vld1.32 {d10[], d11[]}, [lr] + b QuantizeStart + PerChannel: + ldr lr, [sp, #36] + vld1.32 {d6, d7}, [lr] + ldr lr, [sp, #32] + vld1.32 {d8, d9}, [lr] + ldr lr, [sp, #40] + vld1.32 {d10, d11}, [lr] + QuantizeStart: + vshl.s32 q8, q8, q3 + vshl.s32 q12, q12, q3 + + vqrdmulh.s32 q8, q8, q4 + vqrdmulh.s32 q12, q12, q4 + + vand q3, q5, q8 + vshr.s32 q3, q3, #31 + vqadd.s32 q8, q8, q3 + vrshl.s32 q8, q8, q5 + vand q4, q5, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q12, q12, q4 + vrshl.s32 q12, q12, q5 + + ldr r10, [sp, #28] + vdup.32 q6, r10 + vadd.i32 q8, q8, q6 + vadd.i32 q12, q12, q6 + + ldr r10, [sp, #20] + vdup.32 q0, r10 + vmax.s32 q8, q8, q0 + vmax.s32 q12, q12, q0 + + ldr r10, [sp, #24] + vdup.32 q1, r10 + vmin.s32 q8, q8, q1 + vmin.s32 q12, q12, q1 + + vqmovn.s32 d30, q8 + vqmovn.s32 d31, q12 + vqmovn.s16 d0, q15 + + // prefetching is not preferred while writing results in spite of cache missing + // you could try prfm pstl2strm + WriteStart: + cmp r6, #1 + beq Write1 + cmp r6, #2 + beq Write2 + cmp r6, #3 + beq Write3 + b Write4 + Write1: + vst1.8 {d0[0]}, [r11], r7 + vst1.8 {d0[1]}, [r11] + add r0, r0, #1 + b WriteEnd + Write2: + vst1.16 {d0[0]}, [r11], r7 + vst1.16 {d0[1]}, [r11] + add r0, r0, #2 + b WriteEnd + Write3: + add r14, r11, #2 + vst1.16 {d0[0]}, [r11], r7 + vst1.16 {d0[1]}, [r11] + vst1.8 {d0[0]}, [r14], r7 + vst1.8 {d0[1]}, [r14] + add r0, r0, #3 + b WriteEnd + Write4: + vst1.32 {d0[0]}, [r11], r7 + vst1.32 {d0[1]}, [r11] + add r0, r0, #4 + + WriteEnd: + + subs r8, r8, #1 + bne LoopKsize + + cmp r6, #4 + ble LoopOcEnd + ldr lr, [sp, #48] + cmp lr, #0 + beq NoChannelForward + ldr lr, [sp, #44] + cmp lr, #0 + beq NoSumForward + ldr lr, [sp, #16] + add lr, lr, #16 + str lr, [sp, #16] + NoSumForward: + ldr lr, [sp, #36] + add lr, lr, #16 + str lr, [sp, #36] + ldr lr, [sp, #32] + add lr, lr, #16 + str lr, [sp, #32] + ldr lr, [sp, #40] + add lr, lr, #16 + str lr, [sp, #40] + NoChannelForward: + sub r6, r6, #4 + cmp r3, #0 + beq NoStepFowrard + add r3, r3, #16 + NoStepFowrard: + b LoopOc + +LoopOcEnd: + sub sp, sp, #96 + vpop {q4-q7} + pop {r4-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatVecMulFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatVecMulFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..f45eb0d7f71544456ee4056a9f785f664b08ddeb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatVecMulFp32.S @@ -0,0 +1,195 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: col + +asm_function MatVecMulFp32 + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r9, r10, r11, lr} + add sp, sp, #52 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + + mov r10, #4 + mul r10, r10, r5 // stride = depth * sizeof(float) + mov r11, #4 + mul r11, r11, r10 // stride x 4 + + cmp r6, #4 + blt Col1Loop + +Col4Loop: + mov r7, r0 // reload a(vector) ptr + mov r9, r1 // reload b(matrix) ptr + mov r8, r5 // reload depth value + + veor q9, q9, q9 + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q15, q15, q15 + + cmp r8, #4 + blt Col4Depth1 + + Col4Depth4: + vld1.f32 {q8}, [r7]! + add lr, r9, r10 + vld1.f32 {q0}, [r9]! + vld1.f32 {q1}, [lr], r10 + vld1.f32 {q2}, [lr], r10 + vld1.f32 {q3}, [lr] + + vmla.f32 q9, q8, q0 + vmla.f32 q10, q8, q1 + vmla.f32 q11, q8, q2 + vmla.f32 q12, q8, q3 + sub r8, r8, #4 + cmp r8, #4 + bge Col4Depth4 + + vpadd.f32 d26, d18, d20 + vpadd.f32 d27, d19, d21 + vpadd.f32 d28, d22, d24 + vpadd.f32 d29, d23, d25 + vadd.f32 d30, d26, d27 + vadd.f32 d31, d28, d29 + cmp r8, #0 + beq Col4End + + Col4Depth1: + vld1.f32 {d0[0]}, [r7]! + add lr, r9, r10 + vld1.f32 {d2[0]}, [r9]! + vld1.f32 {d2[1]}, [lr], r10 + vld1.f32 {d3[0]}, [lr], r10 + vld1.f32 {d3[1]}, [lr] + + vmla.f32 q15, q1, d0[0] + subs r8, r8, #1 + bne Col4Depth1 + + Col4End: + cmp r3, #0 + beq Col4Activation + vld1.f32 {q13}, [r3]! + vadd.f32 q15, q15, q13 + + Col4Activation: + cmp r4, #3 + beq Col4Relu6 + cmp r4, #1 + beq Col4Relu + b Col4Write + + Col4Relu6: + vmov.i32 q12, #6 + vcvt.f32.s32 q12, q12 + vmin.f32 q15, q15, q12 + + Col4Relu: + veor q13, q13, q13 + vmax.f32 q15, q15, q13 + + Col4Write: + vst1.f32 {q15}, [r2]! + subs r6, r6, #4 + beq End + add r1, r1, r11 + cmp r6, #4 + bge Col4Loop + +Col1Loop: + mov r7, r0 // reload a(vector) ptr + mov r9, r1 // reload b(matrix) ptr + mov r8, r5 // reload depth value + veor q10, q10, q10 + veor q13, q13, q13 + veor q15, q15, q15 + + cmp r8, #4 + blt Col1Depth1 + + Col1Depth4: + vld1.f32 {q0}, [r7]! + vld1.f32 {q1}, [r9]! + + vmla.f32 q10, q1, q0 + sub r8, r8, #4 + cmp r8, #4 + bge Col1Depth4 + + vpadd.f32 d24, d20, d22 + vpadd.f32 d25, d21, d23 + vadd.f32 d30, d24, d25 + cmp r8, #0 + beq Col1End + + Col1Depth1: + vld1.f32 {d0[0]}, [r7]! + vld1.f32 {d2[0]}, [r9]! + + vmla.f32 d30, d2, d0[0] + subs r8, r8, #1 + bne Col1Depth1 + + Col1End: + cmp r3, #0 + beq Col1Activation + vld1.f32 {d28[0]}, [r3]! + vadd.f32 d30, d30, d28 + + Col1Activation: + cmp r4, #3 + beq Col1Relu6 + cmp r4, #1 + beq Col1Relu + b Col1Write + + Col1Relu6: + vmov.i32 d26, #6 + vcvt.f32.s32 d26, d26 + vmin.f32 d30, d30, d26 + + Col1Relu: + veor d24, d24, d24 + vmax.f32 d30, d30, d24 + + Col1Write: + vst1.f32 {d30[0]}, [r2]! + subs r6, r6, #1 + beq End + add r1, r1, r10 + b Col1Loop + +End: + sub sp, sp, #52 + pop {r0-r8, r9, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..f36fe067acd6f5c4b65c443fe2f65e430ef709bd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32.S @@ -0,0 +1,381 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeNhwc, size_t WriteWino) +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: row +// r7: col +// r8: stride +// lr: writeNhwc/writeWino + +asm_function MatmulFloatNeon32 + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + add sp, sp, #48 + + ldr r5, [sp, #4] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + mov lr, #32 // sizeof(float) * 8 + mul r12, r5, lr // block stride of lhs/rhs: sizeof(float) * 8 * depth + ldr lr, [sp, #24] + cmp lr, #0 + beq NoWinoSteps + mov lr, #4 + mul r11, r7, r8 // stride * col * sizeof(float) + mul r11, r11, lr + mov lr, #32 + mul r10, r8, lr // stride * 8 * sizeof(float) +NoWinoSteps: + mov lr, #4 + mul r8, r8, lr // stride * sizeof(float) + +LoopCol: + ldr r6, [sp, #8] // reload lhs row + ldr r0, [sp, #-48] // reload lhs ptr + ldr r2, [sp, #-40] // reload dst ptr + + LoopRow: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r5, [sp, #4] // reload depth + veor q8, q8, q8 + veor q9, q9, q9 + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + + LoopDepth: + vld1.32 {q0}, [r0]! + vld1.32 {q1, q2}, [r1]! + vmla.f32 q8, q1, d0[0] + vmla.f32 q9, q2, d0[0] + vmla.f32 q10, q1, d0[1] + vmla.f32 q11, q2, d0[1] + vmla.f32 q12, q1, d1[0] + vmla.f32 q13, q2, d1[0] + vmla.f32 q14, q1, d1[1] + vmla.f32 q15, q2, d1[1] + + subs r5, r5, #1 + bne LoopDepth + + Bias: + cmp r3, #0 + beq Activation + vld1.32 {q0}, [r3]! + vld1.32 {q1}, [r3] + sub r3, r3, #16 + vadd.f32 q8, q8, q0 + vadd.f32 q9, q9, q1 + vadd.f32 q10, q10, q0 + vadd.f32 q11, q11, q1 + vadd.f32 q12, q12, q0 + vadd.f32 q13, q13, q1 + vadd.f32 q14, q14, q0 + vadd.f32 q15, q15, q1 + + Activation: + ldr lr, [sp] + cmp lr, #3 + beq Relu6 + cmp lr, #1 + beq Relu + b Write + + Relu6: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q8, q8, q2 + vmin.f32 q9, q9, q2 + vmin.f32 q10, q10, q2 + vmin.f32 q11, q11, q2 + vmin.f32 q12, q12, q2 + vmin.f32 q13, q13, q2 + vmin.f32 q14, q14, q2 + vmin.f32 q15, q15, q2 + + Relu: + veor q3, q3, q3 + vmax.f32 q8, q8, q3 + vmax.f32 q9, q9, q3 + vmax.f32 q10, q10, q3 + vmax.f32 q11, q11, q3 + vmax.f32 q12, q12, q3 + vmax.f32 q13, q13, q3 + vmax.f32 q14, q14, q3 + vmax.f32 q15, q15, q3 + + Write: + ldr lr, [sp, #24] + cmp lr, #0 + bne WriteWino + ldr lr, [sp, #20] + cmp lr, #0 + beq WriteC8 + cmp r7, #1 + beq Write1 + cmp r7, #2 + beq Write2 + cmp r7, #3 + beq Write3 + cmp r7, #4 + beq Write4 + cmp r7, #5 + beq Write5 + cmp r7, #6 + beq Write6 + cmp r7, #7 + beq Write7 + b Write8 + + Write1: + vst1.32 d16[0], [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20[0], [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24[0], [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28[0], [r2] + add r2, r2, r8 + b WriteEnd + Write2: + vst1.32 d16, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28, [r2] + add r2, r2, r8 + b WriteEnd + Write3: + add r4, r2, #8 + vst1.32 d16, [r2] + vst1.32 d17[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d20, [r2] + vst1.32 d21[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d24, [r2] + vst1.32 d25[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d28, [r2] + vst1.32 d29[0], [r4] + add r2, r2, r8 + b WriteEnd + Write4: + vst1.32 {d16, d17}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d20, d21}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d24, d25}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d28, d29}, [r2] + add r2, r2, r8 + b WriteEnd + Write5: + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30[0], [r4] + add r2, r2, r8 + b WriteEnd + Write6: + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18, [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22, [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26, [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30, [r4] + add r2, r2, r8 + b WriteEnd + Write7: + add lr, r2, #24 + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18, [r4] + vst1.32 d19[0], [lr] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22, [r4] + vst1.32 d23[0], [lr] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26, [r4] + vst1.32 d27[0], [lr] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30, [r4] + vst1.32 d31[0], [lr] + add r2, r2, r8 + b WriteEnd + WriteC8: + vst1.32 {q8, q9}, [r2]! + vst1.32 {q10, q11}, [r2]! + vst1.32 {q12, q13}, [r2]! + vst1.32 {q14, q15}, [r2]! + str r2, [sp, #-40] + b WriteEnd + WriteWino: + vst1.32 {q8, q9}, [r2] + add r2, r2, r11 + vst1.32 {q10, q11}, [r2] + add r2, r2, r11 + vst1.32 {q12, q13}, [r2] + add r2, r2, r11 + vst1.32 {q14, q15}, [r2] + add r2, r2, r11 + b WriteEnd + Write8: + vst1.32 {q8, q9}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q10, q11}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q12, q13}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q14, q15}, [r2] + add r2, r2, r8 + + WriteEnd: + cmp r6, #4 + ble LoopRowEnd + sub r6, r6, #4 // lhs row - 4 + b LoopRow + + LoopRowEnd: + ldr r1, [sp, #-44] + add r1, r1, r12 // rhs ptr + stride + str r1, [sp, #-44] + cmp r3, #0 + beq NoBiasStep + add r3, r3, #32 // bias ptr + stride + NoBiasStep: + ldr lr, [sp, #24] + cmp lr, #0 + bne WinoDstStep + ldr lr, [sp, #20] + cmp lr, #0 + beq NoDstStep + ldr r2, [sp, #-40] + add r2, r2, #32 // dst ptr + stride + str r2, [sp, #-40] + b NoDstStep + WinoDstStep: + ldr r2, [sp, #-40] + add r2, r2, r10 + str r2, [sp, #-40] + NoDstStep: + cmp r7, #8 + ble LoopColEnd + sub r7, r7, #8 // rhs col - 8 + b LoopCol + +LoopColEnd: + sub sp, sp, #48 + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32Opt.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32Opt.S new file mode 100644 index 0000000000000000000000000000000000000000..83d6113f4006ff1da3c8a573469104d485e5fbce --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32Opt.S @@ -0,0 +1,422 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: row +// r7: col +// r8: stride +// lr: writeNhwc/writeWino + +asm_function MatmulFloatNeon32Opt + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + add sp, sp, #48 + + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + mov lr, #16 // sizeof(float) * 4 + mul r12, r5, lr // block stride of lhs/rhs: sizeof(float) * 4 * depth + ldr lr, [sp, #20] + cmp lr, #0 + bne NoC8Steps + mov lr, #32 + mul r10, r6, lr +NoC8Steps: + cmp lr, #2 + bne NoWinoSteps + mov lr, #4 + mul r11, r7, r8 // stride * col * sizeof(float) + mul r11, r11, lr + mov lr, #32 + mul r10, r8, lr // stride * 8 * sizeof(float) +NoWinoSteps: + mov lr, #4 + mul r8, r8, lr // stride * sizeof(float) + +LoopRow: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r7, [sp, #12] // reload rhs col + ldr r3, [sp, #-36] // reload bias ptr + + LoopCol: + ldr lr, [sp, #20] + cmp lr, #0 + beq NoReloadDst + ldr r2, [sp, #-40] // reload dst ptr + NoReloadDst: + ldr r0, [sp, #-48] // reload lhs ptr + ldr r5, [sp, #4] // reload depth + vld1.32 {q0}, [r0]! + vld1.32 {q1, q2}, [r1]! + vmul.f32 q8, q1, d0[0] + vmul.f32 q9, q2, d0[0] + vmul.f32 q10, q1, d0[1] + vmul.f32 q11, q2, d0[1] + vmul.f32 q12, q1, d1[0] + vmul.f32 q13, q2, d1[0] + vmul.f32 q14, q1, d1[1] + vmul.f32 q15, q2, d1[1] + + subs r5, r5, #1 + beq Bias + + LoopDepth: + vld1.32 {q0}, [r0]! + vld1.32 {q1, q2}, [r1]! + vmla.f32 q8, q1, d0[0] + vmla.f32 q9, q2, d0[0] + vmla.f32 q10, q1, d0[1] + vmla.f32 q11, q2, d0[1] + vmla.f32 q12, q1, d1[0] + vmla.f32 q13, q2, d1[0] + vmla.f32 q14, q1, d1[1] + vmla.f32 q15, q2, d1[1] + + subs r5, r5, #1 + bne LoopDepth + + Bias: + cmp r3, #0 + beq Activation + vld1.32 {q0}, [r3]! + vld1.32 {q1}, [r3]! + vadd.f32 q8, q8, q0 + vadd.f32 q9, q9, q1 + vadd.f32 q10, q10, q0 + vadd.f32 q11, q11, q1 + vadd.f32 q12, q12, q0 + vadd.f32 q13, q13, q1 + vadd.f32 q14, q14, q0 + vadd.f32 q15, q15, q1 + + Activation: + ldr lr, [sp] + cmp lr, #3 + beq Relu6 + cmp lr, #1 + beq Relu + b Write + + Relu6: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q8, q8, q2 + vmin.f32 q9, q9, q2 + vmin.f32 q10, q10, q2 + vmin.f32 q11, q11, q2 + vmin.f32 q12, q12, q2 + vmin.f32 q13, q13, q2 + vmin.f32 q14, q14, q2 + vmin.f32 q15, q15, q2 + + Relu: + veor q3, q3, q3 + vmax.f32 q8, q8, q3 + vmax.f32 q9, q9, q3 + vmax.f32 q10, q10, q3 + vmax.f32 q11, q11, q3 + vmax.f32 q12, q12, q3 + vmax.f32 q13, q13, q3 + vmax.f32 q14, q14, q3 + vmax.f32 q15, q15, q3 + + Write: + ldr lr, [sp, #20] + cmp lr, #2 + beq WriteWino + cmp lr, #0 + beq WriteC8 + cmp r7, #1 + beq Write1 + cmp r7, #2 + beq Write2 + cmp r7, #3 + beq Write3 + cmp r7, #4 + beq Write4 + cmp r7, #5 + beq Write5 + cmp r7, #6 + beq Write6 + cmp r7, #7 + beq Write7 + b Write8 + + Write1: + add lr, r2, #4 + str lr, [sp, #-40] + vst1.32 d16[0], [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20[0], [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24[0], [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28[0], [r2] + add r2, r2, r8 + add r2, r2, #4 + b WriteEnd + Write2: + add lr, r2, #8 + str lr, [sp, #-40] + vst1.32 d16, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28, [r2] + add r2, r2, r8 + add r2, r2, #8 + b WriteEnd + Write3: + add lr, r2, #12 + str lr, [sp, #-40] + add r4, r2, #8 + vst1.32 d16, [r2] + vst1.32 d17[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d20, [r2] + vst1.32 d21[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d24, [r2] + vst1.32 d25[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d28, [r2] + vst1.32 d29[0], [r4] + add r2, r2, r8 + add r2, r2, #12 + b WriteEnd + Write4: + add lr, r2, #16 + str lr, [sp, #-40] + vst1.32 {d16, d17}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d20, d21}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d24, d25}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d28, d29}, [r2] + add r2, r2, r8 + add r2, r2, #16 + b WriteEnd + Write5: + add lr, r2, #20 + str lr, [sp, #-40] + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30[0], [r4] + add r2, r2, r8 + add r2, r2, #20 + b WriteEnd + Write6: + add lr, r2, #24 + str lr, [sp, #-40] + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18, [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22, [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26, [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30, [r4] + add r2, r2, r8 + add r2, r2, #24 + b WriteEnd + Write7: + add lr, r2, #28 + str lr, [sp, #-40] + add lr, r2, #24 + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18, [r4] + vst1.32 d19[0], [lr] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22, [r4] + vst1.32 d23[0], [lr] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26, [r4] + vst1.32 d27[0], [lr] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30, [r4] + vst1.32 d31[0], [lr] + add r2, r2, r8 + add r2, r2, #28 + b WriteEnd + WriteC8: + mov lr, r2 + vst1.32 {q8, q9}, [lr]! + vst1.32 {q10, q11}, [lr]! + vst1.32 {q12, q13}, [lr]! + vst1.32 {q14, q15}, [lr]! + add r2, r2, r10 + b WriteEnd + WriteWino: + add lr, r2, r10 + vst1.32 {q8, q9}, [r2] + add r2, r2, r11 + vst1.32 {q10, q11}, [r2] + add r2, r2, r11 + vst1.32 {q12, q13}, [r2] + add r2, r2, r11 + vst1.32 {q14, q15}, [r2] + str lr, [sp, #-40] + b WriteEnd + Write8: + add lr, r2, #32 + str lr, [sp, #-40] + vst1.32 {q8, q9}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q10, q11}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q12, q13}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q14, q15}, [r2] + add r2, r2, r8 + add r2, r2, #32 + + WriteEnd: + cmp r7, #8 + ble LoopColEnd + sub r7, r7, #8 // rhs col - 8 + b LoopCol + + LoopColEnd: + ldr r0, [sp, #-48] + add r0, r0, r12 // rhs ptr + stride + str r0, [sp, #-48] + ldr lr, [sp, #20] + cmp lr, #0 + beq C8DstStep + cmp lr, #2 + beq WinoDstStep + mov lr, #4 + ldr r7, [sp, #12] // reload rhs col + mul lr, lr, r7 + sub r2, r2, lr + str r2, [sp, #-40] + b NoDstStep + C8DstStep: + ldr lr, [sp, #-40] + add r2, lr, #128 + str r2, [sp, #-40] + b NoDstStep + WinoDstStep: + add r2, r2, r10 + str r2, [sp, #-40] + NoDstStep: + cmp r6, #4 + ble LoopRowEnd + sub r6, r6, #4 // lhs row - 4 + b LoopRow + +LoopRowEnd: + sub sp, sp, #48 + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32Opt12x4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32Opt12x4.S new file mode 100644 index 0000000000000000000000000000000000000000..253700999695be9a74913aa62f9f169d133179a3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32Opt12x4.S @@ -0,0 +1,578 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon32Opt12x4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: row +// r7: col +// r8: stride +// lr: OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 + +asm_function MatmulFloatNeon32Opt12x4 + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #112 + + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + mov lr, #48 // sizeof(float) * 12 + mul r12, r5, lr // block stride of lhs: sizeof(float) * 12 * depth + mov lr, #4 + mul r8, r8, lr // stride * sizeof(float) + +LoopRowStart: + cmp r6, #4 + ble LoopRow4 + cmp r6, #8 + ble LoopRow8 + +LoopRow: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r7, [sp, #12] // reload rhs col + ldr r3, [sp, #-36] // reload bias ptr + + LoopCol: + ldr r2, [sp, #-40] // reload dst ptr + ldr r0, [sp, #-48] // reload lhs ptr + ldr r5, [sp, #4] // reload depth + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmul.f32 q4, q3, d0[0] + vmul.f32 q5, q3, d0[1] + vmul.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmul.f32 q7, q3, d1[1] + + vmul.f32 q8, q3, d2[0] + vmul.f32 q9, q3, d2[1] + vmul.f32 q10, q3, d3[0] + vmul.f32 q11, q3, d3[1] + + vmul.f32 q12, q3, d4[0] + vmul.f32 q13, q3, d4[1] + vmul.f32 q14, q3, d5[0] + vmul.f32 q15, q3, d5[1] + + subs r5, r5, #1 + beq Bias + + LoopDepth: + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmla.f32 q4, q3, d0[0] + vmla.f32 q5, q3, d0[1] + vmla.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmla.f32 q7, q3, d1[1] + + vmla.f32 q8, q3, d2[0] + vmla.f32 q9, q3, d2[1] + vmla.f32 q10, q3, d3[0] + vmla.f32 q11, q3, d3[1] + + vmla.f32 q12, q3, d4[0] + vmla.f32 q13, q3, d4[1] + vmla.f32 q14, q3, d5[0] + vmla.f32 q15, q3, d5[1] + + subs r5, r5, #1 + bne LoopDepth + + Bias: + cmp r3, #0 + beq Activation + vld1.32 {q0}, [r3]! + vadd.f32 q4, q4, q0 + vadd.f32 q5, q5, q0 + vadd.f32 q6, q6, q0 + vadd.f32 q7, q7, q0 + vadd.f32 q8, q8, q0 + vadd.f32 q9, q9, q0 + vadd.f32 q10, q10, q0 + vadd.f32 q11, q11, q0 + vadd.f32 q12, q12, q0 + vadd.f32 q13, q13, q0 + vadd.f32 q14, q14, q0 + vadd.f32 q15, q15, q0 + + Activation: + ldr lr, [sp] + cmp lr, #3 + beq Relu6 + cmp lr, #1 + beq Relu + b Write + + Relu6: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q4, q4, q2 + vmin.f32 q5, q5, q2 + vmin.f32 q6, q6, q2 + vmin.f32 q7, q7, q2 + vmin.f32 q8, q8, q2 + vmin.f32 q9, q9, q2 + vmin.f32 q10, q10, q2 + vmin.f32 q11, q11, q2 + vmin.f32 q12, q12, q2 + vmin.f32 q13, q13, q2 + vmin.f32 q14, q14, q2 + vmin.f32 q15, q15, q2 + + Relu: + veor q3, q3, q3 + vmax.f32 q4, q4, q3 + vmax.f32 q5, q5, q3 + vmax.f32 q6, q6, q3 + vmax.f32 q7, q7, q3 + vmax.f32 q8, q8, q3 + vmax.f32 q9, q9, q3 + vmax.f32 q10, q10, q3 + vmax.f32 q11, q11, q3 + vmax.f32 q12, q12, q3 + vmax.f32 q13, q13, q3 + vmax.f32 q14, q14, q3 + vmax.f32 q15, q15, q3 + b Write + +LoopRow8: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r7, [sp, #12] // reload rhs col + ldr r3, [sp, #-36] // reload bias ptr + + LoopCol_R8: + ldr r2, [sp, #-40] // reload dst ptr + ldr r0, [sp, #-48] // reload lhs ptr + ldr r5, [sp, #4] // reload depth + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmul.f32 q4, q3, d0[0] + vmul.f32 q5, q3, d0[1] + vmul.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmul.f32 q7, q3, d1[1] + + vmul.f32 q8, q3, d2[0] + vmul.f32 q9, q3, d2[1] + vmul.f32 q10, q3, d3[0] + vmul.f32 q11, q3, d3[1] + + subs r5, r5, #1 + beq Bias_R8 + + LoopDepth_R8: + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmla.f32 q4, q3, d0[0] + vmla.f32 q5, q3, d0[1] + vmla.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmla.f32 q7, q3, d1[1] + + vmla.f32 q8, q3, d2[0] + vmla.f32 q9, q3, d2[1] + vmla.f32 q10, q3, d3[0] + vmla.f32 q11, q3, d3[1] + + subs r5, r5, #1 + bne LoopDepth_R8 + + Bias_R8: + cmp r3, #0 + beq Activation_R8 + vld1.32 {q0}, [r3]! + vadd.f32 q4, q4, q0 + vadd.f32 q5, q5, q0 + vadd.f32 q6, q6, q0 + vadd.f32 q7, q7, q0 + vadd.f32 q8, q8, q0 + vadd.f32 q9, q9, q0 + vadd.f32 q10, q10, q0 + vadd.f32 q11, q11, q0 + + Activation_R8: + ldr lr, [sp] + cmp lr, #3 + beq Relu6_R8 + cmp lr, #1 + beq Relu_R8 + b Write + + Relu6_R8: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q4, q4, q2 + vmin.f32 q5, q5, q2 + vmin.f32 q6, q6, q2 + vmin.f32 q7, q7, q2 + vmin.f32 q8, q8, q2 + vmin.f32 q9, q9, q2 + vmin.f32 q10, q10, q2 + vmin.f32 q11, q11, q2 + + Relu_R8: + veor q3, q3, q3 + vmax.f32 q4, q4, q3 + vmax.f32 q5, q5, q3 + vmax.f32 q6, q6, q3 + vmax.f32 q7, q7, q3 + vmax.f32 q8, q8, q3 + vmax.f32 q9, q9, q3 + vmax.f32 q10, q10, q3 + vmax.f32 q11, q11, q3 + b Write + +LoopRow4: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r7, [sp, #12] // reload rhs col + ldr r3, [sp, #-36] // reload bias ptr + + LoopCol_R4: + ldr r2, [sp, #-40] // reload dst ptr + ldr r0, [sp, #-48] // reload lhs ptr + ldr r5, [sp, #4] // reload depth + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmul.f32 q4, q3, d0[0] + vmul.f32 q5, q3, d0[1] + vmul.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmul.f32 q7, q3, d1[1] + + subs r5, r5, #1 + beq Bias_R4 + + LoopDepth_R4: + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmla.f32 q4, q3, d0[0] + vmla.f32 q5, q3, d0[1] + vmla.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmla.f32 q7, q3, d1[1] + + subs r5, r5, #1 + bne LoopDepth_R4 + + Bias_R4: + cmp r3, #0 + beq Activation_R4 + vld1.32 {q0}, [r3]! + vadd.f32 q4, q4, q0 + vadd.f32 q5, q5, q0 + vadd.f32 q6, q6, q0 + vadd.f32 q7, q7, q0 + + Activation_R4: + ldr lr, [sp] + cmp lr, #3 + beq Relu6_R4 + cmp lr, #1 + beq Relu_R4 + b Write + + Relu6_R4: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q4, q4, q2 + vmin.f32 q5, q5, q2 + vmin.f32 q6, q6, q2 + vmin.f32 q7, q7, q2 + + Relu_R4: + veor q3, q3, q3 + vmax.f32 q4, q4, q3 + vmax.f32 q5, q5, q3 + vmax.f32 q6, q6, q3 + vmax.f32 q7, q7, q3 + + Write: + cmp r7, #1 + beq Write1 + cmp r7, #2 + beq Write2 + cmp r7, #3 + beq Write3 + b Write4 + + Write1: + add lr, r2, #4 + str lr, [sp, #-40] + vst1.32 d8[0], [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d10[0], [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d12[0], [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d14[0], [r2] + cmp r6, #4 + beq WriteEnd + add r2, r2, r8 + vst1.32 d16[0], [r2] + cmp r6, #5 + beq WriteEnd + add r2, r2, r8 + vst1.32 d18[0], [r2] + cmp r6, #6 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20[0], [r2] + cmp r6, #7 + beq WriteEnd + add r2, r2, r8 + vst1.32 d22[0], [r2] + cmp r6, #8 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24[0], [r2] + cmp r6, #9 + beq WriteEnd + add r2, r2, r8 + vst1.32 d26[0], [r2] + cmp r6, #10 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28[0], [r2] + cmp r6, #11 + beq WriteEnd + add r2, r2, r8 + vst1.32 d30[0], [r2] + add r2, r2, r8 + add r2, r2, #4 + b WriteEnd + Write2: + add lr, r2, #8 + str lr, [sp, #-40] + vst1.32 d8, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d10, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d12, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d14, [r2] + cmp r6, #4 + beq WriteEnd + add r2, r2, r8 + vst1.32 d16, [r2] + cmp r6, #5 + beq WriteEnd + add r2, r2, r8 + vst1.32 d18, [r2] + cmp r6, #6 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20, [r2] + cmp r6, #7 + beq WriteEnd + add r2, r2, r8 + vst1.32 d22, [r2] + cmp r6, #8 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24, [r2] + cmp r6, #9 + beq WriteEnd + add r2, r2, r8 + vst1.32 d26, [r2] + cmp r6, #10 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28, [r2] + cmp r6, #11 + beq WriteEnd + add r2, r2, r8 + vst1.32 d30, [r2] + add r2, r2, r8 + add r2, r2, #8 + b WriteEnd + Write3: + add lr, r2, #12 + str lr, [sp, #-40] + add r4, r2, #8 + vst1.32 d8, [r2] + vst1.32 d9[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d10, [r2] + vst1.32 d11[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d12, [r2] + vst1.32 d13[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d14, [r2] + vst1.32 d15[0], [r4] + cmp r6, #4 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d16, [r2] + vst1.32 d17[0], [r4] + cmp r6, #5 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d18, [r2] + vst1.32 d19[0], [r4] + cmp r6, #6 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d20, [r2] + vst1.32 d21[0], [r4] + cmp r6, #7 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d22, [r2] + vst1.32 d23[0], [r4] + cmp r6, #8 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d24, [r2] + vst1.32 d25[0], [r4] + cmp r6, #9 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d26, [r2] + vst1.32 d27[0], [r4] + cmp r6, #10 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d28, [r2] + vst1.32 d29[0], [r4] + cmp r6, #11 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d30, [r2] + vst1.32 d31[0], [r4] + add r2, r2, r8 + add r2, r2, #12 + b WriteEnd + Write4: + add lr, r2, #16 + str lr, [sp, #-40] + vst1.32 {d8, d9}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d10, d11}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d12, d13}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d14, d15}, [r2] + cmp r6, #4 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d16, d17}, [r2] + cmp r6, #5 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d18, d19}, [r2] + cmp r6, #6 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d20, d21}, [r2] + cmp r6, #7 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d22, d23}, [r2] + cmp r6, #8 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d24, d25}, [r2] + cmp r6, #9 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d26, d27}, [r2] + cmp r6, #10 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d28, d29}, [r2] + cmp r6, #11 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d30, d31}, [r2] + add r2, r2, r8 + add r2, r2, #16 + b WriteEnd + WriteEnd: + cmp r7, #4 + ble LoopColEnd + sub r7, r7, #4 // rhs col - 4 + b LoopCol + + LoopColEnd: + ldr r0, [sp, #-48] + add r0, r0, r12 // lhs ptr + stride + str r0, [sp, #-48] + mov lr, #4 + ldr r7, [sp, #12] // reload rhs col + mul lr, lr, r7 + sub r2, r2, lr + str r2, [sp, #-40] + cmp r6, #12 + ble LoopRowEnd + sub r6, r6, #12 // lhs row - 12 + b LoopRowStart + +LoopRowEnd: + sub sp, sp, #112 + vpop {q4-q7} + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulInt8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulInt8.S new file mode 100644 index 0000000000000000000000000000000000000000..6dc036dea1aad873753599f3fddef38f92fb2a69 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulInt8.S @@ -0,0 +1,298 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void MatmulInt8Neon32(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, +// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp, +// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel); +// #-52: a, #-48: b, #-44: dst, #-40: row +// #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp +// #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel + +asm_function MatmulInt8Neon32 + push {r0-r11, lr} + vpush {q4-q7} + add sp, sp, #116 + + ldr r4, [sp] // col + ldr r7, [sp, #40] // output stride + mov r8, #0 // output channels offset + ldr r10, [sp, #44] + cmp r10, #0 + beq L1 + ldr r6, [sp, #8] // load intpu_sums ptr if per_channel +L1: + cmp r4, #0 // if at the end of col + ble End1 + + ldr r0, [sp, #-52] // reload a ptr + ldr r3, [sp, #-40] // reset row counter + ldr r10, [sp, #44] + cmp r10, #0 + bne L2 + ldr r6, [sp, #8] // reload intpu_sums ptr if per_tensor +L2: + cmp r3, #0 // if at the end of row + ble End2 + + ldr r1, [sp, #-48] // reload b ptr + ldr r5, [sp, #4] // reset deep16 + vmov.i32 q6, #0 + vmov.i32 q7, #0 + vmov.i32 q8, #0 + vmov.i32 q9, #0 + vmov.i32 q10, #0 + vmov.i32 q11, #0 + vmov.i32 q12, #0 + vmov.i32 q13, #0 +L3: + cmp r5, #0 + beq End3 + + vld1.8 {d0, d1, d2, d3}, [r0]! + vld1.8 {d8, d9, d10, d11}, [r1]! + vmull.s8 q14, d0, d8 + vmull.s8 q2, d0, d10 + vmull.s8 q15, d2, d8 + vmull.s8 q3, d2, d10 + vmlal.s8 q14, d1, d9 + vmlal.s8 q2, d1, d11 + vmlal.s8 q15, d3, d9 + vmlal.s8 q3, d3, d11 + + vpadal.s16 q6, q14 + vpadal.s16 q7, q2 + vpadal.s16 q8, q15 + vpadal.s16 q9, q3 + + vld1.8 {d0, d1, d2, d3}, [r0]! + vmull.s8 q14, d0, d8 + vmull.s8 q2, d0, d10 + vmull.s8 q15, d2, d8 + vmull.s8 q3, d2, d10 + vmlal.s8 q14, d1, d9 + vmlal.s8 q2, d1, d11 + vmlal.s8 q15, d3, d9 + vmlal.s8 q3, d3, d11 + + vpadal.s16 q10, q14 + vpadal.s16 q11, q2 + vpadal.s16 q12, q15 + vpadal.s16 q13, q3 + sub r5, r5, #16 // deep16 -= 16 + b L3 + +End3: + vpadd.i32 d0, d12, d13 + vpadd.i32 d1, d14, d15 + vpadd.i32 d2, d16, d17 + vpadd.i32 d3, d18, d19 + vpadd.i32 d4, d20, d21 + vpadd.i32 d5, d22, d23 + vpadd.i32 d6, d24, d25 + vpadd.i32 d7, d26, d27 + + vpadd.i32 d28, d0, d1 + vpadd.i32 d29, d2, d3 + vpadd.i32 d30, d4, d5 + vpadd.i32 d31, d6, d7 + + // Add weight_bias + ldr r9, [sp, #12] // reload weight_bias ptr + add r9, r9, r8 + vld1.32 {d26}, [r9]! + vadd.i32 d28, d28, d26 + vadd.i32 d29, d29, d26 + vadd.i32 d30, d30, d26 + vadd.i32 d31, d31, d26 + + ldr r10, [sp, #44] + cmp r10, #0 + bgt PerChannel + +PerTensor: + // Subtract input_sums + vld1.32 {d24, d25}, [r6]! + vdup.32 d20, d24[0] + vdup.32 d21, d24[1] + vdup.32 d22, d25[0] + vdup.32 d23, d25[1] + vsub.s32 d28, d28, d20 + vsub.s32 d29, d29, d21 + vsub.s32 d30, d30, d22 + vsub.s32 d31, d31, d23 + + // Apply left shift + ldr r10, [sp, #32] + ldr r11, [r10]! + vdup.32 q9, r11 + vshl.s32 q14, q14, q9 + vshl.s32 q15, q15, q9 + + // Apply the fixed-point part of the multiplier + ldr r10, [sp, #28] + ldr r11, [r10] + vdup.32 q8, r11 + vqrdmulh.s32 q14, q14, q8 + vqrdmulh.s32 q15, q15, q8 + + // Apply right shift + ldr r10, [sp, #36] + ldr r11, [r10] + vdup.32 q7, r11 + vand q6, q7, q14 + vshr.s32 q6, q6, #31 + vqadd.s32 q14, q14, q6 + vrshl.s32 q14, q14, q7 + vand q5, q7, q15 + vshr.s32 q5, q5, #31 + vqadd.s32 q15, q15, q5 + vrshl.s32 q15, q15, q7 + b AddDstZP + +PerChannel: + // Subtract input_sums + vld1.32 {d24, d25, d26, d27}, [r6]! + vsub.s32 d28, d28, d24 + vsub.s32 d29, d29, d25 + vsub.s32 d30, d30, d26 + vsub.s32 d31, d31, d27 + + // Apply left shift + ldr r10, [sp, #32] + add r10, r10, r8 + vld1.32 {d23}, [r10] + vshl.s32 d28, d28, d23 + vshl.s32 d29, d29, d23 + vshl.s32 d30, d30, d23 + vshl.s32 d31, d31, d23 + + // Apply the fixed-point part of the multiplier + ldr r10, [sp, #28] + add r10, r10, r8 + vld1.32 {d22}, [r10] + vqrdmulh.s32 d28, d28, d22 + vqrdmulh.s32 d29, d29, d22 + vqrdmulh.s32 d30, d30, d22 + vqrdmulh.s32 d31, d31, d22 + + // Apply right shift + ldr r10, [sp, #36] + add r10, r10, r8 + vld1.32 {d21}, [r10] + vand d20, d21, d28 + vshr.s32 d20, d20, #31 + vqadd.s32 d28, d28, d20 + vrshl.s32 d28, d28, d21 + vand d19, d21, d29 + vshr.s32 d19, d19, #31 + vqadd.s32 d29, d29, d19 + vrshl.s32 d29, d29, d21 + vand d18, d21, d30 + vshr.s32 d18, d18, #31 + vqadd.s32 d30, d30, d18 + vrshl.s32 d30, d30, d21 + vand d17, d21, d31 + vshr.s32 d17, d17, #31 + vqadd.s32 d31, d31, d17 + vrshl.s32 d31, d31, d21 + +AddDstZP: + // Add the destination zero point + ldr r10, [sp, #24] + vdup.32 q4, r10 + vadd.i32 q14, q14, q4 + vadd.i32 q15, q15, q4 + + // Apply the act_min bound + ldr r10, [sp, #16] + vdup.32 q3, r10 + vmax.s32 q14, q14, q3 + vmax.s32 q15, q15, q3 + + // Apply the act_max bound + ldr r10, [sp, #20] + vdup.32 q2, r10 + vmin.s32 q14, q14, q2 + vmin.s32 q15, q15, q2 + + // Cast-and-saturate from int32 to int16 + vqmovn.s32 d28, q14 + vqmovn.s32 d29, q15 + + // Cast-and-saturate from int16 to int8 + vqmovn.s16 d30, q14 + + // start to write + cmp r4, #2 + bge WriteCol2 + cmp r4, #1 + beq WriteCol1 + b EndWrite + +WriteCol2: + vst1.16 {d30[0]}, [r2], r7 + cmp r3, #1 + beq EndWrite + vst1.16 {d30[1]}, [r2], r7 + cmp r3, #2 + beq EndWrite + vst1.16 {d30[2]}, [r2], r7 + cmp r3, #3 + beq EndWrite + vst1.16 {d30[3]}, [r2], r7 + b EndWrite + +WriteCol1: + vst1.8 {d30[0]}, [r2], r7 + cmp r3, #1 + beq EndWrite + vst1.8 {d30[2]}, [r2], r7 + cmp r3, #2 + beq EndWrite + vst1.8 {d30[4]}, [r2], r7 + cmp r3, #3 + beq EndWrite + vst1.8 {d30[6]}, [r2], r7 + b EndWrite + +EndWrite: + sub r3, r3, #4 // a row counter -= 4 + b L2 + +End2: + sub r4, r4, #2 // b col counter -= 2 + ldr r1, [sp, #-48] // load b ptr + ldr r9, [sp, #4] + mov r10, #2 + mul r9, r9, r10 // the stride of b + add r1, r1, r9 // b ptr + stride + str r1, [sp, #-48] + ldr r2, [sp, #-44] // load dst ptr + add r2, r2, #2 // dst ptr + offset + str r2, [sp, #-44] + add r8, r8, #8 // output channels offset + 2*sizeof(int) + b L1 + +End1: + sub sp, sp, #116 + vpop {q4-q7} + pop {r0-r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulInt8Opt.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulInt8Opt.S new file mode 100644 index 0000000000000000000000000000000000000000..16426bfd80a5a6eb0de27d06488f2b81f394c90a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulInt8Opt.S @@ -0,0 +1,300 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void MatmulInt8Neon32Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, +// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp, +// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel, +// int *filter_zp); +// #-48: a, #-44: b, #-40: dst, #-36: row +// #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp +// #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel, #48: filter_zp + +asm_function MatmulInt8Opt + push {r0-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #112 + + ldr r5, [sp, #4] + ldr r6, [sp, #8] // reload a_sums ptr + ldr r8, [sp, #40] + mov r10, #4 + mul r10, r10, r5 // lhs step + mov r11, #4 + mul r11, r11, r8 // dst step +LoopRow: + ldr r1, [sp, #-44] //reload rhs ptr + ldr r4, [sp] // reload rhs col + ldr lr, [sp, #-40] + vmov.32 d4[0], lr // reload dst ptr + ldr lr, [sp, #32] + vmov.32 d4[1], lr // reload left shift + ldr lr, [sp, #28] + vmov.32 d5[0], lr // reload multiplier + ldr lr, [sp, #36] + vmov.32 d5[1], lr // reload right shift + ldr r7, [sp, #48] // reload filter_zp + ldr r12, [sp, #12] // reload bias ptr + + LoopCol: + vmov.32 r2, d4[0] // reload dst ptr + ldr r0, [sp, #-48] //reload lhs ptr + ldr r5, [sp, #4] // reaload depth + + vmov.i32 q6, #0 + vmov.i32 q7, #0 + vmov.i32 q8, #0 + vmov.i32 q9, #0 + vmov.i32 q10, #0 + vmov.i32 q11, #0 + vmov.i32 q12, #0 + vmov.i32 q13, #0 + + LoopDepth: + vld1.8 {q0-q1}, [r0]! + vld1.8 {q4-q5}, [r1]! + vmull.s8 q14, d0, d8 + vmull.s8 q15, d2, d8 + vmlal.s8 q14, d1, d9 + vmlal.s8 q15, d3, d9 + vpadal.s16 q6, q14 + vpadal.s16 q8, q15 + vmull.s8 q14, d0, d10 + vmull.s8 q15, d2, d10 + vmlal.s8 q14, d1, d11 + vmlal.s8 q15, d3, d11 + vld1.8 {q0-q1}, [r0]! + + vpadal.s16 q7, q14 + vpadal.s16 q9, q15 + + vmull.s8 q14, d0, d8 + vmull.s8 q15, d2, d8 + vmlal.s8 q14, d1, d9 + vmlal.s8 q15, d3, d9 + vpadal.s16 q10, q14 + vpadal.s16 q12, q15 + vmull.s8 q14, d0, d10 + vmull.s8 q15, d2, d10 + vmlal.s8 q14, d1, d11 + vmlal.s8 q15, d3, d11 + + vpadal.s16 q11, q14 + vpadal.s16 q13, q15 + + cmp r5, #16 + ble LoopDepthEnd + sub r5, r5, #16 + b LoopDepth + + LoopDepthEnd: + vpadd.i32 d12, d12, d13 + vpadd.i32 d14, d14, d15 + vpadd.i32 d16, d16, d17 + vpadd.i32 d18, d18, d19 + vpadd.i32 d20, d20, d21 + vpadd.i32 d22, d22, d23 + vpadd.i32 d24, d24, d25 + vpadd.i32 d26, d26, d27 + + vpadd.i32 d28, d12, d14 + vpadd.i32 d29, d16, d18 + vpadd.i32 d30, d20, d22 + vpadd.i32 d31, d24, d26 + + Bias: + cmp r12, #0 + beq NoBias + vld1.32 {d26}, [r12]! + vadd.i32 d28, d28, d26 + vadd.i32 d29, d29, d26 + vadd.i32 d30, d30, d26 + vadd.i32 d31, d31, d26 + + NoBias: + ldr lr, [sp, #44] + cmp lr, #0 + bne PerChannel + + PerTensor: + vld1.32 {d24, d25}, [r6] + vdup.32 d20, d24[0] + vdup.32 d21, d24[1] + vdup.32 d22, d25[0] + vdup.32 d23, d25[1] + vsub.s32 d28, d28, d20 + vsub.s32 d29, d29, d21 + vsub.s32 d30, d30, d22 + vsub.s32 d31, d31, d23 + + vmov.32 lr, d4[1] + vld1.32 {d18[], d19[]}, [lr] + vshl.s32 q14, q14, q9 + vshl.s32 q15, q15, q9 + + vmov.32 lr, d5[0] + vld1.32 {d16[], d17[]}, [lr] + vqrdmulh.s32 q14, q14, q8 + vqrdmulh.s32 q15, q15, q8 + + vmov.32 lr, d5[1] + vld1.32 {d14[], d15[]}, [lr] + vand q6, q7, q14 + vshr.s32 q6, q6, #31 + vqadd.s32 q14, q14, q6 + vrshl.s32 q14, q14, q7 + vand q5, q7, q15 + vshr.s32 q5, q5, #31 + vqadd.s32 q15, q15, q5 + vrshl.s32 q15, q15, q7 + b Quantize + + PerChannel: + vld1.32 {d24, d25}, [r6] + vdup.32 d20, d24[0] + vdup.32 d21, d24[1] + vdup.32 d22, d25[0] + vdup.32 d23, d25[1] + vld1.32 {d19}, [r7]! + vmul.s32 d24, d20, d19 + vmul.s32 d25, d21, d19 + vmul.s32 d26, d22, d19 + vmul.s32 d27, d23, d19 + vsub.s32 d28, d28, d24 + vsub.s32 d29, d29, d25 + vsub.s32 d30, d30, d26 + vsub.s32 d31, d31, d27 + + vmov.32 lr, d4[1] + vld1.32 {d23}, [lr]! + vmov.32 d4[1], lr + vshl.s32 d28, d28, d23 + vshl.s32 d29, d29, d23 + vshl.s32 d30, d30, d23 + vshl.s32 d31, d31, d23 + + vmov.32 lr, d5[0] + vld1.32 {d22}, [lr]! + vmov.32 d5[0], lr + vqrdmulh.s32 d28, d28, d22 + vqrdmulh.s32 d29, d29, d22 + vqrdmulh.s32 d30, d30, d22 + vqrdmulh.s32 d31, d31, d22 + + vmov.32 lr, d5[1] + vld1.32 {d21}, [lr]! + vmov.32 d5[1], lr + vand d20, d21, d28 + vshr.s32 d20, d20, #31 + vqadd.s32 d28, d28, d20 + vrshl.s32 d28, d28, d21 + vand d19, d21, d29 + vshr.s32 d19, d19, #31 + vqadd.s32 d29, d29, d19 + vrshl.s32 d29, d29, d21 + vand d18, d21, d30 + vshr.s32 d18, d18, #31 + vqadd.s32 d30, d30, d18 + vrshl.s32 d30, d30, d21 + vand d17, d21, d31 + vshr.s32 d17, d17, #31 + vqadd.s32 d31, d31, d17 + vrshl.s32 d31, d31, d21 + + Quantize: + ldr lr, [sp, #24] + vdup.32 q0, lr + vadd.i32 q14, q14, q0 + vadd.i32 q15, q15, q0 + + ldr lr, [sp, #16] + vdup.32 q1, lr + vmax.s32 q14, q14, q1 + vmax.s32 q15, q15, q1 + + ldr lr, [sp, #20] + vdup.32 q0, lr + vmin.s32 q14, q14, q0 + vmin.s32 q15, q15, q0 + + vqmovn.s32 d28, q14 + vqmovn.s32 d29, q15 + + vqmovn.s16 d30, q14 + + cmp r4, #1 + beq Write1 + b Write2 + + Write1: + vmov.32 lr, d4[0] + add lr, lr, #1 + vmov.32 d4[0], lr + vst1.8 {d30[0]}, [r2], r8 + cmp r3, #1 + beq WriteEnd + vst1.8 {d30[2]}, [r2], r8 + cmp r3, #2 + beq WriteEnd + vst1.8 {d30[4]}, [r2], r8 + cmp r3, #3 + beq WriteEnd + vst1.8 {d30[6]}, [r2], r8 + b WriteEnd + + Write2: + vmov.32 lr, d4[0] + add lr, lr, #2 + vmov.32 d4[0], lr + vst1.16 {d30[0]}, [r2], r8 + cmp r3, #1 + beq WriteEnd + vst1.16 {d30[1]}, [r2], r8 + cmp r3, #2 + beq WriteEnd + vst1.16 {d30[2]}, [r2], r8 + cmp r3, #3 + beq WriteEnd + vst1.16 {d30[3]}, [r2], r8 + + WriteEnd: + cmp r4, #2 + ble LoopColEnd + sub r4, r4, #2 + b LoopCol + +LoopColEnd: + cmp r3, #4 + ble LoopRowEnd + ldr lr, [sp, #-48] + add lr, lr, r10 + str lr, [sp, #-48] + ldr lr, [sp, #-40] + add lr, lr, r11 + str lr, [sp, #-40] + sub r3, r3, #4 + add r6, r6, #16 + b LoopRow + +LoopRowEnd: + sub sp, sp, #112 + vpop {q4-q7} + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulWinogradFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulWinogradFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..49d30d0a6a76b9e77ff0eb5c3860d44ecd8f9a14 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulWinogradFp32.S @@ -0,0 +1,186 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// MatrixMultiplyWinograd(float *matix_a, float *matrix_b, float *matrix_c, int m, int k, int n, int in_channel, int c4_channel) + // r0: matrix_a, r1: matrix_b, r2: matrix_c, r3: m, r4: k, r5: n, r6: in_channel, r7: c4_channel * 4 + // #-56: matrix_a, #-52: matrix_b, #-48: matrix_c, #-44: m, #0: k, #4: n, #8: in_channel, #12: c4_channel * 4 +asm_function MatrixMultiplyWinograd + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r12, lr} + vpush {q4-q7} + add sp, sp, #120 + + mov r0, #4 + ldr r4, [sp, #4] // n + ldr r5, [sp, #8] // in_channel + ldr r6, [sp] // k + mul r5, r5, r0 // in_channel * 4 + mul r4, r4, r0 // n * 4 + mul r6, r6, r5 // in_channel * 4 * k + + // r3 = m + // r2 = dst + LoopM: + ldr r7, [sp, #4] // n + ldr r8, [sp, #-52] // matrix_b + LoopN: + ldr r0, [sp, #4] // n + ldr r1, [sp, #-44] // m + sub r0, r0, r7 // ni + mul r0, r0, r1 // ni * m + sub r1, r1, r3 // mi + add r0, r0, r1 // ni * m + mi + ldr r1, [sp, #12] + mul r9, r0, r1 // (ni * m + mi) * c4_channel * 4 + add r11, r2, r9 // dst + offset + + ldr r10, [sp, #8] // in_channel + ldr r9, [sp, #-56] // src + cmp r10, #16 + bge LoopC16 + cmp r10, #8 + bge LoopC8 + cmp r10, #4 + bge LoopC4 + cmp r10, #1 + bge LoopC + b EndLoopC + + LoopC16: + mov r0, r8 // mat_b1 + ldr r12, [sp] // k + veor q5, q5, q5 + veor q6, q6, q6 + veor q7, q7, q7 + veor q8, q8, q8 + LoopK16: + vld1.32 {q0, q1}, [r9]! + vld1.32 {q2, q3}, [r9]! + add r9, r9, r5 + sub r9, r9, #64 + vld1.32 d8[0], [r0], r4 + vmla.f32 q5, q0, d8[0] + vmla.f32 q6, q1, d8[0] + vmla.f32 q7, q2, d8[0] + vmla.f32 q8, q3, d8[0] + subs r12, r12, #1 + bne LoopK16 + Write16: + vst1.32 {q5, q6}, [r11]! + vst1.32 {q7, q8}, [r11]! + subs r10, r10, #16 + beq EndLoopC + sub r9, r9, r6 + add r9, r9, #64 + cmp r10, #16 + bge LoopC16 + cmp r10, #8 + bge LoopC8 + cmp r10, #4 + bge LoopC4 + cmp r10, #1 + bge LoopC + + LoopC8: + veor q5, q5, q5 + veor q6, q6, q6 + mov r0, r8 // mat_b1 + ldr r12, [sp] // k + LoopK8: + vld1.32 {q0, q1}, [r9], r5 + vld1.32 d8[0], [r0], r4 + vmla.f32 q5, q0, d8[0] + vmla.f32 q6, q1, d8[0] + subs r12, r12, #1 + bne LoopK8 + Write8: + vst1.32 {q5, q6}, [r11]! + subs r10, r10, #8 + beq EndLoopC + sub r9, r9, r6 + add r9, r9, #32 + cmp r10, #8 + bge LoopC8 + cmp r10, #4 + bge LoopC4 + cmp r10, #1 + bge LoopC + + LoopC4: + veor q5, q5, q5 + mov r0, r8 // mat_b1 + ldr r12, [sp] // k + LoopK4: + vld1.32 {q0}, [r9], r5 + vld1.32 d8[0], [r0], r4 + vmla.f32 q5, q0, d8[0] + subs r12, r12, #1 + bne LoopK4 + Write4: + vst1.32 {q5}, [r11]! + subs r10, r10, #4 + beq EndLoopC + sub r9, r9, r6 + add r9, r9, #16 + cmp r10, #4 + bge LoopC4 + cmp r10, #1 + bge LoopC + + LoopC: + veor q2, q2, q2 + mov r0, r8 // mat_b1 + ldr r12, [sp] // k + LoopK: + vld1.32 d0[0], [r9], r5 + vld1.32 d2[0], [r0], r4 + vmla.f32 s8, s0, s4 + subs r12, r12, #1 + bne LoopK + Write: + vst1.32 d4[0], [r11]! + subs r10, r10, #1 + beq EndLoopC + sub r9, r9, r6 + add r9, r9, #4 + b LoopC + + EndLoopC: + subs r7, r7, #1 + beq EndLoopN + add r8, r8, #4 + b LoopN + EndLoopN: + subs r3, r3, #1 + beq EndLoopM + ldr r0, [sp, #-56] + add r0, r0, r6 + str r0, [sp, #-56] + b LoopM + EndLoopM: + sub sp, sp, #120 + vpop {q4-q7} + pop {r0-r12, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PostFuncBiasReluC4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PostFuncBiasReluC4.S new file mode 100644 index 0000000000000000000000000000000000000000..d22ee8669cd4ecb580c1f08019d5aef7204f69f1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PostFuncBiasReluC4.S @@ -0,0 +1,248 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function WinogradPostFuncBiasReluC4 + push {r4-r8, r10, r11, lr} + add sp, sp, #32 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + vmov.i32 q14, #6 + vcvt.f32.s32 q14, q14 + veor q15, q15, q15 + + mov lr, #4 + add r12, r3, r4 + mul r12, r12, lr + + mov lr, #0 + +Loop_C4: + cmp lr, r3 + beq Loop_C1 + mov r11, #4 + mul r10, lr, r11 + add r11, r0, r10 + add lr, lr, #4 + mov r8, r5 + vld1.32 {q12}, [r2]! + +Loop_4x4: + cmp r8, #4 + blt Loop_1x4 + sub r8, r8, #4 + vld1.32 {q0-q1}, [r1]! + vld1.32 {q2-q3}, [r1]! + + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q12 + vadd.f32 q2, q2, q12 + vadd.f32 q3, q3, q12 + + cmp r7, #3 + beq Relu6_4x4 + cmp r7, #1 + beq Relu_4x4 + b Write_4x4 +Relu6_4x4: + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmin.f32 q2, q2, q14 + vmin.f32 q3, q3, q14 +Relu_4x4: + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vmax.f32 q2, q2, q15 + vmax.f32 q3, q3, q15 +Write_4x4: + vst1.32 {q0}, [r11], r12 + vst1.32 {q1}, [r11], r12 + vst1.32 {q2}, [r11], r12 + vst1.32 {q3}, [r11], r12 + b Loop_4x4 + +Loop_1x4: + cmp r7, #3 + beq Relu6_1x4 + cmp r7, #1 + beq Relu_1x4 + b Write_1x4 +Relu6_1x4: + cmp r8, #0 + beq HW_Add + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {q0}, [r11], r12 + b Relu6_1x4 +Relu_1x4: + cmp r8, #0 + beq HW_Add + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {q0}, [r11], r12 + b Relu_1x4 +Write_1x4: + cmp r8, #0 + beq HW_Add + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {q0}, [r11], r12 + b Write_1x4 + +HW_Add: + add r1, r1, r6 + b Loop_C4 + +Loop_C1: + cmp r4, #0 + beq End + mov r8, r5 + vld1.32 {q12}, [r2]! + mov r11, #4 + mul r10, lr, r11 + add r0, r0, r10 + + cmp r4, #1 + beq Loop_C1_1 + cmp r4, #2 + beq Loop_C1_2 + cmp r4, #3 + beq Loop_C1_3 + +Loop_C1_1: + cmp r7, #3 + beq Loop_C1_1_Relu6 + cmp r7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0[0]}, [r0], r12 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0[0]}, [r0], r12 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0[0]}, [r0], r12 + b Loop_C1_1_Write + +Loop_C1_2: + cmp r7, #3 + beq Loop_C1_2_Relu6 + cmp r7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r12 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r12 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0}, [r0], r12 + b Loop_C1_2_Write + +Loop_C1_3: + add r11, r0, #8 + cmp r7, #3 + beq Loop_C1_3_Relu6 + cmp r7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r12 + vst1.32 {d1[0]}, [r11], r12 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r12 + vst1.32 {d1[0]}, [r11], r12 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0}, [r0], r12 + vst1.32 {d1[0]}, [r11], r12 + b Loop_C1_3_Write + +End: + sub sp, sp, #32 + pop {r4-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PostFuncBiasReluC8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PostFuncBiasReluC8.S new file mode 100644 index 0000000000000000000000000000000000000000..93b860ae6502dfe6f749721d42cf990f7e9aa2ab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PostFuncBiasReluC8.S @@ -0,0 +1,450 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div,size_t oc8mod +// size_t plane_size, size_t stride, int relu_type); +// r0 dst r1 srx r2 bias +// r3 oc8div r4 oc8mod r5 plane_size +// r6 stride r7 relu_type + +// v0 ~ v15 value +// v16 v17 bias data +// r10 r11 weite loop tmp buf +// r16 relu6 #6; r17 relu #0 +// lr oc8 loop control +// r8 hw loop control + +asm_function PostFuncBiasReluC8 + push {r4-r8, r10, r11, lr} + add sp, sp, #32 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + vmov.i32 q14, #6 + vcvt.f32.s32 q14, q14 + veor q15, q15, q15 + mov lr, #0 + +Loop_C8: + cmp lr, r3 + beq Loop_C1 + mov r11, #4 + mul r10, lr, r11 + add r11, r0, r10 + add lr, lr, #8 + mov r8, r5 + vld1.32 {q12-q13}, [r2]! + +Loop_4x8: + cmp r8, #4 + blt Loop_1x8 + sub r8, r8, #4 + vld1.32 {q0-q1}, [r1]! + vld1.32 {q2-q3}, [r1]! + vld1.32 {q8-q9}, [r1]! + vld1.32 {q10-q11}, [r1]! + + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vadd.f32 q2, q2, q12 + vadd.f32 q3, q3, q13 + vadd.f32 q8, q8, q12 + vadd.f32 q9, q9, q13 + vadd.f32 q10, q10, q12 + vadd.f32 q11, q11, q13 + + cmp r7, #3 + beq Relu6_4x8 + cmp r7, #1 + beq Relu_4x8 + b Write_4x8 +Relu6_4x8: + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmin.f32 q2, q2, q14 + vmin.f32 q3, q3, q14 + vmin.f32 q8, q8, q14 + vmin.f32 q9, q9, q14 + vmin.f32 q10, q10, q14 + vmin.f32 q11, q11, q14 +Relu_4x8: + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vmax.f32 q2, q2, q15 + vmax.f32 q3, q3, q15 + vmax.f32 q8, q8, q15 + vmax.f32 q9, q9, q15 + vmax.f32 q10, q10, q15 + vmax.f32 q11, q11, q15 +Write_4x8: + vst1.32 {q0-q1}, [r11], r6 + vst1.32 {q2-q3}, [r11], r6 + vst1.32 {q8-q9}, [r11], r6 + vst1.32 {q10-q11}, [r11], r6 + b Loop_4x8 + +Loop_1x8: + cmp r7, #3 + beq Relu6_1x8 + cmp r7, #1 + beq Relu_1x8 + b Write_1x8 +Relu6_1x8: + cmp r8, #0 + beq Loop_C8 + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0-q1}, [r11], r6 + b Relu6_1x8 +Relu_1x8: + cmp r8, #0 + beq Loop_C8 + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0-q1}, [r11], r6 + b Relu_1x8 +Write_1x8: + cmp r8, #0 + beq Loop_C8 + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vst1.32 {q0-q1}, [r11], r6 + b Write_1x8 + +Loop_C1: + cmp r4, #0 + beq End + mov r8, r5 + vld1.32 {q12-q13}, [r2]! + mov r11, #4 + mul r10, lr, r11 + add r0, r0, r10 + + cmp r4, #1 + beq Loop_C1_1 + cmp r4, #2 + beq Loop_C1_2 + cmp r4, #3 + beq Loop_C1_3 + cmp r4, #4 + beq Loop_C1_4 + cmp r4, #5 + beq Loop_C1_5 + cmp r4, #6 + beq Loop_C1_6 + cmp r4, #7 + beq Loop_C1_7 + +Loop_C1_1: + cmp r7, #3 + beq Loop_C1_1_Relu6 + cmp r7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0[0]}, [r0], r6 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0[0]}, [r0], r6 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0[0]}, [r0], r6 + b Loop_C1_1_Write + +Loop_C1_2: + cmp r7, #3 + beq Loop_C1_2_Relu6 + cmp r7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r6 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r6 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0}, [r0], r6 + b Loop_C1_2_Write + +Loop_C1_3: + add r11, r0, #8 + cmp r7, #3 + beq Loop_C1_3_Relu6 + cmp r7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r6 + vst1.32 {d1[0]}, [r11], r6 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r6 + vst1.32 {d1[0]}, [r11], r6 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0}, [r0], r6 + vst1.32 {d1[0]}, [r11], r6 + b Loop_C1_3_Write + +Loop_C1_4: + cmp r7, #3 + beq Loop_C1_4_Relu6 + cmp r7, #1 + beq Loop_C1_4_Relu + b Loop_C1_4_Write +Loop_C1_4_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {q0}, [r0], r6 + b Loop_C1_4_Relu6 +Loop_C1_4_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {q0}, [r0], r6 + b Loop_C1_4_Relu6 +Loop_C1_4_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {q0}, [r0], r6 + b Loop_C1_4_Write + +Loop_C1_5: + add r11, r0, #16 + cmp r7, #3 + beq Loop_C1_5_Relu6 + cmp r7, #1 + beq Loop_C1_5_Relu + b Loop_C1_5_Write +Loop_C1_5_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2[0]}, [r11], r6 + b Loop_C1_5_Relu6 +Loop_C1_5_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2[0]}, [r11], r6 + b Loop_C1_5_Relu +Loop_C1_5_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2[0]}, [r11], r6 + b Loop_C1_5_Write + +Loop_C1_6: + add r11, r0, #16 + cmp r7, #3 + beq Loop_C1_6_Relu6 + cmp r7, #1 + beq Loop_C1_6_Relu + b Loop_C1_6_Write +Loop_C1_6_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + b Loop_C1_6_Relu6 +Loop_C1_6_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + b Loop_C1_6_Relu +Loop_C1_6_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + b Loop_C1_6_Write + +Loop_C1_7: + add r11, r0, #16 + add r10, r0, #24 + cmp r7, #3 + beq Loop_C1_7_Relu6 + cmp r7, #1 + beq Loop_C1_7_Relu + b Loop_C1_7_Write +Loop_C1_7_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + vst1.32 {d3[0]}, [r10], r6 + b Loop_C1_7_Relu6 +Loop_C1_7_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + vst1.32 {d3[0]}, [r10], r6 + b Loop_C1_7_Relu +Loop_C1_7_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + vst1.32 {d3[0]}, [r10], r6 + b Loop_C1_7_Write + +End: + sub sp, sp, #32 + pop {r4-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PreSum4x16Int8Peroc.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PreSum4x16Int8Peroc.S new file mode 100644 index 0000000000000000000000000000000000000000..0a557ac303847986d115f62764396bab6e90c3b9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PreSum4x16Int8Peroc.S @@ -0,0 +1,143 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div2, +// size_t oc_res2, size_t stride); + +// r0 src +// r1 sum +// r2 zp +// r3 hw4 +// r4 ic16 +// r5 oc_div2 +// r6 oc_res2 +// r7 stride + +asm_function PreSum4x16Int8Peroc + push {r4-r11, lr} + vpush {q4-q7} + add sp, sp, #100 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + mov r8, #0 + mov r10, #8 + +RowLoop: + cmp r8, r3 + beq End + add r8, r8, #4 + vmov.s32 q13, #0 + mov r9, #0 + mov r11, r2 + +Sum: + cmp r9, r4 + beq Mul + add r9, r9, #16 + + vld1.8 {q0, q1}, [r0]! + vld1.8 {q2, q3}, [r0]! + + vpaddl.s8 q4, q0 + vpaddl.s8 q5, q1 + vpaddl.s8 q6, q2 + vpaddl.s8 q7, q3 + + vpaddl.s16 q0, q4 + vpaddl.s16 q1, q5 + vpaddl.s16 q2, q6 + vpaddl.s16 q3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + vpaddl.s32 q6, q2 + vpaddl.s32 q7, q3 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + vqmovn.s64 d2, q6 + vqmovn.s64 d3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + + vadd.i32 q13, q13, q0 + b Sum + +Mul: + mov r12, r1 + add r1, r1, #32 + mov r9, #0 + + vdup.32 d1, d26[0] + vdup.32 d2, d26[1] + vdup.32 d3, d27[0] + vdup.32 d4, d27[1] + +Write: + + cmp r9, r5 + beq OcRes + add r9, r9, #2 + vld1.32 {d9}, [r11]! + + vmul.i32 d5, d1, d9 + vmul.i32 d6, d2, d9 + vmul.i32 d7, d3, d9 + vmul.i32 d8, d4, d9 + + vst1.32 d5, [r12], r10 + vst1.32 d6, [r12], r10 + vst1.32 d7, [r12], r10 + vst1.32 d8, [r12], r10 + add r12, r12, r7 + b Write + +OcRes: + cmp r6, #0 + beq RowLoop + + vmov.s32 d9, #0 + vld1.8 {d9[0]}, [r11] + + vmul.i32 d5, d1, d9 + vmul.i32 d6, d2, d9 + vmul.i32 d7, d3, d9 + vmul.i32 d8, d4, d9 + + vst1.32 d5, [r12], r10 + vst1.32 d6, [r12], r10 + vst1.32 d7, [r12], r10 + vst1.32 d8, [r12], r10 + b RowLoop + +End: + sub sp, sp, #100 + vpop {q4-q7} + pop {r4-r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PreSum4x16Int8Pert.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PreSum4x16Int8Pert.S new file mode 100644 index 0000000000000000000000000000000000000000..d0ad50c2c8e1ce4028e5071ac1c1799cb462b8f6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PreSum4x16Int8Pert.S @@ -0,0 +1,94 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp); + +// r0 src +// r1 sum +// r2 row4 +// r3 co16 +// r4 filter_zp + +asm_function PreSum4x16Int8Pert + push {r4-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #96 + + ldr r4, [sp] + + vdup.32 q10, r4 + mov r5, #0 + mov r7, #16 + +RowLoop: + cmp r5, r2 + beq End + add r5, r5, #4 + vmov.s32 q13, #0 + mov r6, #0 + +CalLoop: + cmp r6, r3 + beq Write + add r6, r6, #16 + + vld1.8 {q0, q1}, [r0]! + vld1.8 {q2, q3}, [r0]! + + vpaddl.s8 q4, q0 + vpaddl.s8 q5, q1 + vpaddl.s8 q6, q2 + vpaddl.s8 q7, q3 + + vpaddl.s16 q0, q4 + vpaddl.s16 q1, q5 + vpaddl.s16 q2, q6 + vpaddl.s16 q3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + vpaddl.s32 q6, q2 + vpaddl.s32 q7, q3 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + vqmovn.s64 d2, q6 + vqmovn.s64 d3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + + vadd.i32 q13, q13, q0 + b CalLoop + +Write: + vmul.i32 q13, q13, q10 + vst1.32 {d26, d27}, [r1], r7 + beq RowLoop + +End: + sub sp, sp, #96 + vpop {q4-q7} + pop {r4-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/TiledC4MatmulFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/TiledC4MatmulFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..9c725f71f6d4ed4e6b165b0a332d594593448da4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/TiledC4MatmulFp32.S @@ -0,0 +1,211 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function TiledC4MatmulFp32 +//void TiledC4MatmulFp32(float* dst, const float* src, const float* weight, size_t cal_num, size_t ic4, size_t oc4) +//x0: dst +//x1: src +//x2: weight +//x3: cal_num +//x4: ic4 +//x5: oc4 + +push {r4-r8, lr} +ldr r4, [sp, #24] +ldr r5, [sp, #28] +//step multi by sizeof(float) +mov r8, #4 +mul r3, r8, r3 + +vpush {q4-q7} + +LoopOc: + mov r6, r1 + mov r8, r0 + subs r7, r4, #1 + vld1.32 {q0, q1}, [r1]! + vld1.32 {q2, q3}, [r1]! + vld1.32 {q4, q5}, [r2]! + vld1.32 {q6, q7}, [r2]! + + vmul.f32 q8, q4, d0[0] + vmul.f32 q9, q4, d2[0] + vmul.f32 q10, q4, d4[0] + vmul.f32 q11, q4, d6[0] + + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q7, d5[1] + vmla.f32 q11, q7, d7[1] + + vld1.32 {q0, q1}, [r1]! + vld1.32 {q2, q3}, [r1]! + + vmul.f32 q12, q4, d0[0] + vmul.f32 q13, q4, d2[0] + vmul.f32 q14, q4, d4[0] + vmul.f32 q15, q4, d6[0] + + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q14, q6, d5[0] + vmla.f32 q15, q6, d7[0] + + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vmla.f32 q14, q7, d5[1] + vmla.f32 q15, q7, d7[1] + beq LoopIcEnd + + subs r7, r7, #1 + + vld1.32 {q4, q5}, [r2]! + vld1.32 {q0, q1}, [r1]! + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q8, q4, d0[0] + vmla.f32 q9, q4, d2[0] + beq LoopIcEndHalf + + LoopIc: + vmla.f32 q10, q4, d4[0] + vmla.f32 q11, q4, d6[0] + + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vld1.32 {q6, q7}, [r2]! + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q7, d5[1] + vld1.32 {q0, q1}, [r1]! + vmla.f32 q11, q7, d7[1] + + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q12, q4, d0[0] + vmla.f32 q13, q4, d2[0] + vmla.f32 q14, q4, d4[0] + vmla.f32 q15, q4, d6[0] + + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q14, q6, d5[0] + vld1.32 {q4, q5}, [r2]! + vmla.f32 q15, q6, d7[0] + + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vmla.f32 q14, q7, d5[1] + vld1.32 {q0, q1}, [r1]! + vmla.f32 q15, q7, d7[1] + + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q8, q4, d0[0] + vmla.f32 q9, q4, d2[0] + + subs r7, r7, #1 + bne LoopIc + LoopIcEndHalf: + vmla.f32 q10, q4, d4[0] + vmla.f32 q11, q4, d6[0] + + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vld1.32 {q6, q7}, [r2]! + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q7, d5[1] + vld1.32 {q0, q1}, [r1]! + vmla.f32 q11, q7, d7[1] + + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q12, q4, d0[0] + vmla.f32 q13, q4, d2[0] + vmla.f32 q14, q4, d4[0] + vmla.f32 q15, q4, d6[0] + + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q14, q6, d5[0] + vmla.f32 q15, q6, d7[0] + + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vmla.f32 q14, q7, d5[1] + vmla.f32 q15, q7, d7[1] + LoopIcEnd: + vst1.32 {q8, q9}, [r0]! + vst1.32 {q10, q11}, [r0]! + vst1.32 {q12, q13}, [r0]! + vst1.32 {q14, q15}, [r0]! + mov r1, r6 + + subs r5, r5, #1 + add r0, r8, r3 + bne LoopOc + + vpop {q4-q7} + pop {r4-r8, pc} + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/WinogradTransLeft.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/WinogradTransLeft.S new file mode 100644 index 0000000000000000000000000000000000000000..757688894907f00298f0012e118bad5e562cd6d7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/WinogradTransLeft.S @@ -0,0 +1,230 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void WinogradTransLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6: length +asm_function WinogradTransLeft + push {r4-r11, lr} + ldr r4, [sp, #36] + ldr r5, [sp, #40] + ldr r6, [sp, #44] + + mov r8, #16 // 4 * sizeof(float) + mul r8, r6, r8 + mul r9, r3, r8 + sub r9, r9, r8 + add r7, r9, r8 // step for S + mov r10, #4 + mul r10, r4, r10 // step for B + +LoopH: + push {r0, r3} + LoopW: + push {r0, r1} + vmov.i32 q14, #0 + mov r11, r6 + InitZero: + vst1.32 {q14}, [r2]! + subs r11, r11, #1 + bne InitZero + + sub r2, r2, r8 + mov r12, r5 + + LoopKStart7: + cmp r12, #7 + blt LoopKStart4 + push {r3-r7} + LoopK7: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + vld1.32 {d2[0]}, [r1], r10 + vld1.32 {d2[1]}, [r1], r10 + vld1.32 {d3[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r7 + add r3, r1, r7 + add r4, r3, r7 + add r5, r4, r7 + add r6, r5, r7 + add r7, r6, r7 + + LoopLength7: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + vld1.32 {q12}, [r5]! + vmla.f32 q8, q12, d2[0] + vld1.32 {q13}, [r6]! + vmla.f32 q9, q13, d2[1] + vld1.32 {q12}, [r7]! + vmla.f32 q8, q12, d3[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength7 + + sub r2, r2, r8 + sub r12, r12, #7 + add r0, r7, r9 + vmov.32 r1, d30[0] + cmp r12, #7 + bge LoopK7 + + pop {r3-r7} + + LoopKStart4: + cmp r12, #4 + blt LoopKStart3 + vmov.32 d30[1], r3 + vmov.32 d31[0], r4 + LoopK4: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r7 + add r3, r1, r7 + add r4, r3, r7 + + LoopLength4: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength4 + + sub r2, r2, r8 + sub r12, r12, #4 + add r0, r4, r9 + vmov.32 r1, d30[0] + cmp r12, #4 + bge LoopK4 + + vmov.32 r3, d30[1] + vmov.32 r4, d31[0] + + LoopKStart3: + cmp r12, #3 + blt LoopKStart + vmov.32 d30[1], r3 + vmov.32 d31[0], r4 + LoopK3: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r7 + add r3, r1, r7 + + LoopLength3: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength3 + + sub r2, r2, r8 + sub r12, r12, #3 + add r0, r3, r9 + vmov.32 r1, d30[0] + cmp r12, #3 + bge LoopK3 + + vmov.32 r3, d30[1] + vmov.32 r4, d31[0] + + LoopKStart: + cmp r12, #0 + beq LoopKEnd + + LoopK: + vld1.32 {d30[0]}, [r1], r10 + + vdup.32 q15, d30[0] + mov r11, r6 + LoopLength: + vld1.32 {q0}, [r2] + vld1.32 {q1}, [r0]! + vmla.f32 q0, q1, q15 + + vst1.32 {q0}, [r2]! + subs r11, r11, #1 + bne LoopLength + subs r12, r12, #1 + + sub r2, r2, r8 + add r0, r0, r9 + bne LoopK + + LoopKEnd: + pop {r0, r1} + subs r3, r3, #1 + add r0, r0, r8 + add r2, r2, r8 + bne LoopW + + pop {r0, r3} + add r1, r1, #4 //sizeof(float) + subs r4, r4, #1 + bne LoopH + + pop {r4-r11, pc} + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/WinogradTransRight.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/WinogradTransRight.S new file mode 100644 index 0000000000000000000000000000000000000000..2abb31abc9bf616685503bb7a24c6970d067e637 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/WinogradTransRight.S @@ -0,0 +1,220 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void WinogradTransRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6: length +asm_function WinogradTransRight + push {r4-r11, lr} + ldr r4, [sp, #36] + ldr r5, [sp, #40] + ldr r6, [sp, #44] + + mov r8, #16 // 4 * sizeof(float) + mul r8, r6, r8 + mul r9, r5, r8 // step for S + mov r10, #4 + mul r10, r4, r10 // step for B + +LoopH: + push {r1, r3} + LoopW: + push {r0, r1} + vmov.i32 q14, #0 + mov r11, r6 + InitZero: + vst1.32 {q14}, [r2]! + subs r11, r11, #1 + bne InitZero + + sub r2, r2, r8 + mov r12, r5 + LoopKStart7: + cmp r12, #7 + blt LoopKStart4 + push {r3-r7} + LoopK7: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + vld1.32 {d2[0]}, [r1], r10 + vld1.32 {d2[1]}, [r1], r10 + vld1.32 {d3[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r8 + add r3, r1, r8 + add r4, r3, r8 + add r5, r4, r8 + add r6, r5, r8 + add r7, r6, r8 + LoopLength7: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + vld1.32 {q12}, [r5]! + vmla.f32 q8, q12, d2[0] + vld1.32 {q13}, [r6]! + vmla.f32 q9, q13, d2[1] + vld1.32 {q12}, [r7]! + vmla.f32 q8, q12, d3[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength7 + + sub r2, r2, r8 + sub r12, r12, #7 + mov r0, r7 + vmov.32 r1, d30[0] + cmp r12, #7 + bge LoopK7 + + pop {r3-r7} + + LoopKStart4: + cmp r12, #4 + blt LoopKStart3 + vmov.32 d30[1], r3 + vmov.32 d31[0], r4 + LoopK4: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r8 + add r3, r1, r8 + add r4, r3, r8 + + LoopLength4: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength4 + + sub r2, r2, r8 + sub r12, r12, #4 + mov r0, r4 + vmov.32 r1, d30[0] + cmp r12, #4 + bge LoopK4 + + vmov.32 r3, d30[1] + vmov.32 r4, d31[0] + + LoopKStart3: + cmp r12, #3 + blt LoopKStart + vmov.32 d30[1], r3 + LoopK3: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r8 + add r3, r1, r8 + + LoopLength3: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength3 + + sub r2, r2, r8 + sub r12, r12, #3 + mov r0, r3 + vmov.32 r1, d30[0] + cmp r12, #3 + bge LoopK3 + + vmov.32 r3, d30[1] + + LoopKStart: + cmp r12, #0 + beq LoopKEnd + LoopK: + vld1.32 {d30[0]}, [r1], r10 + vdup.32 q15, d30[0] + mov r11, r6 + LoopLength: + vld1.32 {q0}, [r2] + vld1.32 {q1}, [r0]! + vmla.f32 q0, q1, q15 + + vst1.32 {q0}, [r2]! + subs r11, r11, #1 + bne LoopLength + + subs r12, r12, #1 + sub r2, r2, r8 + bne LoopK + LoopKEnd: + pop {r0, r1} + subs r3, r3, #1 + add r2, r2, r8 + add r1, r1, #4 //sizeof(float) + bne LoopW + + pop {r1, r3} + add r0, r0, r9 + subs r4, r4, #1 + bne LoopH + + pop {r4-r11, pc} + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/AdderFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/AdderFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..863d3cea6840c4cd8570e4a7d6d81125578f343a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/AdderFp32.S @@ -0,0 +1,622 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void AdderFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function AdderFloatNeon64 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + + ldr x8, [sp, #144] + + mov x20, #48 // sizeof(float) * 12 + mul x17, x5, x20 // block stride of lhs/rhs: sizeof(float) * 12 * depth + + mov x20, #4 + mul x8, x8, x20 + +LoopRowStart: + cmp x6, #4 + ble LoopRow4 + cmp x6, #8 + blt LoopRow8 + +LoopRow: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + LoopDepthStart: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + + dup v8.4s, v0.s[0] + fabd v9.4s, v3.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v11.4s, v3.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v13.4s, v3.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v15.4s, v3.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v17.4s, v3.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v19.4s, v3.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v21.4s, v3.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v23.4s, v3.4s, v22.4s + + dup v24.4s, v2.s[0] + fabd v25.4s, v3.4s, v24.4s + dup v26.4s, v2.s[1] + fabd v27.4s, v3.4s, v26.4s + dup v28.4s, v2.s[2] + fabd v29.4s, v3.4s, v28.4s + dup v30.4s, v2.s[3] + fabd v31.4s, v3.4s, v30.4s + + subs x19, x19, #1 + beq Bias + + LoopDepth: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + + dup v8.4s, v0.s[0] + fabd v8.4s, v3.4s, v8.4s + fadd v9.4s, v9.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v10.4s, v3.4s, v10.4s + fadd v11.4s, v11.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v12.4s, v3.4s, v12.4s + fadd v13.4s, v13.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v14.4s, v3.4s, v14.4s + fadd v15.4s, v15.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v16.4s, v3.4s, v16.4s + fadd v17.4s, v17.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v18.4s, v3.4s, v18.4s + fadd v19.4s, v19.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v20.4s, v3.4s, v20.4s + fadd v21.4s, v21.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v22.4s, v3.4s, v22.4s + fadd v23.4s, v23.4s, v22.4s + + dup v24.4s, v2.s[0] + fabd v24.4s, v3.4s, v24.4s + fadd v25.4s, v25.4s, v24.4s + dup v26.4s, v2.s[1] + fabd v26.4s, v3.4s, v26.4s + fadd v27.4s, v27.4s, v26.4s + dup v28.4s, v2.s[2] + fabd v28.4s, v3.4s, v28.4s + fadd v29.4s, v29.4s, v28.4s + dup v30.4s, v2.s[3] + fabd v30.4s, v3.4s, v30.4s + fadd v31.4s, v31.4s, v30.4s + + subs x19, x19, #1 + bgt LoopDepth + + Bias: + fneg v9.4s, v9.4s + fneg v11.4s, v11.4s + fneg v13.4s, v13.4s + fneg v15.4s, v15.4s + fneg v17.4s, v17.4s + fneg v19.4s, v19.4s + fneg v21.4s, v21.4s + fneg v23.4s, v23.4s + fneg v25.4s, v25.4s + fneg v27.4s, v27.4s + fneg v29.4s, v29.4s + fneg v31.4s, v31.4s + cbz x3, Activation + ld1 {v0.4s}, [x12], #16 + fadd v9.4s, v9.4s, v0.4s + fadd v11.4s, v11.4s, v0.4s + fadd v13.4s, v13.4s, v0.4s + fadd v15.4s, v15.4s, v0.4s + fadd v17.4s, v17.4s, v0.4s + fadd v19.4s, v19.4s, v0.4s + fadd v21.4s, v21.4s, v0.4s + fadd v23.4s, v23.4s, v0.4s + fadd v25.4s, v25.4s, v0.4s + fadd v27.4s, v27.4s, v0.4s + fadd v29.4s, v29.4s, v0.4s + fadd v31.4s, v31.4s, v0.4s + + Activation: + cmp x4, #3 + beq Relu6 + cmp x4, #1 + beq Relu + b Write + + Relu6: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + + Relu: + dup v3.4s, wzr + fmax v9.4s, v9.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + b Write + +LoopRow8: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + LoopDepthStart8: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + + dup v8.4s, v0.s[0] + fabd v9.4s, v3.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v11.4s, v3.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v13.4s, v3.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v15.4s, v3.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v17.4s, v3.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v19.4s, v3.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v21.4s, v3.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v23.4s, v3.4s, v22.4s + + subs x19, x19, #1 + beq Bias8 + + LoopDepth8: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + dup v8.4s, v0.s[0] + fabd v8.4s, v3.4s, v8.4s + fadd v9.4s, v9.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v10.4s, v3.4s, v10.4s + fadd v11.4s, v11.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v12.4s, v3.4s, v12.4s + fadd v13.4s, v13.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v14.4s, v3.4s, v14.4s + fadd v15.4s, v15.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v16.4s, v3.4s, v16.4s + fadd v17.4s, v17.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v18.4s, v3.4s, v18.4s + fadd v19.4s, v19.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v20.4s, v3.4s, v20.4s + fadd v21.4s, v21.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v22.4s, v3.4s, v22.4s + fadd v23.4s, v23.4s, v22.4s + + subs x19, x19, #1 + bgt LoopDepth8 + + Bias8: + fneg v9.4s, v9.4s + fneg v11.4s, v11.4s + fneg v13.4s, v13.4s + fneg v15.4s, v15.4s + fneg v17.4s, v17.4s + fneg v19.4s, v19.4s + fneg v21.4s, v21.4s + fneg v23.4s, v23.4s + cbz x3, Activation8 + ld1 {v0.4s}, [x12], #16 + fadd v9.4s, v9.4s, v0.4s + fadd v11.4s, v11.4s, v0.4s + fadd v13.4s, v13.4s, v0.4s + fadd v15.4s, v15.4s, v0.4s + fadd v17.4s, v17.4s, v0.4s + fadd v19.4s, v19.4s, v0.4s + fadd v21.4s, v21.4s, v0.4s + fadd v23.4s, v23.4s, v0.4s + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + Relu8: + dup v3.4s, wzr + fmax v9.4s, v9.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + b Write + +LoopRow4: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + LoopDepthStart4: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + dup v8.4s, v0.s[0] + fabd v9.4s, v3.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v11.4s, v3.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v13.4s, v3.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v15.4s, v3.4s, v14.4s + + subs x19, x19, #1 + beq Bias4 + + LoopDepth4: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + dup v8.4s, v0.s[0] + fabd v8.4s, v3.4s, v8.4s + fadd v9.4s, v9.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v10.4s, v3.4s, v10.4s + fadd v11.4s, v11.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v12.4s, v3.4s, v12.4s + fadd v13.4s, v13.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v14.4s, v3.4s, v14.4s + fadd v15.4s, v15.4s, v14.4s + + subs x19, x19, #1 + bgt LoopDepth4 + + Bias4: + fneg v9.4s, v9.4s + fneg v11.4s, v11.4s + fneg v13.4s, v13.4s + fneg v15.4s, v15.4s + cbz x3, Activation4 + ld1 {v0.4s}, [x12], #16 + + fadd v9.4s, v9.4s, v0.4s + fadd v11.4s, v11.4s, v0.4s + fadd v13.4s, v13.4s, v0.4s + fadd v15.4s, v15.4s, v0.4s + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + + Relu4: + dup v3.4s, wzr + fmax v9.4s, v9.4s, v2.4s + fmax v11.4s, v11.4s, v2.4s + fmax v13.4s, v13.4s, v2.4s + fmax v15.4s, v15.4s, v2.4s + b Write + + Write: + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + b Write4 + + Write1: + add x2, x2, #4 + str s9, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s11, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s13, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s15, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str s17, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str s19, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str s21, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str s23, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str s25, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str s27, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str s29, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str s31, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + st1 {v9.2s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v11.2s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v13.2s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v15.2s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v17.2s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v19.2s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v21.2s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.2s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v25.2s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v27.2s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v29.2s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v31.2s}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + st1 {v9.2s}, [x11], x8 + st1 {v9.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v11.2s}, [x11], x8 + st1 {v11.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v13.2s}, [x11], x8 + st1 {v13.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v15.2s}, [x11], x8 + st1 {v15.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v17.2s}, [x11], x8 + st1 {v17.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v19.2s}, [x11], x8 + st1 {v19.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v21.2s}, [x11], x8 + st1 {v21.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.2s}, [x11], x8 + st1 {v23.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v25.2s}, [x11], x8 + st1 {v25.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v27.2s}, [x11], x8 + st1 {v27.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v29.2s}, [x11], x8 + st1 {v29.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v31.2s}, [x11], x8 + st1 {v31.s}[2], [x19] + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v15.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v17.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v19.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v21.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v25.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v27.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v29.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v31.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + + WriteEnd: + subs x13, x13, #4 // rhs col - 4 + ble LoopColEnd + cmp x6, #4 + ble LoopCol4 + cmp x6, #8 + ble LoopCol8 + b LoopCol + +LoopColEnd: + add x0, x0, x17 + mov x20, #4 + mul x20, x20, x7 + sub x11, x11, x20 + mov x2, x11 + subs x6, x6, #12 + bgt LoopRowStart + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/BigMatmulFp32Opt.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/BigMatmulFp32Opt.S new file mode 100644 index 0000000000000000000000000000000000000000..6097397d7a274bca447f237a181fc9ecc7773f0b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/BigMatmulFp32Opt.S @@ -0,0 +1,2528 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void BigMatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride + +asm_function BigMatmulFloatNeon64Opt + sub sp, sp, #224 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + stp x29, x30, [sp, #208] + + ldr x8, [sp, #224] + mov x20, #1 + mov x22, #32 + mov x23, #48 + mul x26, x5, x23 // stride for lhs + mul x24, x8, x23 // stride for out + lsl x27, x23, #9 // stride by depth for lhs + lsl x28, x22, #9 // stride by depth for rhs + lsl x22, x5, #5 // stride for rhs + lsl x8, x8, #2 + subs x5, x5, #512 + ble DepthTail +Depth512: + mov x25, x0 // restore lhs + mov x13, x2 // out + mov x10, x6 // restore row + RowStart: + mov x12, x1 // rhs + mov x14, x13 // out + mov x15, x3 // restore bias + mov x9, x7 // restore col + cmp x10, #4 + ble LoopRow4 + cmp x10, #8 + ble LoopRow8 + + LoopRow12: + mov x11, x25 // lhs + mov x23, x12 // rhs + mov x21, #512 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopCol12x4 + + LoopCol12x8: + cbz x20, Reload12x8 + cbnz x15, InitFromBias12x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + dup v24.2d, xzr + dup v25.2d, xzr + dup v26.2d, xzr + dup v27.2d, xzr + dup v28.2d, xzr + dup v29.2d, xzr + dup v30.2d, xzr + dup v31.2d, xzr + b Compute12x8Enter + InitFromBias12x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + ld1 {v16.4s, v17.4s}, [x15] + ld1 {v18.4s, v19.4s}, [x15] + ld1 {v20.4s, v21.4s}, [x15] + ld1 {v22.4s, v23.4s}, [x15] + ld1 {v24.4s, v25.4s}, [x15] + ld1 {v26.4s, v27.4s}, [x15] + ld1 {v28.4s, v29.4s}, [x15] + ld1 {v30.4s, v31.4s}, [x15] + add x15, x15, #32 + b Compute12x8Enter + Reload12x8: + bl Reload + Compute12x8Enter: + cbz x21, Write + bl Compute12x8Unit + b Write + + LoopCol12x4: + cbz x20, Reload12x4 + cbnz x15, InitFromBias12x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + dup v24.2d, xzr + dup v26.2d, xzr + dup v28.2d, xzr + dup v30.2d, xzr + b Compute12x4Enter + InitFromBias12x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + ld1 {v16.4s}, [x15] + ld1 {v18.4s}, [x15] + ld1 {v20.4s}, [x15] + ld1 {v22.4s}, [x15] + ld1 {v24.4s}, [x15] + ld1 {v26.4s}, [x15] + ld1 {v28.4s}, [x15] + ld1 {v30.4s}, [x15] + b Compute12x4Enter + Reload12x4: + bl Reload + Compute12x4Enter: + cbz x21, Write + bl Compute12x4Unit + b Write + + LoopRow8: + mov x11, x25 // lhs + mov x23, x12 // rhs + mov x21, #512 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopCol8x4 + + LoopCol8x8: + cbz x20, Reload8x8 + cbnz x15, InitFromBias8x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + b Compute8x8Enter + InitFromBias8x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + ld1 {v16.4s, v17.4s}, [x15] + ld1 {v18.4s, v19.4s}, [x15] + ld1 {v20.4s, v21.4s}, [x15] + ld1 {v22.4s, v23.4s}, [x15] + add x15, x15, #32 + b Compute8x8Enter + Reload8x8: + bl Reload + Compute8x8Enter: + cbz x21, Write + bl Compute8x8Unit + b Write + + LoopCol8x4: + cbz x20, Reload8x4 + cbnz x15, InitFromBias8x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + b Compute8x4Enter + InitFromBias8x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + ld1 {v16.4s}, [x15] + ld1 {v18.4s}, [x15] + ld1 {v20.4s}, [x15] + ld1 {v22.4s}, [x15] + b Compute8x4Enter + Reload8x4: + bl Reload + Compute8x4Enter: + cbz x21, Write + bl Compute8x4Unit + b Write + + LoopRow4: + mov x11, x25 // lhs + mov x23, x12 // rhs + mov x21, #512 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopCol4x4 + + LoopCol4x8: + cbz x20, Reload4x8 + cbnz x15, InitFromBias4x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + b Compute4x8Enter + InitFromBias4x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + add x15, x15, #32 + b Compute4x8Enter + Reload4x8: + bl Reload + Compute4x8Enter: + cbz x21, Write + bl Compute4x8Unit + b Write + + LoopCol4x4: + cbz x20, Reload4x4 + cbnz x15, InitFromBias4x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + b Compute4x4Enter + InitFromBias4x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + b Compute4x4Enter + Reload4x4: + bl Reload + Compute4x4Enter: + cbz x21, Write + bl Compute4x4Unit + +Write: + mov x21, x14 + cmp x9, #1 + beq Write1 + cmp x9, #2 + beq Write2 + cmp x9, #3 + beq Write3 + cmp x9, #4 + beq Write4 + cmp x9, #5 + beq Write5 + cmp x9, #6 + beq Write6 + cmp x9, #7 + beq Write7 + b Write8 + + Write1: + str s8, [x21] + cmp x10, #1 + beq LoopCol + add x21, x21, x8 + str s10, [x21] + cmp x10, #2 + beq LoopCol + add x21, x21, x8 + str s12, [x21] + cmp x10, #3 + beq LoopCol + add x21, x21, x8 + str s14, [x21] + cmp x10, #4 + beq LoopCol + add x21, x21, x8 + str s16, [x21] + cmp x10, #5 + beq LoopCol + add x21, x21, x8 + str s18, [x21] + cmp x10, #6 + beq LoopCol + add x21, x21, x8 + str s20, [x21] + cmp x10, #7 + beq LoopCol + add x21, x21, x8 + str s22, [x21] + cmp x10, #8 + beq LoopCol + add x21, x21, x8 + str s24, [x21] + cmp x10, #9 + beq LoopCol + add x21, x21, x8 + str s26, [x21] + cmp x10, #10 + beq LoopCol + add x21, x21, x8 + str s28, [x21] + cmp x10, #11 + beq LoopCol + add x21, x21, x8 + str s30, [x21] + b LoopCol + Write2: + st1 {v8.2s}, [x21], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.2s}, [x21], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.2s}, [x21], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.2s}, [x21], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.2s}, [x21], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.2s}, [x21], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.2s}, [x21], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.2s}, [x21], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.2s}, [x21], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.2s}, [x21], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.2s}, [x21], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.2s}, [x21], x8 + add x21, x21, #8 + b LoopCol + Write3: + add x11, x21, #8 + st1 {v8.2s}, [x21], x8 + st1 {v8.s}[2], [x11], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.2s}, [x21], x8 + st1 {v10.s}[2], [x11], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.2s}, [x21], x8 + st1 {v12.s}[2], [x11], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.2s}, [x21], x8 + st1 {v14.s}[2], [x11], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.2s}, [x21], x8 + st1 {v16.s}[2], [x11], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.2s}, [x21], x8 + st1 {v18.s}[2], [x11], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.2s}, [x21], x8 + st1 {v20.s}[2], [x11], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.2s}, [x21], x8 + st1 {v22.s}[2], [x11], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.2s}, [x21], x8 + st1 {v24.s}[2], [x11], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.2s}, [x21], x8 + st1 {v26.s}[2], [x11], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.2s}, [x21], x8 + st1 {v28.s}[2], [x11], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.2s}, [x21], x8 + st1 {v30.s}[2], [x11] + add x21, x21, #12 + b LoopCol + Write4: + st1 {v8.4s}, [x21], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.4s}, [x21], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.4s}, [x21], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.4s}, [x21], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.4s}, [x21], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.4s}, [x21], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.4s}, [x21], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.4s}, [x21], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.4s}, [x21], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.4s}, [x21], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.4s}, [x21], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.4s}, [x21], x8 + add x21, x21, #16 + b LoopCol + Write5: + add x11, x21, #16 + st1 {v8.4s}, [x21], x8 + str s9, [x11] + cmp x10, #1 + beq LoopCol + add x11, x11, x8 + st1 {v10.4s}, [x21], x8 + str s11, [x11] + cmp x10, #2 + beq LoopCol + add x11, x11, x8 + st1 {v12.4s}, [x21], x8 + str s13, [x11] + cmp x10, #3 + beq LoopCol + add x11, x11, x8 + st1 {v14.4s}, [x21], x8 + str s15, [x11] + cmp x10, #4 + beq LoopCol + add x11, x11, x8 + st1 {v16.4s}, [x21], x8 + str s17, [x11] + cmp x10, #5 + beq LoopCol + add x11, x11, x8 + st1 {v18.4s}, [x21], x8 + str s19, [x11] + cmp x10, #6 + beq LoopCol + add x11, x11, x8 + st1 {v20.4s}, [x21], x8 + str s21, [x11] + cmp x10, #7 + beq LoopCol + add x11, x11, x8 + st1 {v22.4s}, [x21], x8 + str s23, [x11] + cmp x10, #8 + beq LoopCol + add x11, x11, x8 + st1 {v24.4s}, [x21], x8 + str s25, [x11] + cmp x10, #9 + beq LoopCol + add x11, x11, x8 + st1 {v26.4s}, [x21], x8 + str s27, [x11] + cmp x10, #10 + beq LoopCol + add x11, x11, x8 + st1 {v28.4s}, [x21], x8 + str s29, [x11] + cmp x10, #11 + beq LoopCol + add x11, x11, x8 + st1 {v30.4s}, [x21], x8 + str s31, [x11] + add x21, x21, #20 + b LoopCol + Write6: + add x11, x21, #16 + st1 {v8.4s}, [x21], x8 + st1 {v9.2s}, [x11], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.4s}, [x21], x8 + st1 {v11.2s}, [x11], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.4s}, [x21], x8 + st1 {v13.2s}, [x11], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.4s}, [x21], x8 + st1 {v15.2s}, [x11], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.4s}, [x21], x8 + st1 {v17.2s}, [x11], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.4s}, [x21], x8 + st1 {v19.2s}, [x11], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.4s}, [x21], x8 + st1 {v21.2s}, [x11], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.4s}, [x21], x8 + st1 {v23.2s}, [x11], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.4s}, [x21], x8 + st1 {v25.2s}, [x11], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.4s}, [x21], x8 + st1 {v27.2s}, [x11], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.4s}, [x21], x8 + st1 {v29.2s}, [x11], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.4s}, [x21], x8 + st1 {v31.2s}, [x11] + add x21, x21, #24 + b LoopCol + Write7: + add x11, x21, #16 + add x23, x21, #24 + st1 {v8.4s}, [x21], x8 + st1 {v9.2s}, [x11], x8 + st1 {v9.s}[2], [x23], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.4s}, [x21], x8 + st1 {v11.2s}, [x11], x8 + st1 {v11.s}[2], [x23], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.4s}, [x21], x8 + st1 {v13.2s}, [x11], x8 + st1 {v13.s}[2], [x23], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.4s}, [x21], x8 + st1 {v15.2s}, [x11], x8 + st1 {v15.s}[2], [x23], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.4s}, [x21], x8 + st1 {v17.2s}, [x11], x8 + st1 {v17.s}[2], [x23], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.4s}, [x21], x8 + st1 {v19.2s}, [x11], x8 + st1 {v19.s}[2], [x23], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.4s}, [x21], x8 + st1 {v21.2s}, [x11], x8 + st1 {v21.s}[2], [x23], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.4s}, [x21], x8 + st1 {v23.2s}, [x11], x8 + st1 {v23.s}[2], [x23], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.4s}, [x21], x8 + st1 {v25.2s}, [x11], x8 + st1 {v25.s}[2], [x23], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.4s}, [x21], x8 + st1 {v27.2s}, [x11], x8 + st1 {v27.s}[2], [x23], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.4s}, [x21], x8 + st1 {v29.2s}, [x11], x8 + st1 {v29.s}[2], [x23], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.4s}, [x21], x8 + st1 {v31.2s}, [x11] + st1 {v31.s}[2], [x23] + add x21, x21, #28 + b LoopCol + + Write8: + st1 {v8.4s, v9.4s}, [x21], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.4s, v11.4s}, [x21], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.4s, v13.4s}, [x21], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.4s, v15.4s}, [x21], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.4s, v17.4s}, [x21], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.4s, v19.4s}, [x21], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.4s, v21.4s}, [x21], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.4s, v23.4s}, [x21], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.4s, v25.4s}, [x21], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.4s, v27.4s}, [x21], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.4s, v29.4s}, [x21], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.4s, v31.4s}, [x21], x8 + add x21, x21, #32 + b LoopCol + +LoopCol: + subs x9, x9, #8 + ble LoopColEnd + add x12, x12, x22 // update rhs + add x14, x14, #32 // update out + cmp x10, #4 + ble LoopRow4 + cmp x10, #8 + ble LoopRow8 + b LoopRow12 + +LoopColEnd: + add x25, x25, x26 // update lhs + add x13, x13, x24 // update out + subs x10, x10, #12 // update row + bgt RowStart + mov x20, #0 + add x0, x0, x27 // update lhs by depth + add x1, x1, x28 // update rhs by depth + subs x5, x5, #512 + bgt Depth512 + +/////////////////////////////////////////////////////// + +DepthTail: + add x5, x5, #512 + mov x13, x2 // out + mov x10, x6 + TailRowStart: + mov x12, x1 // rhs + mov x14, x13 // out + mov x15, x3 // restore bias + mov x9, x7 // restore col + cmp x10, #4 + ble LoopTailRow4 + cmp x10, #8 + ble LoopTailRow8 + + LoopTailRow12: + mov x11, x0 // lhs + mov x23, x12 // rhs + mov x21, x5 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopTailCol12x4 + + LoopTailCol12x8: + cbz x20, ReloadTail12x8 + cbnz x15, InitTailFromBias12x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + dup v24.2d, xzr + dup v25.2d, xzr + dup v26.2d, xzr + dup v27.2d, xzr + dup v28.2d, xzr + dup v29.2d, xzr + dup v30.2d, xzr + dup v31.2d, xzr + b ComputeTail12x8Enter + InitTailFromBias12x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + ld1 {v16.4s, v17.4s}, [x15] + ld1 {v18.4s, v19.4s}, [x15] + ld1 {v20.4s, v21.4s}, [x15] + ld1 {v22.4s, v23.4s}, [x15] + ld1 {v24.4s, v25.4s}, [x15] + ld1 {v26.4s, v27.4s}, [x15] + ld1 {v28.4s, v29.4s}, [x15] + ld1 {v30.4s, v31.4s}, [x15] + add x15, x15, #32 + b ComputeTail12x8Enter + ReloadTail12x8: + bl Reload + ComputeTail12x8Enter: + cbz x21, Activation12x8 + bl Compute12x8Unit + Activation12x8: + cmp x4, #3 + beq Relu612x8 + cmp x4, #1 + beq Relu12x8 + b WriteTail + + Relu612x8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + + Relu12x8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + b WriteTail + + LoopTailCol12x4: + cbz x20, ReloadTail12x4 + cbnz x15, InitTailFromBias12x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + dup v24.2d, xzr + dup v26.2d, xzr + dup v28.2d, xzr + dup v30.2d, xzr + b ComputeTail12x4Enter + InitTailFromBias12x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + ld1 {v16.4s}, [x15] + ld1 {v18.4s}, [x15] + ld1 {v20.4s}, [x15] + ld1 {v22.4s}, [x15] + ld1 {v24.4s}, [x15] + ld1 {v26.4s}, [x15] + ld1 {v28.4s}, [x15] + ld1 {v30.4s}, [x15] + b ComputeTail12x4Enter + ReloadTail12x4: + bl Reload + ComputeTail12x4Enter: + cbz x21, Activation12x4 + bl Compute12x4Unit + Activation12x4: + cmp x4, #3 + beq Relu612x4 + cmp x4, #1 + beq Relu12x4 + b WriteTail + + Relu612x4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + + Relu12x4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + b WriteTail + + LoopTailRow8: + mov x11, x0 // lhs + mov x23, x12 // rhs + mov x21, x5 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopTailCol8x4 + + LoopTailCol8x8: + cbz x20, ReloadTail8x8 + cbnz x15, InitTailFromBias8x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + b ComputeTail8x8Enter + InitTailFromBias8x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + ld1 {v16.4s, v17.4s}, [x15] + ld1 {v18.4s, v19.4s}, [x15] + ld1 {v20.4s, v21.4s}, [x15] + ld1 {v22.4s, v23.4s}, [x15] + add x15, x15, #32 + b ComputeTail8x8Enter + ReloadTail8x8: + bl Reload + ComputeTail8x8Enter: + cbz x21, Activation8x8 + bl Compute8x8Unit + Activation8x8: + cmp x4, #3 + beq Relu68x8 + cmp x4, #1 + beq Relu8x8 + b WriteTail + + Relu68x8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + + Relu8x8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + b WriteTail + + LoopTailCol8x4: + cbz x20, ReloadTail8x4 + cbnz x15, InitTailFromBias8x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + b ComputeTail8x4Enter + InitTailFromBias8x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + ld1 {v16.4s}, [x15] + ld1 {v18.4s}, [x15] + ld1 {v20.4s}, [x15] + ld1 {v22.4s}, [x15] + b ComputeTail8x4Enter + ReloadTail8x4: + bl Reload + ComputeTail8x4Enter: + cbz x21, Activation8x4 + bl Compute8x4Unit + Activation8x4: + cmp x4, #3 + beq Relu68x4 + cmp x4, #1 + beq Relu8x4 + b WriteTail + + Relu68x4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + + Relu8x4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + b WriteTail + + LoopTailRow4: + mov x11, x0 // lhs + mov x23, x12 // rhs + mov x21, x5 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopTailCol4x4 + + LoopTailCol4x8: + cbz x20, ReloadTail4x8 + cbnz x15, InitTailFromBias4x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + b ComputeTail4x8Enter + InitTailFromBias4x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + add x15, x15, #32 + b ComputeTail4x8Enter + ReloadTail4x8: + bl Reload + ComputeTail4x8Enter: + cbz x21, Activation4x8 + bl Compute4x8Unit + Activation4x8: + cmp x4, #3 + beq Relu64x8 + cmp x4, #1 + beq Relu4x8 + b WriteTail + + Relu64x8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + + Relu4x8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + b WriteTail + + LoopTailCol4x4: + cbz x20, ReloadTail4x4 + cbnz x15, InitTailFromBias4x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + b ComputeTail4x4Enter + InitTailFromBias4x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + b ComputeTail4x4Enter + ReloadTail4x4: + bl Reload + ComputeTail4x4Enter: + cbz x21, Activation4x4 + bl Compute4x4Unit + Activation4x4: + cmp x4, #3 + beq Relu64x4 + cmp x4, #1 + beq Relu4x4 + b WriteTail + + Relu64x4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + + Relu4x4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + +WriteTail: + mov x21, x14 + cmp x9, #1 + beq WriteTail1 + cmp x9, #2 + beq WriteTail2 + cmp x9, #3 + beq WriteTail3 + cmp x9, #4 + beq WriteTail4 + cmp x9, #5 + beq WriteTail5 + cmp x9, #6 + beq WriteTail6 + cmp x9, #7 + beq WriteTail7 + b WriteTail8 + + WriteTail1: + str s8, [x21] + cmp x10, #1 + beq LoopTailCol + add x21, x21, x8 + str s10, [x21] + cmp x10, #2 + beq LoopTailCol + add x21, x21, x8 + str s12, [x21] + cmp x10, #3 + beq LoopTailCol + add x21, x21, x8 + str s14, [x21] + cmp x10, #4 + beq LoopTailCol + add x21, x21, x8 + str s16, [x21] + cmp x10, #5 + beq LoopTailCol + add x21, x21, x8 + str s18, [x21] + cmp x10, #6 + beq LoopTailCol + add x21, x21, x8 + str s20, [x21] + cmp x10, #7 + beq LoopTailCol + add x21, x21, x8 + str s22, [x21] + cmp x10, #8 + beq LoopTailCol + add x21, x21, x8 + str s24, [x21] + cmp x10, #9 + beq LoopTailCol + add x21, x21, x8 + str s26, [x21] + cmp x10, #10 + beq LoopTailCol + add x21, x21, x8 + str s28, [x21] + cmp x10, #11 + beq LoopTailCol + add x21, x21, x8 + str s30, [x21] + b LoopTailCol + WriteTail2: + st1 {v8.2s}, [x21], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.2s}, [x21], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.2s}, [x21], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.2s}, [x21], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.2s}, [x21], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.2s}, [x21], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.2s}, [x21], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.2s}, [x21], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.2s}, [x21], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.2s}, [x21], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.2s}, [x21], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.2s}, [x21], x8 + add x21, x21, #8 + b LoopTailCol + WriteTail3: + add x11, x21, #8 + st1 {v8.2s}, [x21], x8 + st1 {v8.s}[2], [x11], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.2s}, [x21], x8 + st1 {v10.s}[2], [x11], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.2s}, [x21], x8 + st1 {v12.s}[2], [x11], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.2s}, [x21], x8 + st1 {v14.s}[2], [x11], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.2s}, [x21], x8 + st1 {v16.s}[2], [x11], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.2s}, [x21], x8 + st1 {v18.s}[2], [x11], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.2s}, [x21], x8 + st1 {v20.s}[2], [x11], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.2s}, [x21], x8 + st1 {v22.s}[2], [x11], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.2s}, [x21], x8 + st1 {v24.s}[2], [x11], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.2s}, [x21], x8 + st1 {v26.s}[2], [x11], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.2s}, [x21], x8 + st1 {v28.s}[2], [x11], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.2s}, [x21], x8 + st1 {v30.s}[2], [x11] + add x21, x21, #12 + b LoopTailCol + WriteTail4: + st1 {v8.4s}, [x21], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.4s}, [x21], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.4s}, [x21], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.4s}, [x21], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.4s}, [x21], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.4s}, [x21], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.4s}, [x21], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.4s}, [x21], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.4s}, [x21], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.4s}, [x21], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.4s}, [x21], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.4s}, [x21], x8 + add x21, x21, #16 + b LoopTailCol + WriteTail5: + add x11, x21, #16 + st1 {v8.4s}, [x21], x8 + str s9, [x11] + cmp x10, #1 + beq LoopTailCol + add x11, x11, x8 + st1 {v10.4s}, [x21], x8 + str s11, [x11] + cmp x10, #2 + beq LoopTailCol + add x11, x11, x8 + st1 {v12.4s}, [x21], x8 + str s13, [x11] + cmp x10, #3 + beq LoopTailCol + add x11, x11, x8 + st1 {v14.4s}, [x21], x8 + str s15, [x11] + cmp x10, #4 + beq LoopTailCol + add x11, x11, x8 + st1 {v16.4s}, [x21], x8 + str s17, [x11] + cmp x10, #5 + beq LoopTailCol + add x11, x11, x8 + st1 {v18.4s}, [x21], x8 + str s19, [x11] + cmp x10, #6 + beq LoopTailCol + add x11, x11, x8 + st1 {v20.4s}, [x21], x8 + str s21, [x11] + cmp x10, #7 + beq LoopTailCol + add x11, x11, x8 + st1 {v22.4s}, [x21], x8 + str s23, [x11] + cmp x10, #8 + beq LoopTailCol + add x11, x11, x8 + st1 {v24.4s}, [x21], x8 + str s25, [x11] + cmp x10, #9 + beq LoopTailCol + add x11, x11, x8 + st1 {v26.4s}, [x21], x8 + str s27, [x11] + cmp x10, #10 + beq LoopTailCol + add x11, x11, x8 + st1 {v28.4s}, [x21], x8 + str s29, [x11] + cmp x10, #11 + beq LoopTailCol + add x11, x11, x8 + st1 {v30.4s}, [x21], x8 + str s31, [x11] + add x21, x21, #20 + b LoopTailCol + WriteTail6: + add x11, x21, #16 + st1 {v8.4s}, [x21], x8 + st1 {v9.2s}, [x11], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.4s}, [x21], x8 + st1 {v11.2s}, [x11], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.4s}, [x21], x8 + st1 {v13.2s}, [x11], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.4s}, [x21], x8 + st1 {v15.2s}, [x11], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.4s}, [x21], x8 + st1 {v17.2s}, [x11], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.4s}, [x21], x8 + st1 {v19.2s}, [x11], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.4s}, [x21], x8 + st1 {v21.2s}, [x11], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.4s}, [x21], x8 + st1 {v23.2s}, [x11], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.4s}, [x21], x8 + st1 {v25.2s}, [x11], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.4s}, [x21], x8 + st1 {v27.2s}, [x11], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.4s}, [x21], x8 + st1 {v29.2s}, [x11], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.4s}, [x21], x8 + st1 {v31.2s}, [x11] + add x21, x21, #24 + b LoopTailCol + WriteTail7: + add x11, x21, #16 + add x23, x21, #24 + st1 {v8.4s}, [x21], x8 + st1 {v9.2s}, [x11], x8 + st1 {v9.s}[2], [x23], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.4s}, [x21], x8 + st1 {v11.2s}, [x11], x8 + st1 {v11.s}[2], [x23], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.4s}, [x21], x8 + st1 {v13.2s}, [x11], x8 + st1 {v13.s}[2], [x23], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.4s}, [x21], x8 + st1 {v15.2s}, [x11], x8 + st1 {v15.s}[2], [x23], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.4s}, [x21], x8 + st1 {v17.2s}, [x11], x8 + st1 {v17.s}[2], [x23], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.4s}, [x21], x8 + st1 {v19.2s}, [x11], x8 + st1 {v19.s}[2], [x23], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.4s}, [x21], x8 + st1 {v21.2s}, [x11], x8 + st1 {v21.s}[2], [x23], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.4s}, [x21], x8 + st1 {v23.2s}, [x11], x8 + st1 {v23.s}[2], [x23], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.4s}, [x21], x8 + st1 {v25.2s}, [x11], x8 + st1 {v25.s}[2], [x23], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.4s}, [x21], x8 + st1 {v27.2s}, [x11], x8 + st1 {v27.s}[2], [x23], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.4s}, [x21], x8 + st1 {v29.2s}, [x11], x8 + st1 {v29.s}[2], [x23], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.4s}, [x21], x8 + st1 {v31.2s}, [x11] + st1 {v31.s}[2], [x23] + add x21, x21, #28 + b LoopTailCol + + WriteTail8: + st1 {v8.4s, v9.4s}, [x21], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.4s, v11.4s}, [x21], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.4s, v13.4s}, [x21], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.4s, v15.4s}, [x21], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.4s, v17.4s}, [x21], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.4s, v19.4s}, [x21], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.4s, v21.4s}, [x21], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.4s, v23.4s}, [x21], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.4s, v25.4s}, [x21], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.4s, v27.4s}, [x21], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.4s, v29.4s}, [x21], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.4s, v31.4s}, [x21], x8 + add x21, x21, #32 + b LoopTailCol + +LoopTailCol: + subs x9, x9, #8 + ble LoopTailEnd + add x12, x12, x22 // update rhs + add x14, x14, #32 + cmp x10, #4 + ble LoopTailRow4 + cmp x10, #8 + ble LoopTailRow8 + b LoopTailRow12 + +LoopTailEnd: + add x0, x0, x26 // update lhs + add x13, x13, x24 // update out + subs x10, x10, #12 // update row + bgt TailRowStart + b End + +Reload: + mov x15, x14 + cmp x9, #1 + beq Reload1 + cmp x9, #2 + beq Reload2 + cmp x9, #3 + beq Reload3 + cmp x9, #4 + beq Reload4 + cmp x9, #5 + beq Reload5 + cmp x9, #6 + beq Reload6 + cmp x9, #7 + beq Reload7 + b Reload8 + + Reload1: + ldr s8, [x15] + cmp x10, #1 + beq ReloadEnd + add x15, x15, x8 + ldr s10, [x15] + cmp x10, #2 + beq ReloadEnd + add x15, x15, x8 + ldr s12, [x15] + cmp x10, #3 + beq ReloadEnd + add x15, x15, x8 + ldr s14, [x15] + cmp x10, #4 + beq ReloadEnd + add x15, x15, x8 + ldr s16, [x15] + cmp x10, #5 + beq ReloadEnd + add x15, x15, x8 + ldr s18, [x15] + cmp x10, #6 + beq ReloadEnd + add x15, x15, x8 + ldr s20, [x15] + cmp x10, #7 + beq ReloadEnd + add x15, x15, x8 + ldr s22, [x15] + cmp x10, #8 + beq ReloadEnd + add x15, x15, x8 + ldr s24, [x15] + cmp x10, #9 + beq ReloadEnd + add x15, x15, x8 + ldr s26, [x15] + cmp x10, #10 + beq ReloadEnd + add x15, x15, x8 + ldr s28, [x15] + cmp x10, #11 + beq ReloadEnd + add x15, x15, x8 + ldr s30, [x15] + b ReloadEnd + Reload2: + ld1 {v8.2s}, [x15], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.2s}, [x15], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.2s}, [x15], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.2s}, [x15], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.2s}, [x15], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.2s}, [x15], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.2s}, [x15], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.2s}, [x15], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.2s}, [x15], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.2s}, [x15], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.2s}, [x15], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.2s}, [x15], x8 + add x15, x15, #8 + b ReloadEnd + Reload3: + add x19, x15, #8 + ld1 {v8.2s}, [x15], x8 + ld1 {v8.s}[2], [x19], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.2s}, [x15], x8 + ld1 {v10.s}[2], [x19], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.2s}, [x15], x8 + ld1 {v12.s}[2], [x19], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.2s}, [x15], x8 + ld1 {v14.s}[2], [x19], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.2s}, [x15], x8 + ld1 {v16.s}[2], [x19], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.2s}, [x15], x8 + ld1 {v18.s}[2], [x19], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.2s}, [x15], x8 + ld1 {v20.s}[2], [x19], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.2s}, [x15], x8 + ld1 {v22.s}[2], [x19], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.2s}, [x15], x8 + ld1 {v24.s}[2], [x19], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.2s}, [x15], x8 + ld1 {v26.s}[2], [x19], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.2s}, [x15], x8 + ld1 {v28.s}[2], [x19], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.2s}, [x15], x8 + ld1 {v30.s}[2], [x19] + add x15, x15, #12 + b ReloadEnd + Reload4: + ld1 {v8.4s}, [x15], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.4s}, [x15], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.4s}, [x15], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.4s}, [x15], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.4s}, [x15], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.4s}, [x15], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.4s}, [x15], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.4s}, [x15], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.4s}, [x15], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.4s}, [x15], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.4s}, [x15], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.4s}, [x15], x8 + add x15, x15, #16 + b ReloadEnd + Reload5: + add x19, x15, #16 + ld1 {v8.4s}, [x15], x8 + ldr s9, [x19] + cmp x10, #1 + beq ReloadEnd + add x19, x19, x8 + ld1 {v10.4s}, [x15], x8 + ldr s11, [x19] + cmp x10, #2 + beq ReloadEnd + add x19, x19, x8 + ld1 {v12.4s}, [x15], x8 + ldr s13, [x19] + cmp x10, #3 + beq ReloadEnd + add x19, x19, x8 + ld1 {v14.4s}, [x15], x8 + ldr s15, [x19] + cmp x10, #4 + beq ReloadEnd + add x19, x19, x8 + ld1 {v16.4s}, [x15], x8 + ldr s17, [x19] + cmp x10, #5 + beq ReloadEnd + add x19, x19, x8 + ld1 {v18.4s}, [x15], x8 + ldr s19, [x19] + cmp x10, #6 + beq ReloadEnd + add x19, x19, x8 + ld1 {v20.4s}, [x15], x8 + ldr s21, [x19] + cmp x10, #7 + beq ReloadEnd + add x19, x19, x8 + ld1 {v22.4s}, [x15], x8 + ldr s23, [x19] + cmp x10, #8 + beq ReloadEnd + add x19, x19, x8 + ld1 {v24.4s}, [x15], x8 + ldr s25, [x19] + cmp x10, #9 + beq ReloadEnd + add x19, x19, x8 + ld1 {v26.4s}, [x15], x8 + ldr s27, [x19] + cmp x10, #10 + beq ReloadEnd + add x19, x19, x8 + ld1 {v28.4s}, [x15], x8 + ldr s29, [x19] + cmp x10, #11 + beq ReloadEnd + add x19, x19, x8 + ld1 {v30.4s}, [x15], x8 + ldr s31, [x19] + add x15, x15, #20 + b ReloadEnd + Reload6: + add x19, x15, #16 + ld1 {v8.4s}, [x15], x8 + ld1 {v9.2s}, [x19], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.4s}, [x15], x8 + ld1 {v11.2s}, [x19], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.4s}, [x15], x8 + ld1 {v13.2s}, [x19], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.4s}, [x15], x8 + ld1 {v15.2s}, [x19], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.4s}, [x15], x8 + ld1 {v17.2s}, [x19], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.4s}, [x15], x8 + ld1 {v19.2s}, [x19], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.4s}, [x15], x8 + ld1 {v21.2s}, [x19], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.4s}, [x15], x8 + ld1 {v23.2s}, [x19], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.4s}, [x15], x8 + ld1 {v25.2s}, [x19], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.4s}, [x15], x8 + ld1 {v27.2s}, [x19], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.4s}, [x15], x8 + ld1 {v29.2s}, [x19], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.4s}, [x15], x8 + ld1 {v31.2s}, [x19] + add x15, x15, #24 + b ReloadEnd + Reload7: + add x19, x15, #16 + add x16, x15, #24 + ld1 {v8.4s}, [x15], x8 + ld1 {v9.2s}, [x19], x8 + ld1 {v9.s}[2], [x16], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.4s}, [x15], x8 + ld1 {v11.2s}, [x19], x8 + ld1 {v11.s}[2], [x16], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.4s}, [x15], x8 + ld1 {v13.2s}, [x19], x8 + ld1 {v13.s}[2], [x16], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.4s}, [x15], x8 + ld1 {v15.2s}, [x19], x8 + ld1 {v15.s}[2], [x16], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.4s}, [x15], x8 + ld1 {v17.2s}, [x19], x8 + ld1 {v17.s}[2], [x16], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.4s}, [x15], x8 + ld1 {v19.2s}, [x19], x8 + ld1 {v19.s}[2], [x16], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.4s}, [x15], x8 + ld1 {v21.2s}, [x19], x8 + ld1 {v21.s}[2], [x16], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.4s}, [x15], x8 + ld1 {v23.2s}, [x19], x8 + ld1 {v23.s}[2], [x16], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.4s}, [x15], x8 + ld1 {v25.2s}, [x19], x8 + ld1 {v25.s}[2], [x16], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.4s}, [x15], x8 + ld1 {v27.2s}, [x19], x8 + ld1 {v27.s}[2], [x16], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.4s}, [x15], x8 + ld1 {v29.2s}, [x19], x8 + ld1 {v29.s}[2], [x16], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.4s}, [x15], x8 + ld1 {v31.2s}, [x19] + ld1 {v31.s}[2], [x16] + add x15, x15, #28 + b ReloadEnd + + Reload8: + ld1 {v8.4s, v9.4s}, [x15], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.4s, v11.4s}, [x15], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.4s, v13.4s}, [x15], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.4s, v15.4s}, [x15], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.4s, v17.4s}, [x15], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.4s, v19.4s}, [x15], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.4s, v21.4s}, [x15], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.4s, v23.4s}, [x15], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.4s, v25.4s}, [x15], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.4s, v27.4s}, [x15], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.4s, v29.4s}, [x15], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.4s, v31.4s}, [x15], x8 + add x15, x15, #32 + b ReloadEnd + +ReloadEnd: + ret + +Compute12x8Unit: + subs x21, x21, #2 + ble Compute12x8End + Compute12x8: + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + + subs x21, x21, #2 + bgt Compute12x8 + + Compute12x8End: + cbnz x21, Compute12x8End1 + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + Compute12x8End1: + ld1 {v1.4s, v2.4s}, [x11] + ld1 {v4.4s}, [x23] + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + ret + +Compute12x4Unit: + subs x21, x21, #2 + ble Compute12x4End + Compute12x4: + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + add x23, x23, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + add x23, x23, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + + subs x21, x21, #2 + bgt Compute12x4 + + Compute12x4End: + cbnz x21, Compute12x4End1 + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + add x23, x23, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + Compute12x4End1: + ld1 {v1.4s, v2.4s}, [x11] + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + ret + +Compute8x8Unit: + subs x21, x21, #2 + ble Compute8x8End + Compute8x8: + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + + subs x21, x21, #2 + bgt Compute8x8 + + Compute8x8End: + cbnz x21, Compute8x8End1 + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + Compute8x8End1: + ld1 {v1.4s}, [x11] + ld1 {v4.4s}, [x23] + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + ret + +Compute8x4Unit: + subs x21, x21, #2 + ble Compute8x4End + Compute8x4: + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + + subs x21, x21, #2 + bgt Compute8x4 + + Compute8x4End: + cbnz x21, Compute8x4End1 + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + Compute8x4End1: + ld1 {v1.4s}, [x11] + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + ret + +Compute4x8Unit: + subs x21, x21, #2 + ble Compute4x8End + Compute4x8: + prfm pldl1keep, [x11, #632] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + + prfm pldl1keep, [x11, #632] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + + subs x21, x21, #2 + bgt Compute4x8 + + Compute4x8End: + cbnz x21, Compute4x8End1 + prfm pldl1keep, [x11, #632] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + Compute4x8End1: + ld1 {v4.4s}, [x23] + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ret + +Compute4x4Unit: + subs x21, x21, #2 + ble Compute4x4End + Compute4x4: + prfm pldl1keep, [x11, #632] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + + prfm pldl1keep, [x11, #632] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + + subs x21, x21, #2 + bgt Compute4x4 + + Compute4x4End: + cbnz x21, Compute4x4End1 + prfm pldl1keep, [x11, #632] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + Compute4x4End1: + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ret + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ldp x29, x30, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Corner.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Corner.S new file mode 100644 index 0000000000000000000000000000000000000000..c03863ef4853ed9bb530207a9beb15aeed79ce4b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Corner.S @@ -0,0 +1,114 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Corner(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, +// int in_kw_step, int channel, size_t relu, size_t relu6) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, x6: channel, x7: relu, x8: relu6 + +asm_function ConvDw3x3Corner + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + mov x9, #4 + mul x13, x6, x9 // x6 * 4 + mul x4, x4, x9 + mul x5, x5, x9 + mov x9, #3 + mul x14, x13, x9 // x6 * 3 * 4 + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + ld1 {v23.4s}, [x3], #16 + mov x9, x1 + mov x10, x2 + + ld1 {v0.4s}, [x9], x5 + add x11, x1, x4 + ld1 {v4.4s}, [x10], x13 // weight + add x12, x2, x14 + ld1 {v1.4s}, [x9], x5 + ld1 {v5.4s}, [x10], x13 + ld1 {v2.4s}, [x11], x5 + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + ld1 {v7.4s}, [x12], x13 + + cmp x6, #4 + ble LoopC4Post + + LoopC4: + add x1, x1, #16 + add x2, x2, #16 + fmla v23.4s, v0.4s, v4.4s + mov x9, x1 + mov x10, x2 + ld1 {v0.4s}, [x9], x5 + ld1 {v4.4s}, [x10], x13 + add x11, x1, x4 + fmla v23.4s, v1.4s, v5.4s + add x12, x2, x14 + ld1 {v1.4s}, [x9], x5 + fmla v23.4s, v2.4s, v6.4s + ld1 {v5.4s}, [x10], x13 + ld1 {v2.4s}, [x11], x5 + fmla v23.4s, v3.4s, v7.4s + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + ld1 {v7.4s}, [x12], x13 + + cbnz x8, C4_RELU6 + cbnz x7, C4_RELU + b C4_WRITE + C4_RELU6: + fmin v23.4s, v23.4s, v26.4s + C4_RELU: + fmax v23.4s, v23.4s, v27.4s + C4_WRITE: + st1 {v23.4s}, [x0], #16 + ld1 {v23.4s}, [x3], #16 + + sub x6, x6, #4 + cmp x6, #4 + bgt LoopC4 + + LoopC4Post: + fmla v23.4s, v0.4s, v4.4s + fmla v23.4s, v1.4s, v5.4s + fmla v23.4s, v2.4s, v6.4s + fmla v23.4s, v3.4s, v7.4s + + cbnz x8, RELU6 + cbnz x7, RELU + b WRITE + RELU6: + fmin v23.4s, v23.4s, v26.4s + RELU: + fmax v23.4s, v23.4s, v27.4s + WRITE: + st1 {v23.4s}, [x0], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Horizontal.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Horizontal.S new file mode 100644 index 0000000000000000000000000000000000000000..b828c30b9f4a66c29265be69b08e2533d72f9f18 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Horizontal.S @@ -0,0 +1,130 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Horizontal(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, +// int in_kw_step, int channel, size_t relu, size_t relu6) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, x6: channel, x7: relu, x8: relu6 + +asm_function ConvDw3x3Horizontal + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + mov x9, #4 + mul x13, x6, x9 // x6 * 4 + mul x4, x4, x9 + mul x5, x5, x9 + mov x9, #3 + mul x14, x13, x9 // x6 * 3 * 4 + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + ld1 {v23.4s}, [x3], #16 + mov x9, x1 + mov x10, x2 + + ld1 {v0.4s}, [x9], x5 + add x11, x1, x4 + ld1 {v4.4s}, [x10], x13 + add x12, x2, x14 + ld1 {v1.4s}, [x9], x5 + ld1 {v5.4s}, [x10], x13 + add x15, x11, x4 + ld1 {v2.4s}, [x11], x5 + add x16, x12, x14 + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + ld1 {v7.4s}, [x12], x13 + ld1 {v16.4s}, [x15], x5 + ld1 {v18.4s}, [x16], x13 + ld1 {v17.4s}, [x15], x5 + ld1 {v19.4s}, [x16], x13 + + cmp x6, #4 + ble LoopC4Post + + LoopC4: + add x1, x1, #16 + add x2, x2, #16 + fmla v23.4s, v0.4s, v4.4s + mov x9, x1 + mov x10, x2 + ld1 {v0.4s}, [x9], x5 + ld1 {v4.4s}, [x10], x13 + add x11, x1, x4 + fmla v23.4s, v1.4s, v5.4s + add x12, x2, x14 + ld1 {v1.4s}, [x9], x5 + fmla v23.4s, v2.4s, v6.4s + add x15, x11, x4 + ld1 {v5.4s}, [x10], x13 + ld1 {v2.4s}, [x11], x5 + fmla v23.4s, v3.4s, v7.4s + add x16, x12, x14 + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + fmla v23.4s, v16.4s, v18.4s + ld1 {v7.4s}, [x12], x13 + ld1 {v16.4s}, [x15], x5 + fmla v23.4s, v17.4s, v19.4s + ld1 {v18.4s}, [x16], x13 + ld1 {v17.4s}, [x15], x5 + ld1 {v19.4s}, [x16], x13 + + cbnz x8, C4_RELU6 + cbnz x7, C4_RELU + b C4_WRITE + C4_RELU6: + fmin v23.4s, v23.4s, v26.4s + C4_RELU: + fmax v23.4s, v23.4s, v27.4s + C4_WRITE: + st1 {v23.4s}, [x0], #16 + ld1 {v23.4s}, [x3], #16 + + sub x6, x6, #4 + cmp x6, #4 + bgt LoopC4 + + LoopC4Post: + fmla v23.4s, v0.4s, v4.4s + fmla v23.4s, v1.4s, v5.4s + fmla v23.4s, v2.4s, v6.4s + fmla v23.4s, v3.4s, v7.4s + fmla v23.4s, v16.4s, v18.4s + fmla v23.4s, v17.4s, v19.4s + + cbnz x8, RELU6 + cbnz x7, RELU + b WRITE + RELU6: + fmin v23.4s, v23.4s, v26.4s + RELU: + fmax v23.4s, v23.4s, v27.4s + WRITE: + st1 {v23.4s}, [x0], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Stride1.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Stride1.S new file mode 100644 index 0000000000000000000000000000000000000000..339585689d09c6ed97b0cab312695b9b222ccfb0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Stride1.S @@ -0,0 +1,210 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Stride1(float *output, const float *buffer, const float *weight, const float *bias, int col_size, +// int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6) +// +// x0: output +// x1: input +// x2: weight +// x3: bias +// w4: col_size +// w5: row_size +// w6: channel +// w7: output_h +// w8: output_w +// w9: relu +// w10: relu6 + +asm_function ConvDw3x3Stride1 + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + + ldr w8, [sp, #128] + ldr w9, [sp, #136] + ldr w10, [sp, #144] + + mov w11, #4 + mul w15, w4, w11 // col_size * 4 + mul w16, w6, w11 // channel * 4 + mul w17, w5, w11 // row_size * 4 + mov w11, #2 + mul w14, w11, w15 // col_size * 2 * 4 + + movi v23.4s, #6 + scvtf v23.4s, v23.4s + dup v24.4s, wzr + + // Load weights + ld1 {v0.4s}, [x2], x16 + ld1 {v1.4s}, [x2], x16 + ld1 {v2.4s}, [x2], x16 + ld1 {v3.4s}, [x2], x16 + ld1 {v4.4s}, [x2], x16 + ld1 {v5.4s}, [x2], x16 + ld1 {v6.4s}, [x2], x16 + ld1 {v7.4s}, [x2], x16 + ld1 {v8.4s}, [x2], x16 + + mov x11, x1 + add x12, x11, x17 + add x13, x12, x17 + ld1 {v9.4s}, [x11], x15 + ld1 {v10.4s}, [x11], x15 + ld1 {v11.4s}, [x11], x15 + ld1 {v13.4s}, [x12], x15 + ld1 {v14.4s}, [x12], x15 + ld1 {v15.4s}, [x12], x15 + ld1 {v17.4s}, [x13], x15 + ld1 {v18.4s}, [x13], x15 + ld1 {v19.4s}, [x13], x15 + + ld1 {v21.4s}, [x3] + ld1 {v22.4s}, [x3] + + cmp w8, #2 + beq WIDTH2_LEFT + cmp w8, #1 + beq WIDTH1_LEFT + +WIDTH2_LOOP: + fmla v21.4s, v0.4s, v9.4s + ld1 {v12.4s}, [x11] + ld1 {v16.4s}, [x12] + fmla v22.4s, v0.4s, v10.4s + ld1 {v20.4s}, [x13] + add x1, x1, x14 + fmla v21.4s, v1.4s, v10.4s + mov x11, x1 + add x12, x11, x17 + add x13, x12, x17 + ld1 {v9.4s}, [x11], x15 + fmla v22.4s, v1.4s, v11.4s + ld1 {v10.4s}, [x11], x15 + fmla v21.4s, v2.4s, v11.4s + fmla v22.4s, v2.4s, v12.4s + fmla v21.4s, v3.4s, v13.4s + ld1 {v11.4s}, [x11], x15 + fmla v22.4s, v3.4s, v14.4s + fmla v21.4s, v4.4s, v14.4s + ld1 {v13.4s}, [x12], x15 + fmla v22.4s, v4.4s, v15.4s + fmla v21.4s, v5.4s, v15.4s + ld1 {v14.4s}, [x12], x15 + fmla v22.4s, v5.4s, v16.4s + fmla v21.4s, v6.4s, v17.4s + ld1 {v15.4s}, [x12], x15 + fmla v22.4s, v6.4s, v18.4s + fmla v21.4s, v7.4s, v18.4s + ld1 {v17.4s}, [x13], x15 + fmla v22.4s, v7.4s, v19.4s + fmla v21.4s, v8.4s, v19.4s + ld1 {v18.4s}, [x13], x15 + fmla v22.4s, v8.4s, v20.4s + ld1 {v19.4s}, [x13], x15 + + cbnz x10, WIDTH2_RELU6 + cbnz x9, WIDTH2_RELU + b WIDTH2_WRITE + WIDTH2_RELU6: + fmin v21.4s, v21.4s, v23.4s + fmin v22.4s, v22.4s, v23.4s + WIDTH2_RELU: + fmax v21.4s, v21.4s, v24.4s + fmax v22.4s, v22.4s, v24.4s + WIDTH2_WRITE: + st1 {v21.4s}, [x0], x16 + ld1 {v21.4s}, [x3] + st1 {v22.4s}, [x0], x16 + ld1 {v22.4s}, [x3] + + sub w8, w8, #2 + cmp w8, #2 + bgt WIDTH2_LOOP + + cmp w8, #2 + blt WIDTH1_LEFT + +WIDTH2_LEFT: + fmla v21.4s, v0.4s, v9.4s + ld1 {v12.4s}, [x11] + fmla v22.4s, v0.4s, v10.4s + fmla v21.4s, v1.4s, v10.4s + ld1 {v16.4s}, [x12] + fmla v22.4s, v1.4s, v11.4s + fmla v21.4s, v2.4s, v11.4s + ld1 {v20.4s}, [x13] + fmla v22.4s, v2.4s, v12.4s + fmla v21.4s, v3.4s, v13.4s + fmla v22.4s, v3.4s, v14.4s + fmla v21.4s, v4.4s, v14.4s + fmla v22.4s, v4.4s, v15.4s + fmla v21.4s, v5.4s, v15.4s + fmla v22.4s, v5.4s, v16.4s + fmla v21.4s, v6.4s, v17.4s + fmla v22.4s, v6.4s, v18.4s + fmla v21.4s, v7.4s, v18.4s + fmla v22.4s, v7.4s, v19.4s + fmla v21.4s, v8.4s, v19.4s + fmla v22.4s, v8.4s, v20.4s + + cbnz x10, WIDTH2_LEFT_RELU6 + cbnz x9, WIDTH2_LEFT_RELU + b WIDTH2_LEFT_WRITE + WIDTH2_LEFT_RELU6: + fmin v21.4s, v21.4s, v23.4s + fmin v22.4s, v22.4s, v23.4s + WIDTH2_LEFT_RELU: + fmax v21.4s, v21.4s, v24.4s + fmax v22.4s, v22.4s, v24.4s + WIDTH2_LEFT_WRITE: + st1 {v21.4s}, [x0], x16 + st1 {v22.4s}, [x0], x16 + b End + +WIDTH1_LEFT: + fmla v21.4s, v0.4s, v9.4s + fmla v21.4s, v1.4s, v10.4s + fmla v21.4s, v2.4s, v11.4s + fmla v21.4s, v3.4s, v13.4s + fmla v21.4s, v4.4s, v14.4s + fmla v21.4s, v5.4s, v15.4s + fmla v21.4s, v6.4s, v17.4s + fmla v21.4s, v7.4s, v18.4s + fmla v21.4s, v8.4s, v19.4s + + cbnz x10, WIDTH1_RELU6 + cbnz x9, WIDTH1_RELU + b WIDTH1_WRITE + WIDTH1_RELU6: + fmin v21.4s, v21.4s, v23.4s + WIDTH1_RELU: + fmax v21.4s, v21.4s, v24.4s + WIDTH1_WRITE: + st1 {v21.4s}, [x0] + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Stride2.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Stride2.S new file mode 100644 index 0000000000000000000000000000000000000000..b3d90e05f4e27a79065de090f6b992baf77d45f3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Stride2.S @@ -0,0 +1,212 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Stride2(float *output, const float *buffer, const float *weight, const float *bias, int col_size, +// int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6) +// +// x0: output +// x1: input +// x2: weight +// x3: bias +// w4: col_size +// w5: row_size +// w6: channel +// w7: output_h +// w8: output_w +// w9: relu +// w10: relu6 + +asm_function ConvDw3x3Stride2 + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + + ldr w8, [sp, #128] + ldr w9, [sp, #136] + ldr w10, [sp, #144] + + mov w11, #4 + mul w15, w4, w11 // col_size * 4 + mul w16, w6, w11 // channel * 4 + mul w17, w5, w11 // row_size * 4 + mov w11, #2 + mul w14, w11, w15 // col_size * 2 * 4 + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + // Load weights + ld1 {v0.4s}, [x2], x16 + ld1 {v1.4s}, [x2], x16 + ld1 {v2.4s}, [x2], x16 + ld1 {v3.4s}, [x2], x16 + ld1 {v4.4s}, [x2], x16 + ld1 {v5.4s}, [x2], x16 + ld1 {v6.4s}, [x2], x16 + ld1 {v7.4s}, [x2], x16 + ld1 {v8.4s}, [x2], x16 + + mov x11, x1 + add x12, x11, x17 + add x13, x12, x17 + ld1 {v9.4s}, [x11], x15 + ld1 {v10.4s}, [x11], x15 + ld1 {v11.4s}, [x11], x15 + ld1 {v14.4s}, [x12], x15 + ld1 {v15.4s}, [x12], x15 + ld1 {v16.4s}, [x12], x15 + ld1 {v19.4s}, [x13], x15 + ld1 {v20.4s}, [x13], x15 + ld1 {v21.4s}, [x13], x15 + + ld1 {v24.4s}, [x3] + ld1 {v25.4s}, [x3] + + cmp w8, #2 + beq WIDTH2_LEFT + cmp w8, #1 + beq WIDTH1_LEFT + +WIDTH2_LOOP: + fmla v24.4s, v0.4s, v9.4s + ld1 {v12.4s}, [x11], x15 + fmla v25.4s, v0.4s, v11.4s + ld1 {v17.4s}, [x12], x15 + fmla v24.4s, v1.4s, v10.4s + ld1 {v22.4s}, [x13], x15 + fmla v25.4s, v1.4s, v12.4s + ld1 {v13.4s}, [x11], x15 + fmla v24.4s, v2.4s, v11.4s + ld1 {v18.4s}, [x12], x15 + fmla v25.4s, v2.4s, v13.4s + ld1 {v23.4s}, [x13], x15 + fmla v24.4s, v3.4s, v14.4s + mov v9.16b, v13.16b + fmla v25.4s, v3.4s, v16.4s + mov v14.16b, v18.16b + fmla v24.4s, v4.4s, v15.4s + fmla v25.4s, v4.4s, v17.4s + ld1 {v10.4s}, [x11], x15 + fmla v24.4s, v5.4s, v16.4s + ld1 {v11.4s}, [x11], x15 + fmla v25.4s, v5.4s, v18.4s + ld1 {v15.4s}, [x12], x15 + fmla v24.4s, v6.4s, v19.4s + ld1 {v16.4s}, [x12], x15 + fmla v25.4s, v6.4s, v21.4s + mov v19.16b, v23.16b + fmla v24.4s, v7.4s, v20.4s + fmla v25.4s, v7.4s, v22.4s + ld1 {v20.4s}, [x13], x15 + fmla v24.4s, v8.4s, v21.4s + fmla v25.4s, v8.4s, v23.4s + ld1 {v21.4s}, [x13], x15 + + cbnz x10, WIDTH2_RELU6 + cbnz x9, WIDTH2_RELU + b WIDTH2_WRITE + WIDTH2_RELU6: + fmin v24.4s, v24.4s, v26.4s + fmin v25.4s, v25.4s, v26.4s + WIDTH2_RELU: + fmax v24.4s, v24.4s, v27.4s + fmax v25.4s, v25.4s, v27.4s + WIDTH2_WRITE: + st1 {v24.4s}, [x0], x16 + ld1 {v24.4s}, [x3] + st1 {v25.4s}, [x0], x16 + ld1 {v25.4s}, [x3] + + sub w8, w8, #2 + cmp w8, #2 + bgt WIDTH2_LOOP + + cmp w8, #2 + blt WIDTH1_LEFT + +WIDTH2_LEFT: + fmla v24.4s, v0.4s, v9.4s + ld1 {v12.4s}, [x11], x15 + fmla v25.4s, v0.4s, v11.4s + ld1 {v17.4s}, [x12], x15 + fmla v24.4s, v1.4s, v10.4s + ld1 {v22.4s}, [x13], x15 + fmla v25.4s, v1.4s, v12.4s + ld1 {v13.4s}, [x11], x15 + fmla v24.4s, v2.4s, v11.4s + ld1 {v18.4s}, [x12], x15 + fmla v25.4s, v2.4s, v13.4s + ld1 {v23.4s}, [x13], x15 + fmla v24.4s, v3.4s, v14.4s + fmla v25.4s, v3.4s, v16.4s + fmla v24.4s, v4.4s, v15.4s + fmla v25.4s, v4.4s, v17.4s + fmla v24.4s, v5.4s, v16.4s + fmla v25.4s, v5.4s, v18.4s + fmla v24.4s, v6.4s, v19.4s + fmla v25.4s, v6.4s, v21.4s + fmla v24.4s, v7.4s, v20.4s + fmla v25.4s, v7.4s, v22.4s + fmla v24.4s, v8.4s, v21.4s + fmla v25.4s, v8.4s, v23.4s + + cbnz x10, WIDTH2_LEFT_RELU6 + cbnz x9, WIDTH2_LEFT_RELU + b WIDTH2_LEFT_WRITE + WIDTH2_LEFT_RELU6: + fmin v24.4s, v24.4s, v26.4s + fmin v25.4s, v25.4s, v26.4s + WIDTH2_LEFT_RELU: + fmax v24.4s, v24.4s, v27.4s + fmax v25.4s, v25.4s, v27.4s + WIDTH2_LEFT_WRITE: + st1 {v24.4s}, [x0], x16 + st1 {v25.4s}, [x0], x16 + b End + +WIDTH1_LEFT: + fmla v24.4s, v0.4s, v9.4s + fmla v24.4s, v1.4s, v10.4s + fmla v24.4s, v2.4s, v11.4s + fmla v24.4s, v3.4s, v14.4s + fmla v24.4s, v4.4s, v15.4s + fmla v24.4s, v5.4s, v16.4s + fmla v24.4s, v6.4s, v19.4s + fmla v24.4s, v7.4s, v20.4s + fmla v24.4s, v8.4s, v21.4s + + cbnz x10, WIDTH1_RELU6 + cbnz x9, WIDTH1_RELU + b WIDTH1_WRITE + WIDTH1_RELU6: + fmin v24.4s, v24.4s, v26.4s + WIDTH1_RELU: + fmax v24.4s, v24.4s, v27.4s + WIDTH1_WRITE: + st1 {v24.4s}, [x0] + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Vertical.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Vertical.S new file mode 100644 index 0000000000000000000000000000000000000000..dbacdb46bba03ab6f7f2c12a365ff3e99da53801 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Vertical.S @@ -0,0 +1,126 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Vertical(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, +// int in_kw_step, int channel, size_t relu, size_t relu6) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, x6: channel, x7: relu, x8: relu6 + +asm_function ConvDw3x3Vertical + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + mov x9, #4 + mul x13, x6, x9 // x6 * 4 + mul x4, x4, x9 + mul x5, x5, x9 + mov x9, #3 + mul x14, x13, x9 // x6 * 3 * 4 + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + ld1 {v23.4s}, [x3], #16 + mov x9, x1 + mov x10, x2 + + ld1 {v0.4s}, [x9], x5 + add x11, x1, x4 + ld1 {v4.4s}, [x10], x13 + ld1 {v1.4s}, [x9], x5 + add x12, x2, x14 + ld1 {v5.4s}, [x10], x13 + ld1 {v2.4s}, [x11], x5 + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + ld1 {v7.4s}, [x12], x13 + ld1 {v16.4s}, [x9], x5 + ld1 {v18.4s}, [x10], x13 + ld1 {v17.4s}, [x11], x5 + ld1 {v19.4s}, [x12], x13 + + cmp x6, #4 + ble LoopC4Post + + LoopC4: + add x1, x1, #16 + add x2, x2, #16 + fmla v23.4s, v0.4s, v4.4s + mov x9, x1 + mov x10, x2 + ld1 {v0.4s}, [x9], x5 + ld1 {v4.4s}, [x10], x13 + add x11, x1, x4 + fmla v23.4s, v1.4s, v5.4s + add x12, x2, x14 + ld1 {v1.4s}, [x9], x5 + fmla v23.4s, v2.4s, v6.4s + ld1 {v5.4s}, [x10], x13 + ld1 {v2.4s}, [x11], x5 + fmla v23.4s, v3.4s, v7.4s + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + fmla v23.4s, v16.4s, v18.4s + ld1 {v7.4s}, [x12], x13 + ld1 {v16.4s}, [x9], x5 + fmla v23.4s, v17.4s, v19.4s + ld1 {v18.4s}, [x10], x13 + ld1 {v17.4s}, [x11], x5 + ld1 {v19.4s}, [x12], x13 + + cbnz x8, C4_RELU6 + cbnz x7, C4_RELU + b C4_WRITE + C4_RELU6: + fmin v23.4s, v23.4s, v26.4s + C4_RELU: + fmax v23.4s, v23.4s, v27.4s + C4_WRITE: + st1 {v23.4s}, [x0], #16 + ld1 {v23.4s}, [x3], #16 + + sub x6, x6, #4 + cmp x6, #4 + bgt LoopC4 + + LoopC4Post: + fmla v23.4s, v0.4s, v4.4s + fmla v23.4s, v1.4s, v5.4s + fmla v23.4s, v2.4s, v6.4s + fmla v23.4s, v3.4s, v7.4s + fmla v23.4s, v16.4s, v18.4s + fmla v23.4s, v17.4s, v19.4s + + cbnz x8, RELU6 + cbnz x7, RELU + b WRITE + RELU6: + fmin v23.4s, v23.4s, v26.4s + RELU: + fmax v23.4s, v23.4s, v27.4s + WRITE: + st1 {v23.4s}, [x0], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8.S new file mode 100644 index 0000000000000000000000000000000000000000..9c5237e6d36e54dcab734c791e924c4970f0dce9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8.S @@ -0,0 +1,500 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, int input_col_size, +// int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp, +// int *out_multiplier, int *left_shift, int *right_shift, int32_t acc_min, int32_t acc_max, +// size_t per_channel) +// +// x0: output +// x1: input +// x2: weight +// x3: bias +// w4: col_size +// w5: row_size +// w6: channel +// w7: output_h +// w8: output_w +// w9: in_zp +// w10: out_zp +// w11: out_multiplier +// w12: left_shift +// w13: right_shift +// w14: acc_min +// w15: acc_max +// w16: per_channel + +asm_function ConvDw3x3Int8Neon64 + sub sp, sp, #192 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + ldr x23, [sp, #256] // per_channel + + add x19, x3, #16 + add w20, w6, w6 // channel * 2 + add w21, w4, w4 // col_size * 2 + dup v25.8b, w9 + + cbnz w23, PER_CHANNEL_DUMP + PER_LAYER_DUMP: + ld1r {v27.4s}, [x11] // out_multiplier + ld1r {v26.4s}, [x12] // left_shift + ld1r {v28.4s}, [x13] // right_shift + b MAIN_FUC + PER_CHANNEL_DUMP: + ld1 {v27.4s}, [x11] + ld1 {v26.4s}, [x12] + ld1 {v28.4s}, [x13] + MAIN_FUC: + dup v29.4s, w10 + dup v30.4s, w14 + dup v31.4s, w15 + ldr w24, [x12] + + // Load weights + ld1 {v0.8h}, [x2], x20 + ld1 {v1.8h}, [x2], x20 + ld1 {v2.8h}, [x2], x20 + ld1 {v3.8h}, [x2], x20 + ld1 {v4.8h}, [x2], x20 + ld1 {v5.8h}, [x2], x20 + ld1 {v6.8h}, [x2], x20 + ld1 {v7.8h}, [x2], x20 + ld1 {v8.8h}, [x2], x20 + + mov x16, x1 + add x17, x16, x5 + add x25, x17, x5 + ld1 {v9.8b}, [x16], x4 + ld1 {v10.8b}, [x16], x4 + ld1 {v11.8b}, [x16], x4 + ld1 {v13.8b}, [x17], x4 + ld1 {v14.8b}, [x17], x4 + ld1 {v15.8b}, [x17], x4 + ld1 {v17.8b}, [x25], x4 + ld1 {v18.8b}, [x25], x4 + ld1 {v19.8b}, [x25], x4 + + ld1 {v21.4s}, [x3] + ld1 {v22.4s}, [x19] + ld1 {v23.4s}, [x3] + ld1 {v24.4s}, [x19] + + // subtract input zp + ssubl v9.8h, v9.8b, v25.8b + ssubl v10.8h, v10.8b, v25.8b + ssubl v11.8h, v11.8b, v25.8b + ssubl v13.8h, v13.8b, v25.8b + ssubl v14.8h, v14.8b, v25.8b + ssubl v15.8h, v15.8b, v25.8b + ssubl v17.8h, v17.8b, v25.8b + ssubl v18.8h, v18.8b, v25.8b + ssubl v19.8h, v19.8b, v25.8b + + cmp w8, #2 + beq WIDTH2_LEFT + cmp w8, #1 + beq WIDTH1_LEFT + +HEIGHT1_LOOP: + smlal v21.4s, v0.4h, v9.4h + ld1 {v12.8b}, [x16] + smlal2 v22.4s, v0.8h, v9.8h + ld1 {v16.8b}, [x17] + smlal v23.4s, v0.4h, v10.4h + smlal2 v24.4s, v0.8h, v10.8h + ld1 {v20.8b}, [x25] + add x1, x1, x21 + ssubl v12.8h, v12.8b, v25.8b + smlal v21.4s, v1.4h, v10.4h + mov x16, x1 + add x17, x16, x5 + add x25, x17, x5 + smlal2 v22.4s, v1.8h, v10.8h + ld1 {v9.8b}, [x16], x4 + ssubl v16.8h, v16.8b, v25.8b + smlal v23.4s, v1.4h, v11.4h + ld1 {v10.8b}, [x16], x4 + ssubl v20.8h, v20.8b, v25.8b + smlal2 v24.4s, v1.8h, v11.8h + smlal v21.4s, v2.4h, v11.4h + smlal2 v22.4s, v2.8h, v11.8h + ld1 {v11.8b}, [x16], x4 + smlal v23.4s, v2.4h, v12.4h + smlal2 v24.4s, v2.8h, v12.8h + smlal v21.4s, v3.4h, v13.4h + smlal2 v22.4s, v3.8h, v13.8h + ld1 {v13.8b}, [x17], x4 + smlal v23.4s, v3.4h, v14.4h + smlal2 v24.4s, v3.8h, v14.8h + smlal v21.4s, v4.4h, v14.4h + smlal2 v22.4s, v4.8h, v14.8h + ld1 {v14.8b}, [x17], x4 + smlal v23.4s, v4.4h, v15.4h + smlal2 v24.4s, v4.8h, v15.8h + smlal v21.4s, v5.4h, v15.4h + smlal2 v22.4s, v5.8h, v15.8h + ld1 {v15.8b}, [x17], x4 + smlal v23.4s, v5.4h, v16.4h + smlal2 v24.4s, v5.8h, v16.8h + smlal v21.4s, v6.4h, v17.4h + smlal2 v22.4s, v6.8h, v17.8h + ld1 {v17.8b}, [x25], x4 + smlal v23.4s, v6.4h, v18.4h + smlal2 v24.4s, v6.8h, v18.8h + smlal v21.4s, v7.4h, v18.4h + smlal2 v22.4s, v7.8h, v18.8h + ld1 {v18.8b}, [x25], x4 + smlal v23.4s, v7.4h, v19.4h + smlal2 v24.4s, v7.8h, v19.8h + smlal v21.4s, v8.4h, v19.4h + smlal2 v22.4s, v8.8h, v19.8h + ld1 {v19.8b}, [x25], x4 + smlal v23.4s, v8.4h, v20.4h + smlal2 v24.4s, v8.8h, v20.8h + + cbnz w23, PER_CHANNEL_POST1 + cbz w24, SKIP_LEFTSHIFT1 + sqshl v21.4s, v21.4s, v26.4s + sqshl v22.4s, v22.4s, v26.4s + sqshl v23.4s, v23.4s, v26.4s + sqshl v24.4s, v24.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b OUTZP1 + +SKIP_LEFTSHIFT1: + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + + and v12.16b, v21.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v21.4s, v21.4s, v12.4s + sqrshl v21.4s, v21.4s, v28.4s + + and v12.16b, v22.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v22.4s, v22.4s, v12.4s + sqrshl v22.4s, v22.4s, v28.4s + + and v12.16b, v23.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v23.4s, v23.4s, v12.4s + sqrshl v23.4s, v23.4s, v28.4s + + and v12.16b, v24.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v24.4s, v24.4s, v12.4s + sqrshl v24.4s, v24.4s, v28.4s + b OUTZP1 + +PER_CHANNEL_POST1: + sqshl v21.4s, v21.4s, v26.4s + sqshl v23.4s, v23.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + ldr q26, [x12, #16] + + and v12.16b, v21.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v21.4s, v21.4s, v12.4s + sqrshl v21.4s, v21.4s, v28.4s + + and v12.16b, v23.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v23.4s, v23.4s, v12.4s + sqrshl v23.4s, v23.4s, v28.4s + + ldr q27, [x11, #16] + sqshl v22.4s, v22.4s, v26.4s + sqshl v24.4s, v24.4s, v26.4s + ldr q28, [x13, #16] + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + ld1 {v26.4s}, [x12] + + and v12.16b, v22.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v22.4s, v22.4s, v12.4s + sqrshl v22.4s, v22.4s, v28.4s + + and v12.16b, v24.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v24.4s, v24.4s, v12.4s + sqrshl v24.4s, v24.4s, v28.4s + + ld1 {v27.4s}, [x11] + ld1 {v28.4s}, [x13] + +OUTZP1: + // Add output zero point + sqadd v21.4s, v21.4s, v29.4s + sqadd v22.4s, v22.4s, v29.4s + sqadd v23.4s, v23.4s, v29.4s + sqadd v24.4s, v24.4s, v29.4s + + // Apply min bound + smax v21.4s, v21.4s, v30.4s + smax v22.4s, v22.4s, v30.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + + // Apply max bound + smin v21.4s, v21.4s, v31.4s + smin v22.4s, v22.4s, v31.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v21.4h, v21.4s + sqxtn2 v21.8h, v22.4s + ld1 {v22.4s}, [x19] + ssubl v9.8h, v9.8b, v25.8b + ssubl v10.8h, v10.8b, v25.8b + sqxtn v23.4h, v23.4s + sqxtn2 v23.8h, v24.4s + ld1 {v24.4s}, [x19] + sqxtn v21.8b, v21.8h + sqxtn2 v21.16b, v23.8h + st1 {v21.8b}, [x0], x6 + mov v23.d[0], v21.d[1] + ld1 {v21.4s}, [x3] + st1 {v23.8b}, [x0], x6 + ssubl v11.8h, v11.8b, v25.8b + ssubl v13.8h, v13.8b, v25.8b + ld1 {v23.4s}, [x3] + ssubl v14.8h, v14.8b, v25.8b + ssubl v15.8h, v15.8b, v25.8b + ssubl v17.8h, v17.8b, v25.8b + ssubl v18.8h, v18.8b, v25.8b + ssubl v19.8h, v19.8b, v25.8b + sub w8, w8, #2 + cmp w8, #2 + bgt HEIGHT1_LOOP + + cmp w8, #2 + blt WIDTH1_LEFT + +WIDTH2_LEFT: + smlal v21.4s, v0.4h, v9.4h + smlal2 v22.4s, v0.8h, v9.8h + ld1 {v12.8b}, [x16] + ssubl v12.8h, v12.8b, v25.8b + smlal v23.4s, v0.4h, v10.4h + smlal2 v24.4s, v0.8h, v10.8h + smlal v21.4s, v1.4h, v10.4h + smlal2 v22.4s, v1.8h, v10.8h + ld1 {v16.8b}, [x17] + smlal v23.4s, v1.4h, v11.4h + smlal2 v24.4s, v1.8h, v11.8h + smlal v21.4s, v2.4h, v11.4h + smlal2 v22.4s, v2.8h, v11.8h + ld1 {v20.8b}, [x25] + smlal v23.4s, v2.4h, v12.4h + smlal2 v24.4s, v2.8h, v12.8h + smlal v21.4s, v3.4h, v13.4h + smlal2 v22.4s, v3.8h, v13.8h + smlal v23.4s, v3.4h, v14.4h + smlal2 v24.4s, v3.8h, v14.8h + smlal v21.4s, v4.4h, v14.4h + smlal2 v22.4s, v4.8h, v14.8h + ssubl v16.8h, v16.8b, v25.8b + smlal v23.4s, v4.4h, v15.4h + smlal2 v24.4s, v4.8h, v15.8h + smlal v21.4s, v5.4h, v15.4h + smlal2 v22.4s, v5.8h, v15.8h + ssubl v20.8h, v20.8b, v25.8b + smlal v23.4s, v5.4h, v16.4h + smlal2 v24.4s, v5.8h, v16.8h + smlal v21.4s, v6.4h, v17.4h + smlal2 v22.4s, v6.8h, v17.8h + smlal v23.4s, v6.4h, v18.4h + smlal2 v24.4s, v6.8h, v18.8h + smlal v21.4s, v7.4h, v18.4h + smlal2 v22.4s, v7.8h, v18.8h + smlal v23.4s, v7.4h, v19.4h + smlal2 v24.4s, v7.8h, v19.8h + smlal v21.4s, v8.4h, v19.4h + smlal2 v22.4s, v8.8h, v19.8h + smlal v23.4s, v8.4h, v20.4h + smlal2 v24.4s, v8.8h, v20.8h + + cbnz w23, PER_CHANNEL_POST2 + cbz w24, SKIP_LEFTSHIFT2 + sqshl v21.4s, v21.4s, v26.4s + sqshl v22.4s, v22.4s, v26.4s + sqshl v23.4s, v23.4s, v26.4s + sqshl v24.4s, v24.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b OUTZP2 + +SKIP_LEFTSHIFT2: + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v21.4s, v21.4s, v28.4s + sqrshl v22.4s, v22.4s, v28.4s + sqrshl v23.4s, v23.4s, v28.4s + sqrshl v24.4s, v24.4s, v28.4s + b OUTZP2 + +PER_CHANNEL_POST2: + sqshl v21.4s, v21.4s, v26.4s + sqshl v23.4s, v23.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + ldr q26, [x12, #16] + sqrshl v21.4s, v21.4s, v28.4s + sqrshl v23.4s, v23.4s, v28.4s + ldr q27, [x11, #16] + sqshl v22.4s, v22.4s, v26.4s + sqshl v24.4s, v24.4s, v26.4s + ldr q28, [x13, #16] + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v22.4s, v22.4s, v28.4s + sqrshl v24.4s, v24.4s, v28.4s + +OUTZP2: + // Add output zero point + sqadd v21.4s, v21.4s, v29.4s + sqadd v22.4s, v22.4s, v29.4s + sqadd v23.4s, v23.4s, v29.4s + sqadd v24.4s, v24.4s, v29.4s + + // Apply min bound + smax v21.4s, v21.4s, v30.4s + smax v22.4s, v22.4s, v30.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + + // Apply max bound + smin v21.4s, v21.4s, v31.4s + smin v22.4s, v22.4s, v31.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v21.4h, v21.4s + sqxtn2 v21.8h, v22.4s + sqxtn v23.4h, v23.4s + sqxtn2 v23.8h, v24.4s + sqxtn v21.8b, v21.8h + sqxtn2 v21.16b, v23.8h + st1 {v21.8b}, [x0], x6 + mov v23.d[0], v21.d[1] + st1 {v23.8b}, [x0], x6 + b End + +WIDTH1_LEFT: + smlal v21.4s, v0.4h, v9.4h + smlal2 v22.4s, v0.8h, v9.8h + smlal v21.4s, v1.4h, v10.4h + smlal2 v22.4s, v1.8h, v10.8h + smlal v21.4s, v2.4h, v11.4h + smlal2 v22.4s, v2.8h, v11.8h + smlal v21.4s, v3.4h, v13.4h + smlal2 v22.4s, v3.8h, v13.8h + smlal v21.4s, v4.4h, v14.4h + smlal2 v22.4s, v4.8h, v14.8h + smlal v21.4s, v5.4h, v15.4h + smlal2 v22.4s, v5.8h, v15.8h + smlal v21.4s, v6.4h, v17.4h + smlal2 v22.4s, v6.8h, v17.8h + smlal v21.4s, v7.4h, v18.4h + smlal2 v22.4s, v7.8h, v18.8h + smlal v21.4s, v8.4h, v19.4h + smlal2 v22.4s, v8.8h, v19.8h + + cbnz w23, PER_CHANNEL_POST3 + cbz w24, SKIP_LEFTSHIFT3 + sqshl v21.4s, v21.4s, v26.4s + sqshl v22.4s, v22.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + b OUTZP3 + +SKIP_LEFTSHIFT3: + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrshl v21.4s, v21.4s, v28.4s + sqrshl v22.4s, v22.4s, v28.4s + b OUTZP3 + +PER_CHANNEL_POST3: + sqshl v21.4s, v21.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + ldr q26, [x12, #16] + sqrshl v21.4s, v21.4s, v28.4s + ldr q27, [x11, #16] + sqshl v22.4s, v22.4s, v26.4s + ldr q28, [x13, #16] + sqrdmulh v22.4s, v22.4s, v27.4s + sqrshl v22.4s, v22.4s, v28.4s + +OUTZP3: + // Add output zero point + sqadd v21.4s, v21.4s, v29.4s + sqadd v22.4s, v22.4s, v29.4s + + // Apply min bound + smax v21.4s, v21.4s, v30.4s + smax v22.4s, v22.4s, v30.4s + + // Apply max bound + smin v21.4s, v21.4s, v31.4s + smin v22.4s, v22.4s, v31.4s + + sqxtn v21.4h, v21.4s + sqxtn2 v21.8h, v22.4s + sqxtn v21.8b, v21.8h + st1 {v21.8b}, [x0], x6 + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Corner.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Corner.S new file mode 100644 index 0000000000000000000000000000000000000000..3af1f98520ebf41134b868b8278a410a18569a6f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Corner.S @@ -0,0 +1,222 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, +// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, size_t acc_min, size_t acc_max, size_t per_channel) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, +// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift +// x12: acc_min, x13: acc_max, x14: per_channel +asm_function ConvDw3x3Int8Corner + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + dup v25.8b, w7 // in_zp + ldr x8, [sp, #32] + dup v26.4s, w8 // out_zp + ldr x9, [sp, #40] // out_multiplier + ldr x10, [sp, #48] // left_shift + ldr x11, [sp, #56] // right_shift + ldr x12, [sp, #64] + dup v30.4s, w12 // acc_min + ldr x13, [sp, #72] + dup v31.4s, w13 // acc_max + ldr x14, [sp, #80] // per_channel + cbnz x14, PerChannelDump + PerLayerDump: + ld1r {v27.4s}, [x9] + ld1r {v28.4s}, [x10] + ld1r {v29.4s}, [x11] + b ContinueFunc + PerChannelDump: + ld1 {v27.4s}, [x9], #16 + ld1 {v28.4s}, [x10], #16 + ld1 {v29.4s}, [x11], #16 + ContinueFunc: + + mov x12, #2 + mul x21, x6, x12 // x6 * 2 + mov x12, #3 + mul x22, x21, x12 // x6 * 3 * 2 + + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + mov x12, x1 + mov x13, x2 + + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + add x19, x1, x4 + ld1 {v4.8h}, [x13], x21 // weight + add x20, x2, x22 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + ld1 {v5.8h}, [x13], x21 + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + ld1 {v6.8h}, [x20], x21 + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + ld1 {v7.8h}, [x20], x21 + + cmp x6, #8 + ble LoopC8Post + + LoopC8: + add x1, x1, #8 + add x2, x2, #16 + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + mov x12, x1 + mov x13, x2 + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + ld1 {v4.8h}, [x13], x21 // weight + add x19, x1, x4 + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + add x20, x2, x22 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + smlal v23.4s, v2.4h, v6.4h + ld1 {v5.8h}, [x13], x21 + smlal2 v24.4s, v2.8h, v6.8h + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + smlal v23.4s, v3.4h, v7.4h + ld1 {v6.8h}, [x20], x21 + smlal2 v24.4s, v3.8h, v7.8h + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + ld1 {v7.8h}, [x20], x21 + + cbnz x14, PerChannelPostLoop + ldr w8, [x10] + cbz w8, RightShiftLoop + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZpLoop + + RightShiftLoop: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZpLoop + PerChannelPostLoop: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v24.4s, v24.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v24.4s, v24.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + + AddZpLoop: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + sub x6, x6, #8 + cmp x6, #8 + bgt LoopC8 + + LoopC8Post: + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + smlal v23.4s, v2.4h, v6.4h + smlal2 v24.4s, v2.8h, v6.8h + smlal v23.4s, v3.4h, v7.4h + smlal2 v24.4s, v3.8h, v7.8h + + cbnz x14, PerChannelPost + ldr w8, [x10] + cbz w8, RightShift + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZp + + RightShift: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZp + PerChannelPost: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v24.4s, v24.4s, v29.4s + + AddZp: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Horizontal.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Horizontal.S new file mode 100644 index 0000000000000000000000000000000000000000..88b2e2beb1e9e8e57b9d997d479a4996894cee95 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Horizontal.S @@ -0,0 +1,255 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, +// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, size_t acc_min, size_t acc_max, size_t per_channel) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, +// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift +// x12: acc_min, x13: acc_max, x14: per_channel +asm_function ConvDw3x3Int8Horizontal + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #48 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + + dup v25.8b, w7 // in_zp + ldr x8, [sp, #48] + dup v26.4s, w8 // out_zp + ldr x9, [sp, #56] // out_multiplier + ldr x10, [sp, #64] // left_shift + ldr x11, [sp, #72] // right_shift + ldr x12, [sp, #80] + dup v30.4s, w12 // acc_min + ldr x13, [sp, #88] + dup v31.4s, w13 // acc_max + ldr x14, [sp, #96] // per_channel + cbnz x14, PerChannelDump + PerLayerDump: + ld1r {v27.4s}, [x9] + ld1r {v28.4s}, [x10] + ld1r {v29.4s}, [x11] + b ContinueFunc + PerChannelDump: + ld1 {v27.4s}, [x9], #16 + ld1 {v28.4s}, [x10], #16 + ld1 {v29.4s}, [x11], #16 + ContinueFunc: + ldr x12, [sp, #80] + dup v30.4s, w12 // acc_min + ldr x13, [sp, #88] + dup v31.4s, w13 // acc_max + + mov x12, #2 + mul x23, x6, x12 // x6 * 2 + mov x12, #3 + mul x24, x23, x12 // x6 * 3 * 2 + + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + mov x12, x1 + mov x13, x2 + + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + add x19, x1, x4 + ld1 {v4.8h}, [x13], x23 // weight + add x20, x2, x24 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + ld1 {v5.8h}, [x13], x23 + add x21, x19, x4 + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + add x22, x20, x24 + ld1 {v6.8h}, [x20], x23 + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + ld1 {v7.8h}, [x20], x23 + ld1 {v16.8b}, [x21], x5 + ssubl v16.8h, v16.8b, v25.8b + ld1 {v18.8h}, [x22], x23 + ld1 {v17.8b}, [x21], x5 + ssubl v17.8h, v17.8b, v25.8b + ld1 {v19.8h}, [x22], x23 + + cmp x6, #8 + ble LoopC8Post + + LoopC8: + add x1, x1, #8 + add x2, x2, #16 + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + mov x12, x1 + mov x13, x2 + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + ld1 {v4.8h}, [x13], x23 // weight + add x19, x1, x4 + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + + add x20, x2, x24 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + smlal v23.4s, v2.4h, v6.4h + ld1 {v5.8h}, [x13], x23 + smlal2 v24.4s, v2.8h, v6.8h + + add x21, x19, x4 + add x22, x20, x24 + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + smlal v23.4s, v3.4h, v7.4h + ld1 {v6.8h}, [x20], x23 + smlal2 v24.4s, v3.8h, v7.8h + + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + smlal v23.4s, v16.4h, v18.4h + ld1 {v7.8h}, [x20], x23 + smlal2 v24.4s, v16.8h, v18.8h + + ld1 {v16.8b}, [x21], x5 + ssubl v16.8h, v16.8b, v25.8b + smlal v23.4s, v17.4h, v19.4h + ld1 {v18.8h}, [x22], x23 + smlal2 v24.4s, v17.8h, v19.8h + ld1 {v17.8b}, [x21], x5 + ssubl v17.8h, v17.8b, v25.8b + ld1 {v19.8h}, [x22], x23 + + cbnz x14, PerChannelPostLoop + ldr w8, [x10] + cbz w8, RightShiftLoop + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZpLoop + + RightShiftLoop: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZpLoop + PerChannelPostLoop: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v24.4s, v24.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v24.4s, v24.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + + AddZpLoop: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + sub x6, x6, #8 + cmp x6, #8 + bgt LoopC8 + + LoopC8Post: + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + smlal v23.4s, v2.4h, v6.4h + smlal2 v24.4s, v2.8h, v6.8h + smlal v23.4s, v3.4h, v7.4h + smlal2 v24.4s, v3.8h, v7.8h + smlal v23.4s, v16.4h, v18.4h + smlal2 v24.4s, v16.8h, v18.8h + smlal v23.4s, v17.4h, v19.4h + smlal2 v24.4s, v17.8h, v19.8h + + cbnz x14, PerChannelPost + ldr w8, [x10] + cbz w8, RightShift + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZp + + RightShift: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZp + PerChannelPost: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v24.4s, v24.4s, v29.4s + + AddZp: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Stride2.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Stride2.S new file mode 100644 index 0000000000000000000000000000000000000000..0209dfe76be193cbe4d8d2d4acb20c583f9d0e68 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Stride2.S @@ -0,0 +1,474 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8Stride2(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, int input_col_size, +// int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp, +// int *out_multiplier, int *left_shift, int *right_shift, int32_t acc_min, int32_t acc_max +// size_t per_channel) +// +// x0: output +// x1: input +// x2: weight +// x3: bias +// w4: col_size +// w5: row_size +// w6: channel +// w7: output_h +// w8: output_w +// w9: in_zp +// w10: out_zp +// w11: out_multiplier +// w12: left_shift +// w13: right_shift +// w14: acc_min +// w15: acc_max +// w16: per_channel + +asm_function ConvDw3x3Int8Stride2 + sub sp, sp, #192 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + ldr x23, [sp, #256] // per_channel + + add x19, x3, #16 + add w20, w6, w6 // channel * 2 + dup v28.8b, w9 // in_zp + ldr w24, [x12] + + dup v29.4s, w10 + dup v30.4s, w14 + dup v31.4s, w15 + // Load weights + ld1 {v0.8h}, [x2], x20 + ld1 {v1.8h}, [x2], x20 + ld1 {v2.8h}, [x2], x20 + ld1 {v3.8h}, [x2], x20 + ld1 {v4.8h}, [x2], x20 + ld1 {v5.8h}, [x2], x20 + ld1 {v6.8h}, [x2], x20 + ld1 {v7.8h}, [x2], x20 + ld1 {v8.8h}, [x2], x20 + + mov x16, x1 + add x17, x16, x5 + add x25, x17, x5 + ld1 {v9.8b}, [x16], x4 + ld1 {v10.8b}, [x16], x4 + ssubl v9.8h, v9.8b, v28.8b + ld1 {v11.8b}, [x16], x4 + ssubl v10.8h, v10.8b, v28.8b + ld1 {v14.8b}, [x17], x4 + ssubl v11.8h, v11.8b, v28.8b + ld1 {v15.8b}, [x17], x4 + ssubl v14.8h, v14.8b, v28.8b + ld1 {v16.8b}, [x17], x4 + ssubl v15.8h, v15.8b, v28.8b + ld1 {v19.8b}, [x25], x4 + ssubl v16.8h, v16.8b, v28.8b + ld1 {v20.8b}, [x25], x4 + ssubl v19.8h, v19.8b, v28.8b + ld1 {v21.8b}, [x25], x4 + ssubl v20.8h, v20.8b, v28.8b + ssubl v21.8h, v21.8b, v28.8b + + ld1 {v24.4s}, [x3] + ld1 {v25.4s}, [x19] + ld1 {v26.4s}, [x3] + ld1 {v27.4s}, [x19] + + cmp w8, #2 + beq WIDTH2_LEFT + cmp w8, #1 + beq WIDTH1_LEFT + +HEIGHT1_LOOP: + smlal v24.4s, v0.4h, v9.4h + ld1 {v12.8b}, [x16], x4 + smlal2 v25.4s, v0.8h, v9.8h + ld1 {v17.8b}, [x17], x4 + ssubl v12.8h, v12.8b, v28.8b + smlal v26.4s, v0.4h, v11.4h + ld1 {v22.8b}, [x25], x4 + ssubl v17.8h, v17.8b, v28.8b + smlal2 v27.4s, v0.8h, v11.8h + ld1 {v13.8b}, [x16], x4 + ssubl v22.8h, v22.8b, v28.8b + smlal v24.4s, v1.4h, v10.4h + ld1 {v18.8b}, [x17], x4 + ssubl v13.8h, v13.8b, v28.8b + smlal2 v25.4s, v1.8h, v10.8h + ld1 {v23.8b}, [x25], x4 + ssubl v18.8h, v18.8b, v28.8b + smlal v26.4s, v1.4h, v12.4h + mov v9.16b, v13.16b + ssubl v23.8h, v23.8b, v28.8b + smlal2 v27.4s, v1.8h, v12.8h + ld1 {v10.8b}, [x16], x4 + smlal v24.4s, v2.4h, v11.4h + smlal2 v25.4s, v2.8h, v11.8h + ld1 {v11.8b}, [x16], x4 + ssubl v10.8h, v10.8b, v28.8b + smlal v26.4s, v2.4h, v13.4h + ssubl v11.8h, v11.8b, v28.8b + smlal2 v27.4s, v2.8h, v13.8h + + smlal v24.4s, v3.4h, v14.4h + smlal2 v25.4s, v3.8h, v14.8h + mov v14.16b, v18.16b + smlal v26.4s, v3.4h, v16.4h + smlal2 v27.4s, v3.8h, v16.8h + smlal v24.4s, v4.4h, v15.4h + smlal2 v25.4s, v4.8h, v15.8h + ld1 {v15.8b}, [x17], x4 + smlal v26.4s, v4.4h, v17.4h + smlal2 v27.4s, v4.8h, v17.8h + smlal v24.4s, v5.4h, v16.4h + smlal2 v25.4s, v5.8h, v16.8h + ld1 {v16.8b}, [x17], x4 + ssubl v15.8h, v15.8b, v28.8b + smlal v26.4s, v5.4h, v18.4h + ssubl v16.8h, v16.8b, v28.8b + smlal2 v27.4s, v5.8h, v18.8h + + smlal v24.4s, v6.4h, v19.4h + smlal2 v25.4s, v6.8h, v19.8h + mov v19.16b, v23.16b + smlal v26.4s, v6.4h, v21.4h + smlal2 v27.4s, v6.8h, v21.8h + smlal v24.4s, v7.4h, v20.4h + smlal2 v25.4s, v7.8h, v20.8h + ld1 {v20.8b}, [x25], x4 + smlal v26.4s, v7.4h, v22.4h + smlal2 v27.4s, v7.8h, v22.8h + smlal v24.4s, v8.4h, v21.4h + smlal2 v25.4s, v8.8h, v21.8h + ld1 {v21.8b}, [x25], x4 + ssubl v20.8h, v20.8b, v28.8b + smlal v26.4s, v8.4h, v23.4h + ssubl v21.8h, v21.8b, v28.8b + smlal2 v27.4s, v8.8h, v23.8h + + cbnz w23, PER_CHANNEL_POST1 + ld1r {v17.4s}, [x11] + cbz w24, SKIP_LEFTSHIFT1 + ld1r {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + sqshl v25.4s, v25.4s, v12.4s + sqshl v26.4s, v26.4s, v12.4s + sqshl v27.4s, v27.4s, v12.4s + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v17.4s + b OUTZP1 + +SKIP_LEFTSHIFT1: + ld1r {v22.4s}, [x13] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v17.4s + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v22.4s + sqrshl v26.4s, v26.4s, v22.4s + sqrshl v27.4s, v27.4s, v22.4s + b OUTZP1 + +PER_CHANNEL_POST1: + ld1 {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + ldr q13, [x12, #16] + sqshl v25.4s, v25.4s, v13.4s + ld1 {v17.4s}, [x11] + sqshl v26.4s, v26.4s, v12.4s + sqshl v27.4s, v27.4s, v13.4s + ldr q18, [x11, #16] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v18.4s + ld1 {v22.4s}, [x13] + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v18.4s + ldr q23, [x13, #16] + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v23.4s + sqrshl v26.4s, v26.4s, v22.4s + sqrshl v27.4s, v27.4s, v23.4s + +OUTZP1: + // Add output zero point + sqadd v24.4s, v24.4s, v29.4s + sqadd v25.4s, v25.4s, v29.4s + sqadd v26.4s, v26.4s, v29.4s + sqadd v27.4s, v27.4s, v29.4s + + // Apply min bound + smax v24.4s, v24.4s, v30.4s + smax v25.4s, v25.4s, v30.4s + smax v26.4s, v26.4s, v30.4s + smax v27.4s, v27.4s, v30.4s + + // Apply max bound + smin v24.4s, v24.4s, v31.4s + smin v25.4s, v25.4s, v31.4s + smin v26.4s, v26.4s, v31.4s + smin v27.4s, v27.4s, v31.4s + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + ld1 {v25.4s}, [x19] + sqxtn v26.4h, v26.4s + sqxtn2 v26.8h, v27.4s + ld1 {v27.4s}, [x19] + sqxtn v24.8b, v24.8h + sqxtn2 v24.16b, v26.8h + st1 {v24.8b}, [x0], x6 + mov v26.d[0], v24.d[1] + ld1 {v24.4s}, [x3] + st1 {v26.8b}, [x0], x6 + ld1 {v26.4s}, [x3] + sub w8, w8, #2 + cmp w8, #2 + bgt HEIGHT1_LOOP + + cmp w8, #2 + blt WIDTH1_LEFT + +WIDTH2_LEFT: + smlal v24.4s, v0.4h, v9.4h + ld1 {v12.8b}, [x16], x4 + smlal2 v25.4s, v0.8h, v9.8h + ld1 {v17.8b}, [x17], x4 + ssubl v12.8h, v12.8b, v28.8b + smlal v26.4s, v0.4h, v11.4h + ld1 {v22.8b}, [x25], x4 + ssubl v17.8h, v17.8b, v28.8b + smlal2 v27.4s, v0.8h, v11.8h + ld1 {v13.8b}, [x16], x4 + ssubl v22.8h, v22.8b, v28.8b + smlal v24.4s, v1.4h, v10.4h + ld1 {v18.8b}, [x17], x4 + ssubl v13.8h, v13.8b, v28.8b + smlal2 v25.4s, v1.8h, v10.8h + ld1 {v23.8b}, [x25], x4 + ssubl v18.8h, v18.8b, v28.8b + smlal v26.4s, v1.4h, v12.4h + ssubl v23.8h, v23.8b, v28.8b + smlal2 v27.4s, v1.8h, v12.8h + smlal v24.4s, v2.4h, v11.4h + smlal2 v25.4s, v2.8h, v11.8h + smlal v26.4s, v2.4h, v13.4h + smlal2 v27.4s, v2.8h, v13.8h + + smlal v24.4s, v3.4h, v14.4h + smlal2 v25.4s, v3.8h, v14.8h + smlal v26.4s, v3.4h, v16.4h + smlal2 v27.4s, v3.8h, v16.8h + smlal v24.4s, v4.4h, v15.4h + smlal2 v25.4s, v4.8h, v15.8h + smlal v26.4s, v4.4h, v17.4h + smlal2 v27.4s, v4.8h, v17.8h + smlal v24.4s, v5.4h, v16.4h + smlal2 v25.4s, v5.8h, v16.8h + smlal v26.4s, v5.4h, v18.4h + smlal2 v27.4s, v5.8h, v18.8h + + smlal v24.4s, v6.4h, v19.4h + smlal2 v25.4s, v6.8h, v19.8h + smlal v26.4s, v6.4h, v21.4h + smlal2 v27.4s, v6.8h, v21.8h + smlal v24.4s, v7.4h, v20.4h + smlal2 v25.4s, v7.8h, v20.8h + smlal v26.4s, v7.4h, v22.4h + smlal2 v27.4s, v7.8h, v22.8h + smlal v24.4s, v8.4h, v21.4h + smlal2 v25.4s, v8.8h, v21.8h + smlal v26.4s, v8.4h, v23.4h + smlal2 v27.4s, v8.8h, v23.8h + + cbnz w23, PER_CHANNEL_POST2 + ld1r {v17.4s}, [x11] + cbz w24, SKIP_LEFTSHIFT2 + ld1r {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + sqshl v25.4s, v25.4s, v12.4s + sqshl v26.4s, v26.4s, v12.4s + sqshl v27.4s, v27.4s, v12.4s + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v17.4s + b OUTZP2 + +SKIP_LEFTSHIFT2: + ld1r {v22.4s}, [x13] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v17.4s + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v22.4s + sqrshl v26.4s, v26.4s, v22.4s + sqrshl v27.4s, v27.4s, v22.4s + b OUTZP2 + +PER_CHANNEL_POST2: + ld1 {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + ldr q13, [x12, #16] + sqshl v25.4s, v25.4s, v13.4s + ld1 {v17.4s}, [x11] + sqshl v26.4s, v26.4s, v12.4s + sqshl v27.4s, v27.4s, v13.4s + ldr q18, [x11, #16] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v18.4s + ld1 {v22.4s}, [x13] + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v18.4s + ldr q23, [x13, #16] + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v23.4s + sqrshl v26.4s, v26.4s, v22.4s + sqrshl v27.4s, v27.4s, v23.4s + +OUTZP2: + // Add output zero point + sqadd v24.4s, v24.4s, v29.4s + sqadd v25.4s, v25.4s, v29.4s + sqadd v26.4s, v26.4s, v29.4s + sqadd v27.4s, v27.4s, v29.4s + + // Apply min bound + smax v24.4s, v24.4s, v30.4s + smax v25.4s, v25.4s, v30.4s + smax v26.4s, v26.4s, v30.4s + smax v27.4s, v27.4s, v30.4s + + // Apply max bound + smin v24.4s, v24.4s, v31.4s + smin v25.4s, v25.4s, v31.4s + smin v26.4s, v26.4s, v31.4s + smin v27.4s, v27.4s, v31.4s + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v26.4h, v26.4s + sqxtn2 v26.8h, v27.4s + sqxtn v24.8b, v24.8h + sqxtn2 v24.16b, v26.8h + st1 {v24.8b}, [x0], x6 + mov v26.d[0], v24.d[1] + st1 {v26.8b}, [x0], x6 + b End + +WIDTH1_LEFT: + smlal v24.4s, v0.4h, v9.4h + smlal2 v25.4s, v0.8h, v9.8h + smlal v24.4s, v1.4h, v10.4h + smlal2 v25.4s, v1.8h, v10.8h + smlal v24.4s, v2.4h, v11.4h + smlal2 v25.4s, v2.8h, v11.8h + smlal v24.4s, v3.4h, v14.4h + smlal2 v25.4s, v3.8h, v14.8h + smlal v24.4s, v4.4h, v15.4h + smlal2 v25.4s, v4.8h, v15.8h + smlal v24.4s, v5.4h, v16.4h + smlal2 v25.4s, v5.8h, v16.8h + smlal v24.4s, v6.4h, v19.4h + smlal2 v25.4s, v6.8h, v19.8h + smlal v24.4s, v7.4h, v20.4h + smlal2 v25.4s, v7.8h, v20.8h + smlal v24.4s, v8.4h, v21.4h + smlal2 v25.4s, v8.8h, v21.8h + + cbnz w23, PER_CHANNEL_POST3 + ld1r {v17.4s}, [x11] + cbz w24, SKIP_LEFTSHIFT3 + ld1r {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + sqshl v25.4s, v25.4s, v12.4s + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + b OUTZP3 + +SKIP_LEFTSHIFT3: + ld1r {v22.4s}, [x13] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v22.4s + b OUTZP3 + +PER_CHANNEL_POST3: + ld1 {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + ldr q13, [x12, #16] + sqshl v25.4s, v25.4s, v13.4s + ld1 {v17.4s}, [x11] + ldr q18, [x11, #16] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v18.4s + ld1 {v22.4s}, [x13] + ldr q23, [x13, #16] + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v23.4s + +OUTZP3: + // Add output zero point + sqadd v24.4s, v24.4s, v29.4s + sqadd v25.4s, v25.4s, v29.4s + + // Apply min bound + smax v24.4s, v24.4s, v30.4s + smax v25.4s, v25.4s, v30.4s + + // Apply max bound + smin v24.4s, v24.4s, v31.4s + smin v25.4s, v25.4s, v31.4s + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v24.8b, v24.8h + st1 {v24.8b}, [x0], x6 + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Vertical.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Vertical.S new file mode 100644 index 0000000000000000000000000000000000000000..a0c2ca54def5653f60696891e393d1355333ca66 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Vertical.S @@ -0,0 +1,245 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, +// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, size_t acc_min, size_t acc_max, size_t per_channel) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, +// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift +// x12: acc_min, x13: acc_max, x14: per_channel +asm_function ConvDw3x3Int8Vertical + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + dup v25.8b, w7 // in_zp + ldr x8, [sp, #32] + dup v26.4s, w8 // out_zp + ldr x9, [sp, #40] // out_multiplier + ldr x10, [sp, #48] // left_shift + ldr x11, [sp, #56] // right_shift + ldr x12, [sp, #64] + dup v30.4s, w12 // acc_min + ldr x13, [sp, #72] + dup v31.4s, w13 // acc_max + ldr x14, [sp, #80] // per_channel + cbnz x14, PerChannelDump + PerLayerDump: + ld1r {v27.4s}, [x9] + ld1r {v28.4s}, [x10] + ld1r {v29.4s}, [x11] + b ContinueFunc + PerChannelDump: + ld1 {v27.4s}, [x9], #16 + ld1 {v28.4s}, [x10], #16 + ld1 {v29.4s}, [x11], #16 + ContinueFunc: + + mov x12, #2 + mul x21, x6, x12 // x6 * 2 + mov x12, #3 + mul x22, x21, x12 // x6 * 3 * 2 + + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + mov x12, x1 + mov x13, x2 + + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + add x19, x1, x4 + ld1 {v4.8h}, [x13], x21 // weight + add x20, x2, x22 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + ld1 {v5.8h}, [x13], x21 + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + ld1 {v6.8h}, [x20], x21 + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + ld1 {v7.8h}, [x20], x21 + ld1 {v16.8b}, [x12], x5 + ssubl v16.8h, v16.8b, v25.8b + ld1 {v18.8h}, [x13], x21 + ld1 {v17.8b}, [x19], x5 + ssubl v17.8h, v17.8b, v25.8b + ld1 {v19.8h}, [x20], x21 + + cmp x6, #8 + ble LoopC8Post + + LoopC8: + add x1, x1, #8 + add x2, x2, #16 + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + mov x12, x1 + mov x13, x2 + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + ld1 {v4.8h}, [x13], x21 // weight + add x19, x1, x4 + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + + add x20, x2, x22 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + smlal v23.4s, v2.4h, v6.4h + ld1 {v5.8h}, [x13], x21 + smlal2 v24.4s, v2.8h, v6.8h + + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + smlal v23.4s, v3.4h, v7.4h + ld1 {v6.8h}, [x20], x21 + smlal2 v24.4s, v3.8h, v7.8h + + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + smlal v23.4s, v16.4h, v18.4h + ld1 {v7.8h}, [x20], x21 + smlal2 v24.4s, v16.8h, v18.8h + + ld1 {v16.8b}, [x12], x5 + ssubl v16.8h, v16.8b, v25.8b + smlal v23.4s, v17.4h, v19.4h + ld1 {v18.8h}, [x13], x21 + smlal2 v24.4s, v17.8h, v19.8h + ld1 {v17.8b}, [x19], x5 + ssubl v17.8h, v17.8b, v25.8b + ld1 {v19.8h}, [x20], x21 + + cbnz x14, PerChannelPostLoop + ldr w8, [x10] + cbz w8, RightShiftLoop + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZpLoop + + RightShiftLoop: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZpLoop + PerChannelPostLoop: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v24.4s, v24.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v24.4s, v24.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + + AddZpLoop: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + sub x6, x6, #8 + cmp x6, #8 + bgt LoopC8 + + LoopC8Post: + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + smlal v23.4s, v2.4h, v6.4h + smlal2 v24.4s, v2.8h, v6.8h + smlal v23.4s, v3.4h, v7.4h + smlal2 v24.4s, v3.8h, v7.8h + smlal v23.4s, v16.4h, v18.4h + smlal2 v24.4s, v16.8h, v18.8h + smlal v23.4s, v17.4h, v19.4h + smlal2 v24.4s, v17.8h, v19.8h + + cbnz x14, PerChannelPost + ldr w8, [x10] + cbz w8, RightShift + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZp + + RightShift: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZp + PerChannelPost: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v24.4s, v24.4s, v29.4s + + AddZp: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Line.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Line.S new file mode 100644 index 0000000000000000000000000000000000000000..3a0f8af61ff28bd4d9fa6f24cfb56e399944402b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Line.S @@ -0,0 +1,203 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel, +// bool relu, bool relu6) + +// x0: dst, x1: lines, x2: weight, x3: bias, x4: width, x5: ori_channel, x6: relu, x7: relu6 +asm_function ConvDw3x3Line + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + + ldr x8, [x1] + ldr x9, [x1, #8] + ldr x10, [x1, #16] + mov x11, x5 + mov x16, #4 + mul x16, x5, x16 + + mov w14, #6 + dup v30.4s, w14 + scvtf v30.4s, v30.4s + + LoopC4: + cbz x3, NoBias + ld1 {v31.4s}, [x3], #16 + NoBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x2], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + mov x12, x0 + mov x13, x4 + + cmp x13, #2 + blt LoopOwRemain + LoopOw2: + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x8], #64 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x9], #64 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64 + fmul v24.4s, v12.4s, v0.4s + fmul v25.4s, v13.4s, v1.4s + fmul v26.4s, v14.4s, v2.4s + fmul v27.4s, v15.4s, v3.4s + fmla v24.4s, v16.4s, v4.4s + fmla v25.4s, v17.4s, v5.4s + fmla v26.4s, v18.4s, v6.4s + fmla v27.4s, v19.4s, v7.4s + fmla v24.4s, v20.4s, v8.4s + fmla v25.4s, v21.4s, v9.4s + fmla v26.4s, v22.4s, v10.4s + fmla v27.4s, v23.4s, v11.4s + + fadd v28.4s, v25.4s, v26.4s + fadd v28.4s, v28.4s, v24.4s + fsub v29.4s, v27.4s, v26.4s + fadd v29.4s, v29.4s, v25.4s + + cbz x3, Activation + Bias: + fadd v28.4s, v28.4s, v31.4s + fadd v29.4s, v29.4s, v31.4s + + Activation: + cbnz x7, Relu6 + cbnz x6, Relu + b Write + Relu6: + fmin v28.4s, v28.4s, v30.4s + fmin v29.4s, v29.4s, v30.4s + Relu: + movi v27.16b, #0 + fmax v28.4s, v28.4s, v27.4s + fmax v29.4s, v29.4s, v27.4s + Write: + add x15, x12, x16 + cmp x11, #4 + bge Write4 + cmp x11, #3 + beq Write3 + cmp x11, #2 + beq Write2 + cmp x11, #1 + beq Write1 + + Write1: + str s28, [x12] + str s29, [x15] + b WriteEnd + Write2: + st1 {v28.2s}, [x12] + st1 {v29.2s}, [x15] + b WriteEnd + Write3: + st1 {v28.2s}, [x12] + add x17, x12, #8 + st1 {v28.s}[2], [x17] + st1 {v29.2s}, [x15] + add x18, x15, #8 + st1 {v29.s}[2], [x18] + b WriteEnd + Write4: + st1 {v28.4s}, [x12] + st1 {v29.4s}, [x15] + + WriteEnd: + add x12, x15, x16 + sub x13, x13, #2 + cmp x13, #2 + bge LoopOw2 + cmp x13, #0 + beq LoopOwEnd + + LoopOwRemain: + ld1 {v12.4s, v13.4s, v14.4s}, [x8] + add x8, x8, #64 + ld1 {v16.4s, v17.4s, v18.4s}, [x9] + add x9, x9, #64 + ld1 {v20.4s, v21.4s, v22.4s}, [x10] + add x10, x10, #64 + fmul v24.4s, v12.4s, v0.4s + fmul v25.4s, v13.4s, v1.4s + fmul v26.4s, v14.4s, v2.4s + + fmla v24.4s, v16.4s, v4.4s + fmla v25.4s, v17.4s, v5.4s + fmla v26.4s, v18.4s, v6.4s + + fmla v24.4s, v20.4s, v8.4s + fmla v25.4s, v21.4s, v9.4s + fmla v26.4s, v22.4s, v10.4s + + fadd v28.4s, v25.4s, v26.4s + fadd v28.4s, v28.4s, v24.4s + + cbz x3, ActivationRemain + BiasRemain: + fadd v28.4s, v28.4s, v31.4s + + ActivationRemain: + cbnz x7, Relu6Remain + cbnz x6, ReluRemain + b WriteRemain + Relu6Remain: + fmin v28.4s, v28.4s, v30.4s + ReluRemain: + movi v27.16b, #0 + fmax v28.4s, v28.4s, v27.4s + WriteRemain: + cmp x11, #4 + bge Write4Remain + cmp x11, #3 + beq Write3Remain + cmp x11, #2 + beq Write2Remain + cmp x11, #1 + beq Write1Remain + + Write1Remain: + str s28, [x12] + b LoopOwEnd + Write2Remain: + st1 {v28.2s}, [x12] + b LoopOwEnd + Write3Remain: + st1 {v28.2s}, [x12] + add x17, x12, #8 + st1 {v28.s}[2], [x17] + b LoopOwEnd + Write4Remain: + st1 {v28.4s}, [x12] + + LoopOwEnd: + subs x11, x11, #4 + add x0, x0, #16 + bgt LoopC4 + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Border.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Border.S new file mode 100644 index 0000000000000000000000000000000000000000..5f4744dceef216a72229371e20419c3b5edc080d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Border.S @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: in_kh_step, x7: in_kw_step, +// x8: kernel_w, x9: relu, x10: relu6 +asm_function ConvDwFp32Border + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + + ld1 {v0.4s}, [x3] // bias + movi v1.4s, #6 // relu 6 + scvtf v1.4s, v1.4s + dup v2.4s, wzr // relu + + mov x13, x1 + mov x14, x2 + LoopH: + mov x15, x13 + mov x16, x14 + mov x17, x5 + LoopW: + ld1 {v3.4s}, [x15], x7 + ld1 {v4.4s}, [x16], #16 + fmla v0.4s, v3.4s, v4.4s + subs x17, x17, #1 + bne LoopW + subs x4, x4, #1 + add x13, x13, x6 + add x14, x14, x8 + bne LoopH + cbnz x10, Relu6 + cbnz x9, Relu + b Write + Relu6: + fmin v0.4s, v0.4s, v1.4s + Relu: + fmax v0.4s, v0.4s, v2.4s + Write: + st1 {v0.4s}, [x0] + + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Center.S new file mode 100644 index 0000000000000000000000000000000000000000..568c1a33b4fa6d9c3322e788c6deb67945e2e0d5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Center.S @@ -0,0 +1,313 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: kernel_h, x7: kernel_w, +// x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step +// x14: relu, x15: relu6 +asm_function ConvDwFp32Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #192 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + + ld1 {v24.4s}, [x3] + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + LoopH: + mov x23, x1 + mov x24, x5 + mov x3, x0 + cmp x24, #8 + blt LoopW + cmp x24, #16 + blt LoopW8 + + LoopW16: + mov x19, #16 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + mov v8.16b, v24.16b + mov v9.16b, v24.16b + mov v10.16b, v24.16b + mov v11.16b, v24.16b + mov v12.16b, v24.16b + mov v13.16b, v24.16b + mov v14.16b, v24.16b + mov v15.16b, v24.16b + LoopKh16: + mov x25, x7 + mov x21, x16 + LoopKw16: + mov x22, x21 + ld1 {v25.4s}, [x17], #16 + ld1 {v16.4s}, [x22], x11 + ld1 {v17.4s}, [x22], x11 + fmla v0.4s, v16.4s, v25.4s + fmla v1.4s, v17.4s, v25.4s + ld1 {v18.4s}, [x22], x11 + ld1 {v19.4s}, [x22], x11 + fmla v2.4s, v18.4s, v25.4s + fmla v3.4s, v19.4s, v25.4s + ld1 {v20.4s}, [x22], x11 + ld1 {v21.4s}, [x22], x11 + fmla v4.4s, v20.4s, v25.4s + fmla v5.4s, v21.4s, v25.4s + ld1 {v22.4s}, [x22], x11 + ld1 {v23.4s}, [x22], x11 + fmla v6.4s, v22.4s, v25.4s + fmla v7.4s, v23.4s, v25.4s + ld1 {v16.4s}, [x22], x11 + ld1 {v17.4s}, [x22], x11 + fmla v8.4s, v16.4s, v25.4s + fmla v9.4s, v17.4s, v25.4s + ld1 {v18.4s}, [x22], x11 + ld1 {v19.4s}, [x22], x11 + fmla v10.4s, v18.4s, v25.4s + fmla v11.4s, v19.4s, v25.4s + ld1 {v20.4s}, [x22], x11 + ld1 {v21.4s}, [x22], x11 + fmla v12.4s, v20.4s, v25.4s + fmla v13.4s, v21.4s, v25.4s + ld1 {v22.4s}, [x22], x11 + ld1 {v23.4s}, [x22], x11 + fmla v14.4s, v22.4s, v25.4s + fmla v15.4s, v23.4s, v25.4s + subs x25, x25, #1 + add x21, x21, x13 + bne LoopKw16 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh16 + cbnz x15, Relu616 + cbnz x14, Relu16 + b Write16 + Relu616: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + fmin v8.4s, v8.4s, v26.4s + fmin v9.4s, v9.4s, v26.4s + fmin v10.4s, v10.4s, v26.4s + fmin v11.4s, v11.4s, v26.4s + fmin v12.4s, v12.4s, v26.4s + fmin v13.4s, v13.4s, v26.4s + fmin v14.4s, v14.4s, v26.4s + fmin v15.4s, v15.4s, v26.4s + Relu16: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + fmax v8.4s, v8.4s, v27.4s + fmax v9.4s, v9.4s, v27.4s + fmax v10.4s, v10.4s, v27.4s + fmax v11.4s, v11.4s, v27.4s + fmax v12.4s, v12.4s, v27.4s + fmax v13.4s, v13.4s, v27.4s + fmax v14.4s, v14.4s, v27.4s + fmax v15.4s, v15.4s, v27.4s + Write16: + st1 {v0.4s}, [x3], x9 + st1 {v1.4s}, [x3], x9 + st1 {v2.4s}, [x3], x9 + st1 {v3.4s}, [x3], x9 + st1 {v4.4s}, [x3], x9 + st1 {v5.4s}, [x3], x9 + st1 {v6.4s}, [x3], x9 + st1 {v7.4s}, [x3], x9 + st1 {v8.4s}, [x3], x9 + st1 {v9.4s}, [x3], x9 + st1 {v10.4s}, [x3], x9 + st1 {v11.4s}, [x3], x9 + st1 {v12.4s}, [x3], x9 + st1 {v13.4s}, [x3], x9 + st1 {v14.4s}, [x3], x9 + st1 {v15.4s}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #16 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + blt LoopW + cmp x24, #16 + bge LoopW16 + LoopW8: + mov x19, #8 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + LoopKh8: + mov x25, x7 + mov x21, x16 + LoopKw8: + mov x22, x21 + ld1 {v25.4s}, [x17], #16 + ld1 {v16.4s}, [x22], x11 + ld1 {v17.4s}, [x22], x11 + fmla v0.4s, v16.4s, v25.4s + fmla v1.4s, v17.4s, v25.4s + ld1 {v18.4s}, [x22], x11 + ld1 {v19.4s}, [x22], x11 + fmla v2.4s, v18.4s, v25.4s + fmla v3.4s, v19.4s, v25.4s + ld1 {v20.4s}, [x22], x11 + ld1 {v21.4s}, [x22], x11 + fmla v4.4s, v20.4s, v25.4s + fmla v5.4s, v21.4s, v25.4s + ld1 {v22.4s}, [x22], x11 + ld1 {v23.4s}, [x22], x11 + fmla v6.4s, v22.4s, v25.4s + fmla v7.4s, v23.4s, v25.4s + subs x25, x25, #1 + add x21, x21, x13 + bne LoopKw8 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh8 + cbnz x15, Relu68 + cbnz x14, Relu8 + b Write8 + Relu68: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + Relu8: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + Write8: + st1 {v0.4s}, [x3], x9 + st1 {v1.4s}, [x3], x9 + st1 {v2.4s}, [x3], x9 + st1 {v3.4s}, [x3], x9 + st1 {v4.4s}, [x3], x9 + st1 {v5.4s}, [x3], x9 + st1 {v6.4s}, [x3], x9 + st1 {v7.4s}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #8 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + bge LoopW8 + LoopW: + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + LoopKh: + mov x25, x7 + mov x22, x16 + LoopKw: + ld1 {v16.4s}, [x22], x13 + ld1 {v25.4s}, [x17], #16 + fmla v0.4s, v16.4s, v25.4s + subs x25, x25, #1 + bne LoopKw + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh + cbnz x15, Relu6 + cbnz x14, Relu + b Write + Relu6: + fmin v0.4s, v0.4s, v26.4s + Relu: + fmax v0.4s, v0.4s, v27.4s + Write: + st1 {v0.4s}, [x3], x9 + add x23, x23, x11 + subs x24, x24, #1 + bne LoopW + LoopWEnd: + add x0, x0, x8 + add x1, x1, x10 + subs x4, x4, #1 + bne LoopH + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Indirect3x3.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Indirect3x3.S new file mode 100644 index 0000000000000000000000000000000000000000..aafde32194aa2491a920a196e1963d2625d76a44 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Indirect3x3.S @@ -0,0 +1,159 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Indirect3x3(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, +// size_t input_stride, size_t relu, size_t relu6) +// x0: output, x1: input, x2: weights, x3: bias, x4: channels, x5: output_width, x6: input_stride, x7: relu, x8: relu6 + +asm_function ConvDwFp32Indirect3x3 + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + movi v31.4s, #6 + scvtf v31.4s, v31.4s + dup v30.4s, wzr + + ldr x8, [sp, #32] + cmp x5, #0 + beq End + + LoopPixel: + ldp x12, x13, [x1] + ldp x14, x15, [x1, #16] + ldp x16, x17, [x1, #32] + ldp x21, x19, [x1, #48] + ldr x20, [x1, #64] + mov x9, x2 + mov x10, x3 + mov x11, x4 + + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x13], #16 + ld1 {v2.4s}, [x14], #16 + + ld1 {v17.4s}, [x9], #16 + ld1 {v18.4s}, [x9], #16 + ld1 {v19.4s}, [x9], #16 + + ld1 {v29.4s}, [x10], #16 + cmp x11, #4 + ble LeftLoop + LoopC4: + fmla v29.4s, v0.4s, v17.4s + ld1 {v3.4s}, [x15], #16 + ld1 {v20.4s}, [x9], #16 + fmla v29.4s, v1.4s, v18.4s + ld1 {v4.4s}, [x16], #16 + ld1 {v21.4s}, [x9], #16 + fmla v29.4s, v2.4s, v19.4s + ld1 {v5.4s}, [x17], #16 + ld1 {v22.4s}, [x9], #16 + fmla v29.4s, v3.4s, v20.4s + ld1 {v6.4s}, [x21], #16 + ld1 {v23.4s}, [x9], #16 + fmla v29.4s, v4.4s, v21.4s + ld1 {v7.4s}, [x19], #16 + ld1 {v24.4s}, [x9], #16 + fmla v29.4s, v5.4s, v22.4s + ld1 {v16.4s}, [x20], #16 + ld1 {v25.4s}, [x9], #16 + fmla v29.4s, v6.4s, v23.4s + ld1 {v0.4s}, [x12], #16 + ld1 {v17.4s}, [x9], #16 + fmla v29.4s, v7.4s, v24.4s + ld1 {v1.4s}, [x13], #16 + ld1 {v18.4s}, [x9], #16 + fmla v29.4s, v16.4s, v25.4s + ld1 {v2.4s}, [x14], #16 + ld1 {v19.4s}, [x9], #16 + + cbnz x8, Relu6 + cbnz x7, Relu + b Write + Relu6: + fmin v29.4s, v29.4s, v31.4s + Relu: + fmax v29.4s, v29.4s, v30.4s + Write: + st1 {v29.4s}, [x0], #16 + + ld1 {v29.4s}, [x10], #16 + sub x11, x11, #4 + cmp x11, #4 + bgt LoopC4 + + LeftLoop: + fmla v29.4s, v0.4s, v17.4s + ld1 {v3.4s}, [x15], #16 + ld1 {v20.4s}, [x9], #16 + fmla v29.4s, v1.4s, v18.4s + ld1 {v4.4s}, [x16], #16 + ld1 {v21.4s}, [x9], #16 + fmla v29.4s, v2.4s, v19.4s + ld1 {v5.4s}, [x17], #16 + ld1 {v22.4s}, [x9], #16 + fmla v29.4s, v3.4s, v20.4s + ld1 {v6.4s}, [x21], #16 + ld1 {v23.4s}, [x9], #16 + fmla v29.4s, v4.4s, v21.4s + ld1 {v7.4s}, [x19], #16 + ld1 {v24.4s}, [x9], #16 + fmla v29.4s, v5.4s, v22.4s + ld1 {v16.4s}, [x20], #16 + ld1 {v25.4s}, [x9], #16 + fmla v29.4s, v6.4s, v23.4s + fmla v29.4s, v7.4s, v24.4s + fmla v29.4s, v16.4s, v25.4s + + cbnz x8, LeftRelu6 + cbnz x7, LeftRelu + b LeftWrite + LeftRelu6: + fmin v29.4s, v29.4s, v31.4s + LeftRelu: + fmax v29.4s, v29.4s, v30.4s + LeftWrite: + cmp x11, #4 + bne Write3 + st1 {v29.4s}, [x0], #16 + b NextPixel + Write3: + sxtw x11, w11 + tbnz w11, #1, Write2 + tbnz w11, #0, Write1 + Write2: + st1 {v29.2s}, [x0], #8 + ext v29.16b, v29.16b, v29.16b, #8 + tbz w11, #0, NextPixel + Write1: + str s29, [x0], #4 + + NextPixel: + add x1, x1, x6 + sub x5, x5, #1 + cmp x5, #0 + bgt LoopPixel +End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 +ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Indirect5x5.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Indirect5x5.S new file mode 100644 index 0000000000000000000000000000000000000000..87f48ac3871814f42eb9f805f614eecd1dc507fe --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Indirect5x5.S @@ -0,0 +1,304 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, +// size_t input_stride, size_t relu, size_t relu6) +// x0: output, x1: input, x2: weights, x3: bias, x4: channels, x5: output_width, x6: input_stride, x7: relu, x8: relu6 + +asm_function ConvDwFp32Indirect5x5 + sub sp, sp, #176 + stp x19, x20, [sp, #64] + stp x21, x22, [sp, #80] + stp x23, x24, [sp, #96] + stp x25, x26, [sp, #112] + stp x27, x28, [sp, #128] + stp x29, x30, [sp, #144] + ldrb w8, [sp, #176] + stp x2, x3, [sp] + stp x4, x6, [sp, #16] + stp x7, x8, [sp, #32] + stp x0, x1, [sp, #160] + + movi v31.4s, #6 + scvtf v31.4s, v31.4s + dup v30.4s, wzr + + mov x3, x5 + cmp x3, #0 + beq End + + LoopPixel: + ldp x5, x4, [sp] // weight, bias + ld1 {v29.4s}, [x4], #16 + ldr x2, [sp, #16] // channel + + ldp x6, x7, [x1] + ldp x8, x9, [x1, #16] + ldp x10, x11, [x1, #32] + ldp x12, x13, [x1, #48] + ldp x14, x15, [x1, #64] + ldp x16, x17, [x1, #80] + ldp x0, x19, [x1, #96] + ldp x20, x21, [x1, #112] + ldp x22, x23, [x1, #128] + ldp x24, x25, [x1, #144] + ldp x26, x27, [x1, #160] + ldp x28, x29, [x1, #176] + ldr x30, [x1, #192] + + ld1 {v0.4s}, [x6], #16 + ld1 {v1.4s}, [x7], #16 + ld1 {v2.4s}, [x8], #16 + ld1 {v3.4s}, [x9], #16 + ld1 {v4.4s}, [x10], #16 + + ld1 {v18.4s}, [x5], #16 + ld1 {v19.4s}, [x5], #16 + ld1 {v20.4s}, [x5], #16 + ld1 {v21.4s}, [x5], #16 + ld1 {v22.4s}, [x5], #16 + stp x5, x4, [sp, #48] + + cmp x2, #4 + ble LeftLoop + LoopC4: + ldr x5, [sp, #48] + // column 0 + fmla v29.4s, v0.4s, v18.4s + ld1 {v5.4s}, [x11], #16 + ld1 {v23.4s}, [x5], #16 + fmla v29.4s, v1.4s, v19.4s + ld1 {v6.4s}, [x12], #16 + ld1 {v24.4s}, [x5], #16 + fmla v29.4s, v2.4s, v20.4s + ld1 {v7.4s}, [x13], #16 + ld1 {v25.4s}, [x5], #16 + fmla v29.4s, v3.4s, v21.4s + ld1 {v16.4s}, [x14], #16 + ld1 {v26.4s}, [x5], #16 + fmla v29.4s, v4.4s, v22.4s + ld1 {v17.4s}, [x15], #16 + ld1 {v27.4s}, [x5], #16 + // column 1 + fmla v29.4s, v5.4s, v23.4s + ld1 {v0.4s}, [x16], #16 + ld1 {v18.4s}, [x5], #16 + fmla v29.4s, v6.4s, v24.4s + ld1 {v1.4s}, [x17], #16 + ld1 {v19.4s}, [x5], #16 + fmla v29.4s, v7.4s, v25.4s + ld1 {v2.4s}, [x0], #16 + ld1 {v20.4s}, [x5], #16 + fmla v29.4s, v16.4s, v26.4s + ld1 {v3.4s}, [x19], #16 + ld1 {v21.4s}, [x5], #16 + fmla v29.4s, v17.4s, v27.4s + ld1 {v4.4s}, [x20], #16 + ld1 {v22.4s}, [x5], #16 + // column 2 + fmla v29.4s, v0.4s, v18.4s + ld1 {v5.4s}, [x21], #16 + ld1 {v23.4s}, [x5], #16 + fmla v29.4s, v1.4s, v19.4s + ld1 {v6.4s}, [x22], #16 + ld1 {v24.4s}, [x5], #16 + fmla v29.4s, v2.4s, v20.4s + ld1 {v7.4s}, [x23], #16 + ld1 {v25.4s}, [x5], #16 + fmla v29.4s, v3.4s, v21.4s + ld1 {v16.4s}, [x24], #16 + ld1 {v26.4s}, [x5], #16 + fmla v29.4s, v4.4s, v22.4s + ld1 {v17.4s}, [x25], #16 + ld1 {v27.4s}, [x5], #16 + // column 3 + fmla v29.4s, v5.4s, v23.4s + ld1 {v0.4s}, [x26], #16 + ld1 {v18.4s}, [x5], #16 + fmla v29.4s, v6.4s, v24.4s + ld1 {v1.4s}, [x27], #16 + ld1 {v19.4s}, [x5], #16 + fmla v29.4s, v7.4s, v25.4s + ld1 {v2.4s}, [x28], #16 + ld1 {v20.4s}, [x5], #16 + fmla v29.4s, v16.4s, v26.4s + ld1 {v3.4s}, [x29], #16 + ld1 {v21.4s}, [x5], #16 + fmla v29.4s, v17.4s, v27.4s + ld1 {v4.4s}, [x30], #16 + ld1 {v22.4s}, [x5], #16 + // column 4 + fmla v29.4s, v0.4s, v18.4s + fmla v29.4s, v1.4s, v19.4s + ld1 {v0.4s}, [x6], #16 + ld1 {v18.4s}, [x5], #16 + fmla v29.4s, v2.4s, v20.4s + ld1 {v1.4s}, [x7], #16 + ld1 {v19.4s}, [x5], #16 + fmla v29.4s, v3.4s, v21.4s + ld1 {v2.4s}, [x8], #16 + ld1 {v20.4s}, [x5], #16 + fmla v29.4s, v4.4s, v22.4s + ld1 {v3.4s}, [x9], #16 + ld1 {v21.4s}, [x5], #16 + ld1 {v4.4s}, [x10], #16 + ld1 {v22.4s}, [x5], #16 + str x5, [sp, #48] + + ldp x4, x5, [sp, #32] + cbnz x5, RELU6 + cbnz x4, RELU + b WRITE + RELU6: + fmin v29.4s, v29.4s, v31.4s + RELU: + fmax v29.4s, v29.4s, v30.4s + WRITE: + ldr x4, [sp, #160] + st1 {v29.4s}, [x4], #16 + str x4, [sp, #160] + + ldr x4, [sp, #56] + ld1 {v29.4s}, [x4], #16 + str x4, [sp, #56] + sub x2, x2, #4 + cmp x2, #4 + bgt LoopC4 + + LeftLoop: + // column 0 + ldr x5, [sp, #48] + fmla v29.4s, v0.4s, v18.4s + ld1 {v5.4s}, [x11], #16 + ld1 {v23.4s}, [x5], #16 + fmla v29.4s, v1.4s, v19.4s + ld1 {v6.4s}, [x12], #16 + ld1 {v24.4s}, [x5], #16 + fmla v29.4s, v2.4s, v20.4s + ld1 {v7.4s}, [x13], #16 + ld1 {v25.4s}, [x5], #16 + fmla v29.4s, v3.4s, v21.4s + ld1 {v16.4s}, [x14], #16 + ld1 {v26.4s}, [x5], #16 + fmla v29.4s, v4.4s, v22.4s + ld1 {v17.4s}, [x15], #16 + ld1 {v27.4s}, [x5], #16 + // column 1 + fmla v29.4s, v5.4s, v23.4s + ld1 {v0.4s}, [x16], #16 + ld1 {v18.4s}, [x5], #16 + fmla v29.4s, v6.4s, v24.4s + ld1 {v1.4s}, [x17], #16 + ld1 {v19.4s}, [x5], #16 + fmla v29.4s, v7.4s, v25.4s + ld1 {v2.4s}, [x0], #16 + ld1 {v20.4s}, [x5], #16 + fmla v29.4s, v16.4s, v26.4s + ld1 {v3.4s}, [x19], #16 + ld1 {v21.4s}, [x5], #16 + fmla v29.4s, v17.4s, v27.4s + ld1 {v4.4s}, [x20], #16 + ld1 {v22.4s}, [x5], #16 + // column 2 + fmla v29.4s, v0.4s, v18.4s + ld1 {v5.4s}, [x21], #16 + ld1 {v23.4s}, [x5], #16 + fmla v29.4s, v1.4s, v19.4s + ld1 {v6.4s}, [x22], #16 + ld1 {v24.4s}, [x5], #16 + fmla v29.4s, v2.4s, v20.4s + ld1 {v7.4s}, [x23], #16 + ld1 {v25.4s}, [x5], #16 + fmla v29.4s, v3.4s, v21.4s + ld1 {v16.4s}, [x24], #16 + ld1 {v26.4s}, [x5], #16 + fmla v29.4s, v4.4s, v22.4s + ld1 {v17.4s}, [x25], #16 + ld1 {v27.4s}, [x5], #16 + // column 3 + fmla v29.4s, v5.4s, v23.4s + ld1 {v0.4s}, [x26], #16 + ld1 {v18.4s}, [x5], #16 + fmla v29.4s, v6.4s, v24.4s + ld1 {v1.4s}, [x27], #16 + ld1 {v19.4s}, [x5], #16 + fmla v29.4s, v7.4s, v25.4s + ld1 {v2.4s}, [x28], #16 + ld1 {v20.4s}, [x5], #16 + fmla v29.4s, v16.4s, v26.4s + ld1 {v3.4s}, [x29], #16 + ld1 {v21.4s}, [x5], #16 + fmla v29.4s, v17.4s, v27.4s + ld1 {v4.4s}, [x30], #16 + ld1 {v22.4s}, [x5], #16 + // column 4 + fmla v29.4s, v0.4s, v18.4s + fmla v29.4s, v1.4s, v19.4s + fmla v29.4s, v2.4s, v20.4s + fmla v29.4s, v3.4s, v21.4s + fmla v29.4s, v4.4s, v22.4s + + ldp x4, x5, [sp, #32] + cbnz x5, LeftRelu6 + cbnz x4, LeftRelu + b LeftWrite + LeftRelu6: + fmin v29.4s, v29.4s, v31.4s + LeftRelu: + fmax v29.4s, v29.4s, v30.4s + LeftWrite: + cmp x2, #4 + bne Write3 + ldr x4, [sp, #160] + st1 {v29.4s}, [x4], #16 + str x4, [sp, #160] + b NextPixel + Write3: + sxtw x2, w2 + tbnz w2, #1, Write2 + tbnz w2, #0, Write1 + Write2: + ldr x4, [sp, #160] + st1 {v29.2s}, [x4], #8 + str x4, [sp, #160] + ext v29.16b, v29.16b, v29.16b, #8 + tbz w2, #0, NextPixel + Write1: + ldr x4, [sp, #160] + str s29, [x4], #4 + str x4, [sp, #160] + + NextPixel: + ldr x2, [sp, #24] + add x1, x1, x2 + sub x3, x3, #1 + cmp x3, #0 + bgt LoopPixel +End: + ldp x19, x20, [sp, #64] + ldp x21, x22, [sp, #80] + ldp x23, x24, [sp, #96] + ldp x25, x26, [sp, #112] + ldp x27, x28, [sp, #128] + ldp x29, x30, [sp, #144] + add sp, sp, #176 +ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Row.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Row.S new file mode 100644 index 0000000000000000000000000000000000000000..59923da401fbe748f2717c5eb6460244600600cb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Row.S @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Row(float* output_ptr, const float* input_ptr,const float* filter_ptr, +// size_t num_pixels, size_t input_channel, size_t input_step) +// x0: output_ptr, x1: input_ptr, x2: filter_ptr, x3: num_pixels, +// x4: input_channel, x5: input_step +// +asm_function ConvDwFp32Row + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters +cmp x3, #0 +ble End + +mov x9, x0 +mov x12, #4 +mul x5, x5, x12 + +LoopOutPixel: +mov x6, x1 +mov x7, x2 +mov x8, x4 + + LoopDepth16In: + cmp x8, #16 + blt L4 + sub x8, x8, #16 + + ld1 {v0.4s, v1.4s}, [x6], #32 + ld1 {v2.4s, v3.4s}, [x7], #32 + ld1 {v16.4s, v17.4s}, [x0], #32 + + cmp x8, #16 + blt LoopDepth16Out + LoopDepth16: + fmla v16.4s, v0.4s, v2.4s + fmla v17.4s, v1.4s, v3.4s + + st1 {v16.4s, v17.4s}, [x9], #32 + + ld1 {v4.4s, v5.4s}, [x6], #32 + ld1 {v6.4s, v7.4s}, [x7], #32 + ld1 {v18.4s, v19.4s}, [x0], #32 + + fmla v18.4s, v4.4s, v6.4s + fmla v19.4s, v5.4s, v7.4s + + st1 {v18.4s, v19.4s}, [x9], #32 + + ld1 {v0.4s, v1.4s}, [x6], #32 + ld1 {v2.4s, v3.4s}, [x7], #32 + ld1 {v16.4s, v17.4s}, [x0], #32 + + sub x8, x8, #16 + cmp x8, #16 + bge LoopDepth16 + + LoopDepth16Out: + fmla v16.4s, v0.4s, v2.4s + fmla v17.4s, v1.4s, v3.4s + st1 {v16.4s, v17.4s}, [x9], #32 + + ld1 {v4.4s, v5.4s}, [x6], #32 + ld1 {v6.4s, v7.4s}, [x7], #32 + ld1 {v18.4s, v19.4s}, [x0], #32 + + fmla v18.4s, v4.4s, v6.4s + fmla v19.4s, v5.4s, v7.4s + + st1 {v18.4s, v19.4s}, [x9], #32 + + L4: + cmp x8, #4 + blt L0 + + LoopDepth4: + ld1 {v0.4s}, [x6], #16 + ld1 {v2.4s}, [x7], #16 + ld1 {v16.4s}, [x0], #16 + fmla v16.4s, v0.4s, v2.4s + st1 {v16.4s}, [x9], #16 + sub x8, x8, #4 + cmp x8, #4 + bge LoopDepth4 + + L0: + cmp x8, #0 + beq Loop16LineEnd + + LoopDepth0: + ldr s0, [x6], #4 + ldr s1, [x7], #4 + ldr s2, [x0], #4 + fmul s0, s0, s1 + fadd s2, s2, s0 + str s2, [x9], #4 + subs x8, x8, #1 + bne LoopDepth0 + + Loop16LineEnd: + +subs x3, x3, #1 +add x1, x1, x5 +bne LoopOutPixel + +End: +ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8Center.S new file mode 100644 index 0000000000000000000000000000000000000000..2648795e8b92faccada30b211a82b2b4882ee872 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8Center.S @@ -0,0 +1,294 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, +// size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, +// size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp, +// int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, +// int32_t *acc_min, int32_t *acc_max) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: weight, x6: kernel_h, x7: kernel_w, +// x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step +// x14: in_zp, #56: out_zp, #64: out_multiplier, #72:left_shift, #80: right_shift, #88: acc_min, #96: acc_max +asm_function ConvDwInt8Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #192 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + + ldr x14, [sp, #240] // input_zp + ld1 {v19.8b}, [x14], #8 + + ldr x15, [sp, #248] // output_zp + ld1 {v20.4s}, [x15], #16 + ld1 {v21.4s}, [x15], #16 + + ldr x16, [sp, #256] // out_multiplier + ld1 {v22.4s}, [x16], #16 + ld1 {v23.4s}, [x16], #16 + + ldr x17, [sp, #264] // left_shift + ld1 {v24.4s}, [x17], #16 + ld1 {v25.4s}, [x17], #16 + + ldr x25, [sp, #272] // right shift + ld1 {v26.4s}, [x25], #16 + ld1 {v27.4s}, [x25], #16 + + ldr x19, [sp, #280] // acc_min + ld1 {v28.4s}, [x19], #16 + ld1 {v29.4s}, [x19], #16 + + ldr x20, [sp, #288] // acc_max + ld1 {v30.4s}, [x20], #16 + ld1 {v31.4s}, [x20], #16 + + ld1 {v17.4s}, [x3], #16 + ld1 {v18.4s}, [x3], #16 + + LoopH: + mov x23, x1 + mov x24, x5 + mov x3, x0 + + LoopW4: + mov x19, #4 + mul x19, x19, x11 + mov x25, #4 + mul x25, x25, x9 + + mov x16, x23 + mov x17, x2 + mov x20, x6 + + mov v0.16b, v17.16b + mov v1.16b, v18.16b + mov v2.16b, v17.16b + mov v3.16b, v18.16b + mov v4.16b, v17.16b + mov v5.16b, v18.16b + mov v6.16b, v17.16b + mov v7.16b, v18.16b + LoopKh4: + mov x25, x7 + mov x21, x16 + LoopKw4: + mov x22, x21 + ld1 {v16.8h}, [x17], #16 + + ld1 {v15.8b}, [x22], x11 + ssubl v14.8h, v15.8b, v19.8b + smlal v0.4s, v14.4h, v16.4h + smlal2 v1.4s, v14.8h, v16.8h + + ld1 {v13.8b}, [x22], x11 + ssubl v12.8h, v13.8b, v19.8b + smlal v2.4s, v12.4h, v16.4h + smlal2 v3.4s, v12.8h, v16.8h + + ld1 {v11.8b}, [x22], x11 + ssubl v10.8h, v11.8b, v19.8b + smlal v4.4s, v10.4h, v16.4h + smlal2 v5.4s, v10.8h, v16.8h + + ld1 {v9.8b}, [x22], x11 + ssubl v8.8h, v9.8b, v19.8b + smlal v6.4s, v8.4h, v16.4h + smlal2 v7.4s, v8.8h, v16.8h + + subs x25, x25, #1 + add x21, x21, x13 + bne LoopKw4 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh4 + + sqshl v0.4s, v0.4s, v24.4s + sqshl v1.4s, v1.4s, v25.4s + sqshl v2.4s, v2.4s, v24.4s + sqshl v3.4s, v3.4s, v25.4s + sqshl v4.4s, v4.4s, v24.4s + sqshl v5.4s, v5.4s, v25.4s + sqshl v6.4s, v6.4s, v24.4s + sqshl v7.4s, v7.4s, v25.4s + + sqrdmulh v0.4s, v0.4s, v22.4s + sqrdmulh v1.4s, v1.4s, v23.4s + sqrdmulh v2.4s, v2.4s, v22.4s + sqrdmulh v3.4s, v3.4s, v23.4s + sqrdmulh v4.4s, v4.4s, v22.4s + sqrdmulh v5.4s, v5.4s, v23.4s + sqrdmulh v6.4s, v6.4s, v22.4s + sqrdmulh v7.4s, v7.4s, v23.4s + + sqrshl v0.4s, v0.4s, v26.4s + sqrshl v1.4s, v1.4s, v27.4s + sqrshl v2.4s, v2.4s, v26.4s + sqrshl v3.4s, v3.4s, v27.4s + sqrshl v4.4s, v4.4s, v26.4s + sqrshl v5.4s, v5.4s, v27.4s + sqrshl v6.4s, v6.4s, v26.4s + sqrshl v7.4s, v7.4s, v27.4s + + add v0.4s, v0.4s, v20.4s + add v1.4s, v1.4s, v21.4s + add v2.4s, v2.4s, v20.4s + add v3.4s, v3.4s, v21.4s + add v4.4s, v4.4s, v20.4s + add v5.4s, v5.4s, v21.4s + add v6.4s, v6.4s, v20.4s + add v7.4s, v7.4s, v21.4s + smax v0.4s, v0.4s, v28.4s + smax v1.4s, v1.4s, v29.4s + smax v2.4s, v2.4s, v28.4s + smax v3.4s, v3.4s, v29.4s + smax v4.4s, v4.4s, v28.4s + smax v5.4s, v5.4s, v29.4s + smax v6.4s, v6.4s, v28.4s + smax v7.4s, v7.4s, v29.4s + smin v0.4s, v0.4s, v30.4s + smin v1.4s, v1.4s, v31.4s + smin v2.4s, v2.4s, v30.4s + smin v3.4s, v3.4s, v31.4s + smin v4.4s, v4.4s, v30.4s + smin v5.4s, v5.4s, v31.4s + smin v6.4s, v6.4s, v30.4s + smin v7.4s, v7.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v1.4h, v1.4s + sqxtn v2.4h, v2.4s + sqxtn v3.4h, v3.4s + sqxtn v4.4h, v4.4s + sqxtn v5.4h, v5.4s + sqxtn v6.4h, v6.4s + sqxtn v7.4h, v7.4s + sqxtn v0.8b, v0.8h + sqxtn v1.8b, v1.8h + sqxtn v2.8b, v2.8h + sqxtn v3.8b, v3.8h + sqxtn v4.8b, v4.8h + sqxtn v5.8b, v5.8h + sqxtn v6.8b, v6.8h + sqxtn v7.8b, v7.8h + + mov x16, x3 + add x17, x16, x9 + add x25, x17, x9 + add x21, x25, x9 + + st1 {v0.s}[0], [x16], #4 + st1 {v1.s}[0], [x16], #4 + st1 {v2.s}[0], [x17], #4 + st1 {v3.s}[0], [x17], #4 + st1 {v4.s}[0], [x25], #4 + st1 {v5.s}[0], [x25], #4 + st1 {v6.s}[0], [x21], #4 + st1 {v7.s}[0], [x21], #4 + + add x3, x3, x25 + add x23, x23, x19 + sub x24, x24, #4 + cmp x24, #0 + ble LoopWEnd + cmp x24, #4 + bge LoopW4 + + LoopW: + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v17.16b + mov v1.16b, v18.16b + LoopKh: + mov x25, x7 + mov x22, x16 + LoopKw: + ld1 {v15.8b}, [x22], x13 + ssubl v14.8h, v15.8b, v19.8b + ld1 {v16.8h}, [x17], #16 + smlal v0.4s, v14.4h, v16.4h + smlal2 v1.4s, v14.8h, v16.8h + subs x25, x25, #1 + bne LoopKw + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh + + sqshl v0.4s, v0.4s, v24.4s + sqrdmulh v0.4s, v0.4s, v22.4s + sqshl v1.4s, v1.4s, v25.4s + sqrdmulh v1.4s, v1.4s, v23.4s + + sqrshl v0.4s, v0.4s, v26.4s + sqrshl v1.4s, v1.4s, v27.4s + + add v0.4s, v0.4s, v20.4s + smax v0.4s, v0.4s, v28.4s + smin v0.4s, v0.4s, v30.4s + + sqxtn v0.4h, v0.4s + sqxtn v0.8b, v0.8h + + add v1.4s, v1.4s, v21.4s + smax v1.4s, v1.4s, v29.4s + smin v1.4s, v1.4s, v31.4s + + sqxtn v1.4h, v1.4s + sqxtn v1.8b, v1.8h + + mov x17, x3 + st1 {v0.s}[0], [x17], #4 + st1 {v1.s}[0], [x17], #4 + add x3, x3, x9 + + add x23, x23, x11 + subs x24, x24, #1 + bne LoopW + LoopWEnd: + add x0, x0, x8 + add x1, x1, x10 + subs x4, x4, #1 + bne LoopH + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8PostAlign4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8PostAlign4.S new file mode 100644 index 0000000000000000000000000000000000000000..8d678174ec5a3bdcf15e0c7bf58942f49fefa128 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8PostAlign4.S @@ -0,0 +1,191 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, +// int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); +// x0: dst, x1: buffer, x2: num_pixels, x3: output_zp, x4: out_multiplier, +// x5: left_shift, x6: right_shift, x7: acc_min, x8: acc_max + +asm_function ConvDwInt8PostAlign4 + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + dup v26.4s, w5 + dup v27.4s, w4 + dup v28.4s, w6 + + dup v29.4s, w3 + dup v30.4s, w7 + dup v31.4s, w8 + + cmp x2, #16 + blt LoopDepth8 + + LoopDepth16: + ld1 {v0.4s}, [x1], #16 + ld1 {v1.4s}, [x1], #16 + ld1 {v2.4s}, [x1], #16 + ld1 {v3.4s}, [x1], #16 + + cbz w5, RightShiftDepth16 + sqshl v0.4s, v0.4s, v26.4s + sqshl v1.4s, v1.4s, v26.4s + sqshl v2.4s, v2.4s, v26.4s + sqshl v3.4s, v3.4s, v26.4s + sqrdmulh v0.4s, v0.4s, v27.4s + sqrdmulh v1.4s, v1.4s, v27.4s + sqrdmulh v2.4s, v2.4s, v27.4s + sqrdmulh v3.4s, v3.4s, v27.4s + b AddZpDepth16 + + RightShiftDepth16: + sqrdmulh v0.4s, v0.4s, v27.4s + sqrdmulh v1.4s, v1.4s, v27.4s + sqrdmulh v2.4s, v2.4s, v27.4s + sqrdmulh v3.4s, v3.4s, v27.4s + + and v4.16b, v0.16b, v28.16b + sshr v4.4s, v4.4s, #31 + sqadd v0.4s, v0.4s, v4.4s + srshl v0.4s, v0.4s, v28.4s + and v5.16b, v1.16b, v28.16b + sshr v5.4s, v5.4s, #31 + sqadd v1.4s, v1.4s, v5.4s + srshl v1.4s, v1.4s, v28.4s + and v6.16b, v2.16b, v28.16b + sshr v6.4s, v6.4s, #31 + sqadd v2.4s, v2.4s, v6.4s + srshl v2.4s, v2.4s, v28.4s + and v7.16b, v3.16b, v28.16b + sshr v7.4s, v7.4s, #31 + sqadd v3.4s, v3.4s, v7.4s + srshl v3.4s, v3.4s, v28.4s + + AddZpDepth16: + add v0.4s, v0.4s, v29.4s + add v1.4s, v1.4s, v29.4s + add v2.4s, v2.4s, v29.4s + add v3.4s, v3.4s, v29.4s + + smax v0.4s, v0.4s, v30.4s + smax v1.4s, v1.4s, v30.4s + smax v2.4s, v2.4s, v30.4s + smax v3.4s, v3.4s, v30.4s + + smin v0.4s, v0.4s, v31.4s + smin v1.4s, v1.4s, v31.4s + smin v2.4s, v2.4s, v31.4s + smin v3.4s, v3.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v1.4h, v1.4s + sqxtn v2.4h, v2.4s + sqxtn v3.4h, v3.4s + + sqxtn v0.8b, v0.8h + sqxtn v1.8b, v1.8h + sqxtn v2.8b, v2.8h + sqxtn v3.8b, v3.8h + + st1 {v0.s}[0], [x0], #4 + st1 {v1.s}[0], [x0], #4 + st1 {v2.s}[0], [x0], #4 + st1 {v3.s}[0], [x0], #4 + + sub x2, x2, #16 + cmp x2, #16 + bge LoopDepth16 + + LoopDepth8: + cmp x2, #8 + blt LoopDepth4 + ld1 {v0.4s}, [x1], #16 + ld1 {v1.4s}, [x1], #16 + + cbz w5, RightShiftDepth8 + sqshl v0.4s, v0.4s, v26.4s + sqshl v1.4s, v1.4s, v26.4s + sqrdmulh v0.4s, v0.4s, v27.4s + sqrdmulh v1.4s, v1.4s, v27.4s + b AddZpDepth8 + + RightShiftDepth8: + sqrdmulh v0.4s, v0.4s, v27.4s + sqrdmulh v1.4s, v1.4s, v27.4s + and v4.16b, v0.16b, v28.16b + sshr v4.4s, v4.4s, #31 + sqadd v0.4s, v0.4s, v4.4s + srshl v0.4s, v0.4s, v28.4s + and v5.16b, v1.16b, v28.16b + sshr v5.4s, v5.4s, #31 + sqadd v1.4s, v1.4s, v5.4s + srshl v1.4s, v1.4s, v28.4s + + AddZpDepth8: + add v0.4s, v0.4s, v29.4s + add v1.4s, v1.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smax v1.4s, v1.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + smin v1.4s, v1.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v1.4h, v1.4s + + sqxtn v0.8b, v0.8h + sqxtn v1.8b, v1.8h + + st1 {v0.s}[0], [x0], #4 + st1 {v1.s}[0], [x0], #4 + + sub x2, x2, #8 + cmp x2, #8 + bge LoopDepth8 + + LoopDepth4: + cmp x2, #4 + blt End + ld1 {v0.4s}, [x1], #16 + + sqshl v0.4s, v0.4s, v26.4s + sqrdmulh v0.4s, v0.4s, v27.4s + and v4.16b, v0.16b, v28.16b + sshr v4.4s, v4.4s, #31 + sqadd v0.4s, v0.4s, v4.4s + srshl v0.4s, v0.4s, v28.4s + + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v0.8b, v0.8h + + st1 {v0.s}[0], [x0], #4 + + sub x2, x2, #4 + bge LoopDepth4 + End: + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S new file mode 100644 index 0000000000000000000000000000000000000000..7f14f7c2ddc4eb1502563f4183b29b81c156b886 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S @@ -0,0 +1,119 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max); +// x0: dst, x1: buffer, x2: num_pixels, x3: output_zp, x4: out_multiplier, +// x5: left_shift, x6: right_shift, x7: acc_min, x8: acc_max + +asm_function ConvDwInt8PostAlign4PerChannel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + dup v29.4s, w3 + dup v30.4s, w7 + dup v31.4s, w8 + + LoopDepth8: + cmp x2, #8 + blt LoopDepth4 + ld1 {v0.4s}, [x1], #16 + ld1 {v1.4s}, [x1], #16 + + ld1 {v2.4s}, [x5], #16 + ld1 {v3.4s}, [x5], #16 + + ld1 {v4.4s}, [x4], #16 + ld1 {v5.4s}, [x4], #16 + + sqshl v0.4s, v0.4s, v2.4s + sqshl v1.4s, v1.4s, v3.4s + + ld1 {v6.4s}, [x6], #16 + ld1 {v7.4s}, [x6], #16 + + sqrdmulh v0.4s, v0.4s, v4.4s + sqrdmulh v1.4s, v1.4s, v5.4s + and v16.16b, v0.16b, v6.16b + sshr v16.4s, v16.4s, #31 + sqadd v0.4s, v0.4s, v16.4s + srshl v0.4s, v0.4s, v6.4s + and v17.16b, v1.16b, v7.16b + sshr v17.4s, v17.4s, #31 + sqadd v1.4s, v1.4s, v17.4s + srshl v1.4s, v1.4s, v7.4s + + add v0.4s, v0.4s, v29.4s + add v1.4s, v1.4s, v29.4s + + smax v0.4s, v0.4s, v30.4s + smax v1.4s, v1.4s, v30.4s + + smin v0.4s, v0.4s, v31.4s + smin v1.4s, v1.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v1.4h, v1.4s + + sqxtn v0.8b, v0.8h + sqxtn v1.8b, v1.8h + + st1 {v0.s}[0], [x0], #4 + st1 {v1.s}[0], [x0], #4 + + sub x2, x2, #8 + cmp x2, #8 + bge LoopDepth8 + + LoopDepth4: + cmp x2, #4 + blt End + ld1 {v0.4s}, [x1], #16 + ld1 {v2.4s}, [x5], #16 + + sqshl v0.4s, v0.4s, v2.4s + + ld1 {v4.4s}, [x4], #16 + sqrdmulh v0.4s, v0.4s, v4.4s + + ld1 {v6.4s}, [x6], #16 + and v16.16b, v0.16b, v6.16b + sshr v16.4s, v16.4s, #31 + sqadd v0.4s, v0.4s, v16.4s + srshl v0.4s, v0.4s, v6.4s + + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v0.8b, v0.8h + + st1 {v0.s}[0], [x0], #4 + + sub x2, x2, #4 + bge LoopDepth4 + End: + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8Row.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8Row.S new file mode 100644 index 0000000000000000000000000000000000000000..5828d57525845fd11ef9c4f818cdd395c11ab4d3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8Row.S @@ -0,0 +1,134 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, +// int output_channel, int input_step, int8_t input_zp) +// x0: output_ptr, x1: input_ptr, x2: weight_ptr, x3: num_pixels, +// x4: output_channel, x5: input_step, x6: input_zp +// +asm_function ConvDwInt8Row + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters +cmp x3, #0 +beq End + +mov x10, x0 + +dup v31.8b, w6 + +LoopOutPixel: +mov x7, x1 +mov x8, x2 +mov x9, x4 + + LoopDepth16In: + cmp x9, #16 + blt L8 + sub x9, x9, #16 + + ld1 {v0.8b, v1.8b}, [x7], #16 + ld1 {v2.8h, v3.8h}, [x8], #32 + ld1 {v16.4s, v17.4s}, [x0], #32 + + ssubl v20.8h, v0.8b, v31.8b + smlal v16.4s, v20.4h, v2.4h + smlal2 v17.4s, v20.8h, v2.8h + + + cmp x9, #16 + blt LoopDepth16Out + LoopDepth16: + + st1 {v16.4s, v17.4s}, [x10], #32 + ld1 {v18.4s, v19.4s}, [x0], #32 + ssubl v21.8h, v1.8b, v31.8b + smlal v18.4s, v21.4h, v3.4h + smlal2 v19.4s, v21.8h, v3.8h + st1 {v18.4s, v19.4s}, [x10], #32 + + ld1 {v0.8b, v1.8b}, [x7], #16 + ld1 {v2.8h, v3.8h}, [x8], #32 + ld1 {v16.4s, v17.4s}, [x0], #32 + + ssubl v20.8h, v0.8b, v31.8b + smlal v16.4s, v20.4h, v2.4h + smlal2 v17.4s, v20.8h, v2.8h + + sub x9, x9, #16 + cmp x9, #16 + bge LoopDepth16 + + LoopDepth16Out: + + st1 {v16.4s, v17.4s}, [x10], #32 + ld1 {v18.4s, v19.4s}, [x0], #32 + ssubl v21.8h, v1.8b, v31.8b + smlal v18.4s, v21.4h, v3.4h + smlal2 v19.4s, v21.8h, v3.8h + st1 {v18.4s, v19.4s}, [x10], #32 + + L8: + cmp x9, #8 + blt L0 + + LoopDepth8: + ld1 {v0.8b}, [x7], #8 + ld1 {v2.8h}, [x8], #16 + ld1 {v16.4s, v17.4s}, [x0], #32 + + ssubl v20.8h, v0.8b, v31.8b + smlal v16.4s, v20.4h, v2.4h + smlal2 v17.4s, v20.8h, v2.8h + st1 {v16.4s, v17.4s}, [x10], #32 + + sub x9, x9, #8 + cmp x9, #8 + bge LoopDepth8 + + L0: + cmp x9, #0 + beq Loop16LineEnd + + LoopDepth0: + ldrsb w14, [x7], #1 + ldrsh w15, [x8], #2 + ldr w16, [x0], #4 + sub w14, w14, w6 + + sxth w14, w14 + madd w14, w14, w15, w16 + str w14, [x10], #4 + + subs x9, x9, #1 + bne LoopDepth0 + + Loop16LineEnd: + +subs x3, x3, #1 +add x1, x1, x5 +bne LoopOutPixel + +End: +ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvFp32Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvFp32Center.S new file mode 100644 index 0000000000000000000000000000000000000000..9cead57ce4620cc291c606f6ed0bcb5ccc419f2c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvFp32Center.S @@ -0,0 +1,458 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvSwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t ic4, size_t in_sh_step, +// size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: kernel_h, x7: kernel_w, +// x8: out_h_step, x9: block_channel, x10: ic4, x11: in_sh_step, x12: in_sw_step, x13: in_kh_step, x14: in_kw_step +// x26: relu, x16: relu6 +asm_function ConvSwFp32Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr x8, [sp, #208] + ldr x9, [sp, #216] + ldr x10, [sp, #224] + ldr x11, [sp, #232] + ldr x12, [sp, #240] + ldr x13, [sp, #248] + ldr x14, [sp, #256] + mul x15, x6, x7 + mul x15, x10, x15 + mov x16, #16 + mul x15, x15, x16 + + ld1 {v25.4s}, [x3] + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + LoopH: + mov x17, x1 + mov x28, x5 + mov x3, x0 + cmp x28, #8 + blt LoopW + cmp x28, #16 + blt LoopW8 + + LoopW16: + mov x19, #16 + mul x19, x19, x12 + mov x20, x17 + mov x21, x2 + mov x22, x6 + mov v0.16b, v25.16b + mov v1.16b, v25.16b + mov v2.16b, v25.16b + mov v3.16b, v25.16b + mov v4.16b, v25.16b + mov v5.16b, v25.16b + mov v6.16b, v25.16b + mov v7.16b, v25.16b + mov v8.16b, v25.16b + mov v9.16b, v25.16b + mov v10.16b, v25.16b + mov v11.16b, v25.16b + mov v12.16b, v25.16b + mov v13.16b, v25.16b + mov v14.16b, v25.16b + mov v15.16b, v25.16b + LoopKh16: + mov x23, x7 + mov x24, x20 + LoopKw16: + mov x25, x24 + mov x27, x10 + LoopIc16: + mov x26, x25 + mov x16, x21 + ld1 {v28.4s}, [x16], x15 + ld1 {v29.4s}, [x16], x15 + ld1 {v30.4s}, [x16], x15 + ld1 {v31.4s}, [x16], x15 + zip1 v20.4s, v28.4s, v29.4s + zip2 v21.4s, v28.4s, v29.4s + zip1 v22.4s, v30.4s, v31.4s + zip2 v23.4s, v30.4s, v31.4s + ld1 {v16.4s}, [x26], x12 + ld1 {v17.4s}, [x26], x12 + trn1 v28.2d, v20.2d, v22.2d + trn2 v29.2d, v20.2d, v22.2d + trn1 v30.2d, v21.2d, v23.2d + trn2 v31.2d, v21.2d, v23.2d + ld1 {v18.4s}, [x26], x12 + ld1 {v19.4s}, [x26], x12 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v28.4s, v17.s[0] + fmla v0.4s, v29.4s, v16.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v0.4s, v30.4s, v16.s[2] + fmla v1.4s, v30.4s, v17.s[2] + fmla v0.4s, v31.4s, v16.s[3] + fmla v1.4s, v31.4s, v17.s[3] + ld1 {v20.4s}, [x26], x12 + ld1 {v21.4s}, [x26], x12 + fmla v2.4s, v28.4s, v18.s[0] + fmla v3.4s, v28.4s, v19.s[0] + fmla v2.4s, v29.4s, v18.s[1] + fmla v3.4s, v29.4s, v19.s[1] + fmla v2.4s, v30.4s, v18.s[2] + fmla v3.4s, v30.4s, v19.s[2] + fmla v2.4s, v31.4s, v18.s[3] + fmla v3.4s, v31.4s, v19.s[3] + ld1 {v22.4s}, [x26], x12 + ld1 {v23.4s}, [x26], x12 + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v28.4s, v21.s[0] + fmla v4.4s, v29.4s, v20.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v4.4s, v30.4s, v20.s[2] + fmla v5.4s, v30.4s, v21.s[2] + fmla v4.4s, v31.4s, v20.s[3] + fmla v5.4s, v31.4s, v21.s[3] + ld1 {v16.4s}, [x26], x12 + ld1 {v17.4s}, [x26], x12 + fmla v6.4s, v28.4s, v22.s[0] + fmla v7.4s, v28.4s, v23.s[0] + fmla v6.4s, v29.4s, v22.s[1] + fmla v7.4s, v29.4s, v23.s[1] + fmla v6.4s, v30.4s, v22.s[2] + fmla v7.4s, v30.4s, v23.s[2] + fmla v6.4s, v31.4s, v22.s[3] + fmla v7.4s, v31.4s, v23.s[3] + ld1 {v18.4s}, [x26], x12 + ld1 {v19.4s}, [x26], x12 + fmla v8.4s, v28.4s, v16.s[0] + fmla v9.4s, v28.4s, v17.s[0] + fmla v8.4s, v29.4s, v16.s[1] + fmla v9.4s, v29.4s, v17.s[1] + fmla v8.4s, v30.4s, v16.s[2] + fmla v9.4s, v30.4s, v17.s[2] + fmla v8.4s, v31.4s, v16.s[3] + fmla v9.4s, v31.4s, v17.s[3] + ld1 {v20.4s}, [x26], x12 + ld1 {v21.4s}, [x26], x12 + fmla v10.4s, v28.4s, v18.s[0] + fmla v11.4s, v28.4s, v19.s[0] + fmla v10.4s, v29.4s, v18.s[1] + fmla v11.4s, v29.4s, v19.s[1] + fmla v10.4s, v30.4s, v18.s[2] + fmla v11.4s, v30.4s, v19.s[2] + fmla v10.4s, v31.4s, v18.s[3] + fmla v11.4s, v31.4s, v19.s[3] + ld1 {v22.4s}, [x26], x12 + ld1 {v23.4s}, [x26], x12 + fmla v12.4s, v28.4s, v20.s[0] + fmla v13.4s, v28.4s, v21.s[0] + fmla v12.4s, v29.4s, v20.s[1] + fmla v13.4s, v29.4s, v21.s[1] + fmla v12.4s, v30.4s, v20.s[2] + fmla v13.4s, v30.4s, v21.s[2] + fmla v12.4s, v31.4s, v20.s[3] + fmla v13.4s, v31.4s, v21.s[3] + fmla v14.4s, v28.4s, v22.s[0] + fmla v15.4s, v28.4s, v23.s[0] + fmla v14.4s, v29.4s, v22.s[1] + fmla v15.4s, v29.4s, v23.s[1] + fmla v14.4s, v30.4s, v22.s[2] + fmla v15.4s, v30.4s, v23.s[2] + fmla v14.4s, v31.4s, v22.s[3] + fmla v15.4s, v31.4s, v23.s[3] + add x21, x21, #16 + add x25, x25, #16 + subs x27, x27, #1 + bgt LoopIc16 + subs x23, x23, #1 + add x24, x24, x14 + bne LoopKw16 + add x20, x20, x13 + subs x22, x22, #1 + bne LoopKh16 + ldr x16, [sp, #272] + cbnz x16, Relu616 + ldr x26, [sp, #264] + cbnz x26, Relu16 + b Write16 + Relu616: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + fmin v8.4s, v8.4s, v26.4s + fmin v9.4s, v9.4s, v26.4s + fmin v10.4s, v10.4s, v26.4s + fmin v11.4s, v11.4s, v26.4s + fmin v12.4s, v12.4s, v26.4s + fmin v13.4s, v13.4s, v26.4s + fmin v14.4s, v14.4s, v26.4s + fmin v15.4s, v15.4s, v26.4s + Relu16: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + fmax v8.4s, v8.4s, v27.4s + fmax v9.4s, v9.4s, v27.4s + fmax v10.4s, v10.4s, v27.4s + fmax v11.4s, v11.4s, v27.4s + fmax v12.4s, v12.4s, v27.4s + fmax v13.4s, v13.4s, v27.4s + fmax v14.4s, v14.4s, v27.4s + fmax v15.4s, v15.4s, v27.4s + Write16: + st1 {v0.4s}, [x3], x9 + st1 {v1.4s}, [x3], x9 + st1 {v2.4s}, [x3], x9 + st1 {v3.4s}, [x3], x9 + st1 {v4.4s}, [x3], x9 + st1 {v5.4s}, [x3], x9 + st1 {v6.4s}, [x3], x9 + st1 {v7.4s}, [x3], x9 + st1 {v8.4s}, [x3], x9 + st1 {v9.4s}, [x3], x9 + st1 {v10.4s}, [x3], x9 + st1 {v11.4s}, [x3], x9 + st1 {v12.4s}, [x3], x9 + st1 {v13.4s}, [x3], x9 + st1 {v14.4s}, [x3], x9 + st1 {v15.4s}, [x3], x9 + add x17, x17, x19 + sub x28, x28, #16 + cmp x28, #0 + ble LoopWEnd + cmp x28, #8 + blt LoopW + cmp x28, #16 + bge LoopW16 + LoopW8: + mov x19, #8 + mul x19, x19, x12 + mov x20, x17 + mov x21, x2 + mov x22, x6 + mov v0.16b, v25.16b + mov v1.16b, v25.16b + mov v2.16b, v25.16b + mov v3.16b, v25.16b + mov v4.16b, v25.16b + mov v5.16b, v25.16b + mov v6.16b, v25.16b + mov v7.16b, v25.16b + LoopKh8: + mov x23, x7 + mov x24, x20 + LoopKw8: + mov x25, x24 + mov x27, x10 + LoopIc8: + mov x26, x25 + mov x16, x21 + ld1 {v28.4s}, [x16], x15 + ld1 {v29.4s}, [x16], x15 + ld1 {v30.4s}, [x16], x15 + ld1 {v31.4s}, [x16], x15 + zip1 v20.4s, v28.4s, v29.4s + zip2 v21.4s, v28.4s, v29.4s + zip1 v22.4s, v30.4s, v31.4s + zip2 v23.4s, v30.4s, v31.4s + ld1 {v16.4s}, [x26], x12 + ld1 {v17.4s}, [x26], x12 + trn1 v28.2d, v20.2d, v22.2d + trn2 v29.2d, v20.2d, v22.2d + trn1 v30.2d, v21.2d, v23.2d + trn2 v31.2d, v21.2d, v23.2d + ld1 {v18.4s}, [x26], x12 + ld1 {v19.4s}, [x26], x12 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v28.4s, v17.s[0] + fmla v0.4s, v29.4s, v16.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v0.4s, v30.4s, v16.s[2] + fmla v1.4s, v30.4s, v17.s[2] + fmla v0.4s, v31.4s, v16.s[3] + fmla v1.4s, v31.4s, v17.s[3] + ld1 {v20.4s}, [x26], x12 + ld1 {v21.4s}, [x26], x12 + fmla v2.4s, v28.4s, v18.s[0] + fmla v3.4s, v28.4s, v19.s[0] + fmla v2.4s, v29.4s, v18.s[1] + fmla v3.4s, v29.4s, v19.s[1] + fmla v2.4s, v30.4s, v18.s[2] + fmla v3.4s, v30.4s, v19.s[2] + fmla v2.4s, v31.4s, v18.s[3] + fmla v3.4s, v31.4s, v19.s[3] + ld1 {v22.4s}, [x26], x12 + ld1 {v23.4s}, [x26], x12 + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v28.4s, v21.s[0] + fmla v4.4s, v29.4s, v20.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v4.4s, v30.4s, v20.s[2] + fmla v5.4s, v30.4s, v21.s[2] + fmla v4.4s, v31.4s, v20.s[3] + fmla v5.4s, v31.4s, v21.s[3] + fmla v6.4s, v28.4s, v22.s[0] + fmla v7.4s, v28.4s, v23.s[0] + fmla v6.4s, v29.4s, v22.s[1] + fmla v7.4s, v29.4s, v23.s[1] + fmla v6.4s, v30.4s, v22.s[2] + fmla v7.4s, v30.4s, v23.s[2] + fmla v6.4s, v31.4s, v22.s[3] + fmla v7.4s, v31.4s, v23.s[3] + add x21, x21, #16 + add x25, x25, #16 + subs x27, x27, #1 + bgt LoopIc8 + subs x23, x23, #1 + add x24, x24, x14 + bne LoopKw8 + add x20, x20, x13 + subs x22, x22, #1 + bne LoopKh8 + ldr x16, [sp, #272] + cbnz x16, Relu68 + ldr x26, [sp, #264] + cbnz x26, Relu8 + b Write8 + Relu68: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + Relu8: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + Write8: + st1 {v0.4s}, [x3], x9 + st1 {v1.4s}, [x3], x9 + st1 {v2.4s}, [x3], x9 + st1 {v3.4s}, [x3], x9 + st1 {v4.4s}, [x3], x9 + st1 {v5.4s}, [x3], x9 + st1 {v6.4s}, [x3], x9 + st1 {v7.4s}, [x3], x9 + add x17, x17, x19 + sub x28, x28, #8 + cmp x28, #0 + ble LoopWEnd + cmp x28, #8 + bge LoopW8 + LoopW: + mov x20, x17 + mov x21, x2 + mov x22, x6 + mov v0.16b, v25.16b + LoopKh: + mov x23, x7 + mov x24, x20 + LoopKw: + mov x25, x24 + mov x27, x10 + LoopIc: + mov x26, x25 + mov x16, x21 + ld1 {v28.4s}, [x16], x15 + ld1 {v29.4s}, [x16], x15 + ld1 {v30.4s}, [x16], x15 + ld1 {v31.4s}, [x16], x15 + zip1 v20.4s, v28.4s, v29.4s + zip2 v21.4s, v28.4s, v29.4s + zip1 v22.4s, v30.4s, v31.4s + zip2 v23.4s, v30.4s, v31.4s + ld1 {v16.4s}, [x26], x12 + trn1 v28.2d, v20.2d, v22.2d + trn2 v29.2d, v20.2d, v22.2d + trn1 v30.2d, v21.2d, v23.2d + trn2 v31.2d, v21.2d, v23.2d + fmla v0.4s, v28.4s, v16.s[0] + fmla v0.4s, v29.4s, v16.s[1] + fmla v0.4s, v30.4s, v16.s[2] + fmla v0.4s, v31.4s, v16.s[3] + add x21, x21, #16 + add x25, x25, #16 + subs x27, x27, #1 + bgt LoopIc + subs x23, x23, #1 + add x24, x24, x14 + bne LoopKw + add x20, x20, x13 + subs x22, x22, #1 + bne LoopKh + ldr x16, [sp, #272] + cbnz x16, Relu6 + ldr x26, [sp, #264] + cbnz x26, Relu + b Write + Relu6: + fmin v0.4s, v0.4s, v26.4s + Relu: + fmax v0.4s, v0.4s, v27.4s + Write: + st1 {v0.4s}, [x3], x9 + add x17, x17, x12 + subs x28, x28, #1 + bne LoopW + LoopWEnd: + add x0, x0, x8 + add x1, x1, x11 + subs x4, x4, #1 + bne LoopH + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW1x16Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW1x16Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..7820c52148a37c762578dbdb7c26ed2de20a9ec6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW1x16Kernel.S @@ -0,0 +1,421 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv1x16Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #64 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + stp x25, x26, [sp, #48] + + ldr x10, [sp, #64] + ldr x11, [sp, #72] + ldr x12, [sp, #80] + ldr x13, [sp, #88] + ldr x14, [sp, #96] + ldr x15, [sp, #104] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + prfm pldl1keep, [x23] + mov x24, x23 + mov x25, x10 + subs x25, x25, #16 + blt LoopC12 + LoopC16: + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x24], #64 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v2.4s, v18.4s, v4.s[0] + fmla v3.4s, v19.4s, v4.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[1] + fmla v1.4s, v17.4s, v4.s[1] + fmla v2.4s, v18.4s, v4.s[1] + fmla v3.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v2.4s, v18.4s, v4.s[2] + fmla v3.4s, v19.4s, v4.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[3] + fmla v1.4s, v17.4s, v4.s[3] + fmla v2.4s, v18.4s, v4.s[3] + fmla v3.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v2.4s, v18.4s, v5.s[0] + fmla v3.4s, v19.4s, v5.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[1] + fmla v1.4s, v17.4s, v5.s[1] + fmla v2.4s, v18.4s, v5.s[1] + fmla v3.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v2.4s, v18.4s, v5.s[2] + fmla v3.4s, v19.4s, v5.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[3] + fmla v1.4s, v17.4s, v5.s[3] + fmla v2.4s, v18.4s, v5.s[3] + fmla v3.4s, v19.4s, v5.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[0] + fmla v1.4s, v17.4s, v6.s[0] + fmla v2.4s, v18.4s, v6.s[0] + fmla v3.4s, v19.4s, v6.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[1] + fmla v1.4s, v17.4s, v6.s[1] + fmla v2.4s, v18.4s, v6.s[1] + fmla v3.4s, v19.4s, v6.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[2] + fmla v1.4s, v17.4s, v6.s[2] + fmla v2.4s, v18.4s, v6.s[2] + fmla v3.4s, v19.4s, v6.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[3] + fmla v1.4s, v17.4s, v6.s[3] + fmla v2.4s, v18.4s, v6.s[3] + fmla v3.4s, v19.4s, v6.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[0] + fmla v1.4s, v17.4s, v7.s[0] + fmla v2.4s, v18.4s, v7.s[0] + fmla v3.4s, v19.4s, v7.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[1] + fmla v1.4s, v17.4s, v7.s[1] + fmla v2.4s, v18.4s, v7.s[1] + fmla v3.4s, v19.4s, v7.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[2] + fmla v1.4s, v17.4s, v7.s[2] + fmla v2.4s, v18.4s, v7.s[2] + fmla v3.4s, v19.4s, v7.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[3] + fmla v1.4s, v17.4s, v7.s[3] + fmla v2.4s, v18.4s, v7.s[3] + fmla v3.4s, v19.4s, v7.s[3] + subs x25, x25, #16 + bge LoopC16 + LoopC12: + adds x25, x25, #16 + cbz x25, LoopCEnd + cmp x25, #12 + blt LoopC8 + ld1 {v4.4s, v5.4s, v6.4s}, [x24], #48 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v2.4s, v18.4s, v4.s[0] + fmla v3.4s, v19.4s, v4.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[1] + fmla v1.4s, v17.4s, v4.s[1] + fmla v2.4s, v18.4s, v4.s[1] + fmla v3.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v2.4s, v18.4s, v4.s[2] + fmla v3.4s, v19.4s, v4.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[3] + fmla v1.4s, v17.4s, v4.s[3] + fmla v2.4s, v18.4s, v4.s[3] + fmla v3.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v2.4s, v18.4s, v5.s[0] + fmla v3.4s, v19.4s, v5.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[1] + fmla v1.4s, v17.4s, v5.s[1] + fmla v2.4s, v18.4s, v5.s[1] + fmla v3.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v2.4s, v18.4s, v5.s[2] + fmla v3.4s, v19.4s, v5.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[3] + fmla v1.4s, v17.4s, v5.s[3] + fmla v2.4s, v18.4s, v5.s[3] + fmla v3.4s, v19.4s, v5.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[0] + fmla v1.4s, v17.4s, v6.s[0] + fmla v2.4s, v18.4s, v6.s[0] + fmla v3.4s, v19.4s, v6.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[1] + fmla v1.4s, v17.4s, v6.s[1] + fmla v2.4s, v18.4s, v6.s[1] + fmla v3.4s, v19.4s, v6.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[2] + fmla v1.4s, v17.4s, v6.s[2] + fmla v2.4s, v18.4s, v6.s[2] + fmla v3.4s, v19.4s, v6.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[3] + fmla v1.4s, v17.4s, v6.s[3] + fmla v2.4s, v18.4s, v6.s[3] + fmla v3.4s, v19.4s, v6.s[3] + sub x25, x25, #12 + b LoopCTail + LoopC8: + cmp x25, #8 + blt LoopC4 + ld1 {v4.4s, v5.4s}, [x24], #32 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v2.4s, v18.4s, v4.s[0] + fmla v3.4s, v19.4s, v4.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[1] + fmla v1.4s, v17.4s, v4.s[1] + fmla v2.4s, v18.4s, v4.s[1] + fmla v3.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v2.4s, v18.4s, v4.s[2] + fmla v3.4s, v19.4s, v4.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[3] + fmla v1.4s, v17.4s, v4.s[3] + fmla v2.4s, v18.4s, v4.s[3] + fmla v3.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v2.4s, v18.4s, v5.s[0] + fmla v3.4s, v19.4s, v5.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[1] + fmla v1.4s, v17.4s, v5.s[1] + fmla v2.4s, v18.4s, v5.s[1] + fmla v3.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v2.4s, v18.4s, v5.s[2] + fmla v3.4s, v19.4s, v5.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[3] + fmla v1.4s, v17.4s, v5.s[3] + fmla v2.4s, v18.4s, v5.s[3] + fmla v3.4s, v19.4s, v5.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v4.4s}, [x24], #16 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v2.4s, v18.4s, v4.s[0] + fmla v3.4s, v19.4s, v4.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[1] + fmla v1.4s, v17.4s, v4.s[1] + fmla v2.4s, v18.4s, v4.s[1] + fmla v3.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v2.4s, v18.4s, v4.s[2] + fmla v3.4s, v19.4s, v4.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[3] + fmla v1.4s, v17.4s, v4.s[3] + fmla v2.4s, v18.4s, v4.s[3] + fmla v3.4s, v19.4s, v4.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + cmp x25, #2 + beq LoopC2 + cmp x25, #1 + beq LoopC1 + // LoopC3 + ld3r {v4.4s, v5.4s, v6.4s}, [x24] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.4s + fmla v1.4s, v17.4s, v4.4s + fmla v2.4s, v18.4s, v4.4s + fmla v3.4s, v19.4s, v4.4s + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.4s + fmla v1.4s, v17.4s, v5.4s + fmla v2.4s, v18.4s, v5.4s + fmla v3.4s, v19.4s, v5.4s + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.4s + fmla v1.4s, v17.4s, v6.4s + fmla v2.4s, v18.4s, v6.4s + fmla v3.4s, v19.4s, v6.4s + b LoopCEnd + LoopC2: + ld1 {v4.d}[0], [x24] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v2.4s, v18.4s, v4.s[0] + fmla v3.4s, v19.4s, v4.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[1] + fmla v1.4s, v17.4s, v4.s[1] + fmla v2.4s, v18.4s, v4.s[1] + fmla v3.4s, v19.4s, v4.s[1] + b LoopCEnd + LoopC1: + ld1r {v4.4s}, [x24] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.4s + fmla v1.4s, v17.4s, v4.4s + fmla v2.4s, v18.4s, v4.4s + fmla v3.4s, v19.4s, v4.4s + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v4.2d, xzr // relu + fmax v0.4s, v0.4s, v4.4s + fmax v1.4s, v1.4s, v4.4s + fmax v2.4s, v2.4s, v4.4s + fmax v3.4s, v3.4s, v4.4s + + ands x6, x6, #1 + beq WriteBack + movi v4.4s, #6 // relu6 + scvtf v4.4s, v4.4s + fmin v0.4s, v0.4s, v4.4s + fmin v1.4s, v1.4s, v4.4s + fmin v2.4s, v2.4s, v4.4s + fmin v3.4s, v3.4s, v4.4s + fmin v4.4s, v4.4s, v4.4s + + WriteBack: + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0] + b End + NC4HW4: + add x21, x0, x7, LSL #1 + add x22, x20, x7, LSL #1 + st1 {v0.4s}, [x0] + st1 {v1.4s}, [x20] + st1 {v2.4s}, [x21] + st1 {v3.4s}, [x22] + End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW1x8Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW1x8Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..7ed045f20520444260e5edf42ca553e30cfae0ad --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW1x8Kernel.S @@ -0,0 +1,278 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv1x8Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #64 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + stp x25, x26, [sp, #48] + + ldr x10, [sp, #64] + ldr x11, [sp, #72] + ldr x12, [sp, #80] + ldr x13, [sp, #88] + ldr x14, [sp, #96] + ldr x15, [sp, #104] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + prfm pldl1keep, [x23] + mov x24, x23 + mov x25, x10 + subs x25, x25, #16 + blt LoopC12 + LoopC16: + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x24], #64 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v0.4s, v18.4s, v4.s[1] + fmla v1.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v0.4s, v18.4s, v4.s[3] + fmla v1.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v0.4s, v18.4s, v5.s[1] + fmla v1.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v0.4s, v18.4s, v5.s[3] + fmla v1.4s, v19.4s, v5.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[0] + fmla v1.4s, v17.4s, v6.s[0] + fmla v0.4s, v18.4s, v6.s[1] + fmla v1.4s, v19.4s, v6.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[2] + fmla v1.4s, v17.4s, v6.s[2] + fmla v0.4s, v18.4s, v6.s[3] + fmla v1.4s, v19.4s, v6.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[0] + fmla v1.4s, v17.4s, v7.s[0] + fmla v0.4s, v18.4s, v7.s[1] + fmla v1.4s, v19.4s, v7.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[2] + fmla v1.4s, v17.4s, v7.s[2] + fmla v0.4s, v18.4s, v7.s[3] + fmla v1.4s, v19.4s, v7.s[3] + subs x25, x25, #16 + bge LoopC16 + LoopC12: + adds x25, x25, #16 + cbz x25, LoopCEnd + cmp x25, #12 + blt LoopC8 + ld1 {v4.4s, v5.4s, v6.4s}, [x24], #48 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v0.4s, v18.4s, v4.s[1] + fmla v1.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v0.4s, v18.4s, v4.s[3] + fmla v1.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v0.4s, v18.4s, v5.s[1] + fmla v1.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v0.4s, v18.4s, v5.s[3] + fmla v1.4s, v19.4s, v5.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[0] + fmla v1.4s, v17.4s, v6.s[0] + fmla v0.4s, v18.4s, v6.s[1] + fmla v1.4s, v19.4s, v6.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[2] + fmla v1.4s, v17.4s, v6.s[2] + fmla v0.4s, v18.4s, v6.s[3] + fmla v1.4s, v19.4s, v6.s[3] + sub x25, x25, #12 + b LoopCTail + LoopC8: + cmp x25, #8 + blt LoopC4 + ld1 {v4.4s, v5.4s}, [x24], #32 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v0.4s, v18.4s, v4.s[1] + fmla v1.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v0.4s, v18.4s, v4.s[3] + fmla v1.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v0.4s, v18.4s, v5.s[1] + fmla v1.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v0.4s, v18.4s, v5.s[3] + fmla v1.4s, v19.4s, v5.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v4.4s}, [x24], #16 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v0.4s, v18.4s, v4.s[1] + fmla v1.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v0.4s, v18.4s, v4.s[3] + fmla v1.4s, v19.4s, v4.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + cmp x25, #2 + beq LoopC2 + cmp x25, #1 + beq LoopC1 + // LoopC3 + ld3r {v4.4s, v5.4s, v6.4s}, [x24] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + ld1 {v20.4s, v21.4s}, [x2], #32 + fmla v0.4s, v16.4s, v4.4s + fmla v1.4s, v17.4s, v4.4s + fmla v0.4s, v18.4s, v5.4s + fmla v1.4s, v19.4s, v5.4s + fmla v0.4s, v20.4s, v6.4s + fmla v1.4s, v21.4s, v6.4s + b LoopCEnd + LoopC2: + ld1 {v4.2s}, [x24] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v0.4s, v18.4s, v4.s[1] + fmla v1.4s, v19.4s, v4.s[1] + b LoopCEnd + LoopC1: + ld1r {v4.4s}, [x24] + ld1 {v16.4s, v17.4s}, [x2], #32 + fmla v0.4s, v16.4s, v4.4s + fmla v1.4s, v17.4s, v4.4s + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v4.2d, xzr // relu + fmax v0.4s, v0.4s, v4.4s + fmax v1.4s, v1.4s, v4.4s + fmax v2.4s, v2.4s, v4.4s + fmax v3.4s, v3.4s, v4.4s + + ands x6, x6, #1 + beq WriteBack + movi v4.4s, #6 // relu6 + scvtf v4.4s, v4.4s + fmin v0.4s, v0.4s, v4.4s + fmin v1.4s, v1.4s, v4.4s + fmin v2.4s, v2.4s, v4.4s + fmin v3.4s, v3.4s, v4.4s + fmin v4.4s, v4.4s, v4.4s + + WriteBack: + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s}, [x0] + b End + NC4HW4: + st1 {v0.4s}, [x0] + st1 {v1.4s}, [x20] + End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW2x16Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW2x16Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..221ebcf0ad92ffe453ab573e4cbed570b9cd9558 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW2x16Kernel.S @@ -0,0 +1,407 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv2x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv2x16Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #64 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + stp x25, x26, [sp, #48] + + ldr x10, [sp, #64] + ldr x11, [sp, #72] + ldr x12, [sp, #80] + ldr x13, [sp, #88] + ldr x14, [sp, #96] + ldr x15, [sp, #104] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v30.4s, v18.s[0] + fmla v3.4s, v31.4s, v18.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[1] + fmla v1.4s, v29.4s, v18.s[1] + fmla v2.4s, v30.4s, v18.s[1] + fmla v3.4s, v31.4s, v18.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v30.4s, v18.s[2] + fmla v3.4s, v31.4s, v18.s[2] + fmla v4.4s, v28.4s, v21.s[2] + fmla v5.4s, v29.4s, v21.s[2] + fmla v6.4s, v30.4s, v21.s[2] + fmla v7.4s, v31.4s, v21.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[3] + fmla v1.4s, v29.4s, v18.s[3] + fmla v2.4s, v30.4s, v18.s[3] + fmla v3.4s, v31.4s, v18.s[3] + fmla v4.4s, v28.4s, v21.s[3] + fmla v5.4s, v29.4s, v21.s[3] + fmla v6.4s, v30.4s, v21.s[3] + fmla v7.4s, v31.4s, v21.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + + WriteBack: + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0] + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] + b End + NC4HW4: + add x21, x0, x7, LSL #1 + add x22, x21, x7 + st1 {v0.4s}, [x0], #16 + st1 {v4.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v5.4s}, [x20] + st1 {v2.4s}, [x21], #16 + st1 {v6.4s}, [x21] + st1 {v3.4s}, [x22], #16 + st1 {v7.4s}, [x22] + End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW2x8Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW2x8Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..0d3be107b0228c675ba5fc8b7bd6f2503ddb4f81 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW2x8Kernel.S @@ -0,0 +1,265 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv2x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv2x8Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #64 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + stp x25, x26, [sp, #48] + + ldr x10, [sp, #64] + ldr x11, [sp, #72] + ldr x12, [sp, #80] + ldr x13, [sp, #88] + ldr x14, [sp, #96] + ldr x15, [sp, #104] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s}, [x3] + ld1 {v2.4s, v3.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v0.4s, v30.4s, v18.s[1] + fmla v1.4s, v31.4s, v18.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v28.4s, v21.s[2] + fmla v3.4s, v29.4s, v21.s[2] + fmla v0.4s, v30.4s, v18.s[3] + fmla v1.4s, v31.4s, v18.s[3] + fmla v2.4s, v30.4s, v21.s[3] + fmla v3.4s, v31.4s, v21.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v28.4s, v29.4s}, [x2], #32 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + + WriteBack: + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s}, [x0] + st1 {v2.4s, v3.4s}, [x20] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v2.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v3.4s}, [x20] + End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW3x16Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW3x16Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..34706ef96a1aeacead8983586b7e926595bd297a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW3x16Kernel.S @@ -0,0 +1,533 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv3x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv3x16Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + stp x19, x20, [sp, #64] + stp x21, x22, [sp, #80] + stp x23, x24, [sp, #96] + stp x25, x26, [sp, #112] + + ldr x10, [sp, #128] + ldr x11, [sp, #136] + ldr x12, [sp, #144] + ldr x13, [sp, #152] + ldr x14, [sp, #160] + ldr x15, [sp, #168] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3] + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x19, x23, x13, lsl #1 + prfm pldl1keep, [x19] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v22.4s, v23.4s, v24.4s}, [x19], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + fmla v8.4s, v28.4s, v23.s[0] + fmla v9.4s, v29.4s, v23.s[0] + fmla v10.4s, v30.4s, v23.s[0] + fmla v11.4s, v31.4s, v23.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + fmla v8.4s, v28.4s, v23.s[1] + fmla v9.4s, v29.4s, v23.s[1] + fmla v10.4s, v30.4s, v23.s[1] + fmla v11.4s, v31.4s, v23.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + fmla v8.4s, v28.4s, v23.s[2] + fmla v9.4s, v29.4s, v23.s[2] + fmla v10.4s, v30.4s, v23.s[2] + fmla v11.4s, v31.4s, v23.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + fmla v8.4s, v28.4s, v23.s[3] + fmla v9.4s, v29.4s, v23.s[3] + fmla v10.4s, v30.4s, v23.s[3] + fmla v11.4s, v31.4s, v23.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v30.4s, v18.s[0] + fmla v3.4s, v31.4s, v18.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + fmla v10.4s, v30.4s, v24.s[0] + fmla v11.4s, v31.4s, v24.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[1] + fmla v1.4s, v29.4s, v18.s[1] + fmla v2.4s, v30.4s, v18.s[1] + fmla v3.4s, v31.4s, v18.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + fmla v8.4s, v28.4s, v24.s[1] + fmla v9.4s, v29.4s, v24.s[1] + fmla v10.4s, v30.4s, v24.s[1] + fmla v11.4s, v31.4s, v24.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v30.4s, v18.s[2] + fmla v3.4s, v31.4s, v18.s[2] + fmla v4.4s, v28.4s, v21.s[2] + fmla v5.4s, v29.4s, v21.s[2] + fmla v6.4s, v30.4s, v21.s[2] + fmla v7.4s, v31.4s, v21.s[2] + fmla v8.4s, v28.4s, v24.s[2] + fmla v9.4s, v29.4s, v24.s[2] + fmla v10.4s, v30.4s, v24.s[2] + fmla v11.4s, v31.4s, v24.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[3] + fmla v1.4s, v29.4s, v18.s[3] + fmla v2.4s, v30.4s, v18.s[3] + fmla v3.4s, v31.4s, v18.s[3] + fmla v4.4s, v28.4s, v21.s[3] + fmla v5.4s, v29.4s, v21.s[3] + fmla v6.4s, v30.4s, v21.s[3] + fmla v7.4s, v31.4s, v21.s[3] + fmla v8.4s, v28.4s, v24.s[3] + fmla v9.4s, v29.4s, v24.s[3] + fmla v10.4s, v30.4s, v24.s[3] + fmla v11.4s, v31.4s, v24.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v22.4s, v23.4s}, [x19], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + fmla v8.4s, v28.4s, v23.s[0] + fmla v9.4s, v29.4s, v23.s[0] + fmla v10.4s, v30.4s, v23.s[0] + fmla v11.4s, v31.4s, v23.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + fmla v8.4s, v28.4s, v23.s[1] + fmla v9.4s, v29.4s, v23.s[1] + fmla v10.4s, v30.4s, v23.s[1] + fmla v11.4s, v31.4s, v23.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + fmla v8.4s, v28.4s, v23.s[2] + fmla v9.4s, v29.4s, v23.s[2] + fmla v10.4s, v30.4s, v23.s[2] + fmla v11.4s, v31.4s, v23.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + fmla v8.4s, v28.4s, v23.s[3] + fmla v9.4s, v29.4s, v23.s[3] + fmla v10.4s, v30.4s, v23.s[3] + fmla v11.4s, v31.4s, v23.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v22.4s}, [x19], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v22.s}[0], [x19], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + fmax v8.4s, v8.4s, v24.4s + fmax v9.4s, v9.4s, v24.4s + fmax v10.4s, v10.4s, v24.4s + fmax v11.4s, v11.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + fmin v8.4s, v8.4s, v24.4s + fmin v9.4s, v9.4s, v24.4s + fmin v10.4s, v10.4s, v24.4s + fmin v11.4s, v11.4s, v24.4s + + WriteBack: + add x21, x0, x7, LSL #1 + add x22, x21, x7 + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0] + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x21] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v4.4s}, [x0], #16 + st1 {v8.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v5.4s}, [x20], #16 + st1 {v9.4s}, [x20] + st1 {v2.4s}, [x21], #16 + st1 {v6.4s}, [x21], #16 + st1 {v10.4s}, [x21] + st1 {v3.4s}, [x22], #16 + st1 {v7.4s}, [x22], #16 + st1 {v11.4s}, [x22] + End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW3x8Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW3x8Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..afbdecf517188c5465d715bd566008070dfac281 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW3x8Kernel.S @@ -0,0 +1,332 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv3x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv3x8Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #64 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + stp x25, x26, [sp, #48] + + ldr x10, [sp, #64] + ldr x11, [sp, #72] + ldr x12, [sp, #80] + ldr x13, [sp, #88] + ldr x14, [sp, #96] + ldr x15, [sp, #104] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s}, [x3] + ld1 {v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x19, x23, x13, lsl #1 + prfm pldl1keep, [x19] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v22.4s, v23.4s, v24.4s}, [x19], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v4.4s, v28.4s, v23.s[0] + fmla v5.4s, v29.4s, v23.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v30.4s, v23.s[1] + fmla v5.4s, v31.4s, v23.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v4.4s, v28.4s, v23.s[2] + fmla v5.4s, v29.4s, v23.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + fmla v4.4s, v30.4s, v23.s[3] + fmla v5.4s, v31.4s, v23.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v24.s[0] + fmla v5.4s, v29.4s, v24.s[0] + fmla v0.4s, v30.4s, v18.s[1] + fmla v1.4s, v31.4s, v18.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + fmla v4.4s, v30.4s, v24.s[1] + fmla v5.4s, v31.4s, v24.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v28.4s, v21.s[2] + fmla v3.4s, v29.4s, v21.s[2] + fmla v4.4s, v28.4s, v24.s[2] + fmla v5.4s, v29.4s, v24.s[2] + fmla v0.4s, v30.4s, v18.s[3] + fmla v1.4s, v31.4s, v18.s[3] + fmla v2.4s, v30.4s, v21.s[3] + fmla v3.4s, v31.4s, v21.s[3] + fmla v4.4s, v30.4s, v24.s[3] + fmla v5.4s, v31.4s, v24.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v22.4s, v23.4s}, [x19], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v4.4s, v28.4s, v23.s[0] + fmla v5.4s, v29.4s, v23.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v30.4s, v23.s[1] + fmla v5.4s, v31.4s, v23.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v4.4s, v28.4s, v23.s[2] + fmla v5.4s, v29.4s, v23.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + fmla v4.4s, v30.4s, v23.s[3] + fmla v5.4s, v31.4s, v23.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v22.4s}, [x19], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v22.s}[0], [x19], #4 + ld1 {v28.4s, v29.4s}, [x2], #32 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + + WriteBack: + add x21, x0, x7, LSL #1 + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s}, [x0] + st1 {v2.4s, v3.4s}, [x20] + st1 {v4.4s, v5.4s}, [x21] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v2.4s}, [x0], #16 + st1 {v4.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v3.4s}, [x20], #16 + st1 {v5.4s}, [x20] + End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW4x16Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW4x16Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..9d5400a6b4359a69a13cae981f58c592ccdd6b50 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW4x16Kernel.S @@ -0,0 +1,662 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv4x16Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3] + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x3] + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x27, x23, x13, lsl #1 + prfm pldl1keep, [x27] + add x28, x27, x13 + prfm pldl1keep, [x28] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v22.4s, v23.4s, v24.4s}, [x27], #48 + ld1 {v25.4s, v26.4s, v27.4s}, [x28], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v25.s[0] + fmla v13.4s, v29.4s, v25.s[0] + fmla v14.4s, v30.4s, v25.s[0] + fmla v15.4s, v31.4s, v25.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v25.s[1] + fmla v13.4s, v29.4s, v25.s[1] + fmla v14.4s, v30.4s, v25.s[1] + fmla v15.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + fmla v12.4s, v28.4s, v25.s[2] + fmla v13.4s, v29.4s, v25.s[2] + fmla v14.4s, v30.4s, v25.s[2] + fmla v15.4s, v31.4s, v25.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + fmla v12.4s, v28.4s, v25.s[3] + fmla v13.4s, v29.4s, v25.s[3] + fmla v14.4s, v30.4s, v25.s[3] + fmla v15.4s, v31.4s, v25.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + fmla v8.4s, v28.4s, v23.s[0] + fmla v9.4s, v29.4s, v23.s[0] + fmla v10.4s, v30.4s, v23.s[0] + fmla v11.4s, v31.4s, v23.s[0] + fmla v12.4s, v28.4s, v26.s[0] + fmla v13.4s, v29.4s, v26.s[0] + fmla v14.4s, v30.4s, v26.s[0] + fmla v15.4s, v31.4s, v26.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + fmla v8.4s, v28.4s, v23.s[1] + fmla v9.4s, v29.4s, v23.s[1] + fmla v10.4s, v30.4s, v23.s[1] + fmla v11.4s, v31.4s, v23.s[1] + fmla v12.4s, v28.4s, v26.s[1] + fmla v13.4s, v29.4s, v26.s[1] + fmla v14.4s, v30.4s, v26.s[1] + fmla v15.4s, v31.4s, v26.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + fmla v8.4s, v28.4s, v23.s[2] + fmla v9.4s, v29.4s, v23.s[2] + fmla v10.4s, v30.4s, v23.s[2] + fmla v11.4s, v31.4s, v23.s[2] + fmla v12.4s, v28.4s, v26.s[2] + fmla v13.4s, v29.4s, v26.s[2] + fmla v14.4s, v30.4s, v26.s[2] + fmla v15.4s, v31.4s, v26.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + fmla v8.4s, v28.4s, v23.s[3] + fmla v9.4s, v29.4s, v23.s[3] + fmla v10.4s, v30.4s, v23.s[3] + fmla v11.4s, v31.4s, v23.s[3] + fmla v12.4s, v28.4s, v26.s[3] + fmla v13.4s, v29.4s, v26.s[3] + fmla v14.4s, v30.4s, v26.s[3] + fmla v15.4s, v31.4s, v26.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v30.4s, v18.s[0] + fmla v3.4s, v31.4s, v18.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + fmla v10.4s, v30.4s, v24.s[0] + fmla v11.4s, v31.4s, v24.s[0] + fmla v12.4s, v28.4s, v27.s[0] + fmla v13.4s, v29.4s, v27.s[0] + fmla v14.4s, v30.4s, v27.s[0] + fmla v15.4s, v31.4s, v27.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[1] + fmla v1.4s, v29.4s, v18.s[1] + fmla v2.4s, v30.4s, v18.s[1] + fmla v3.4s, v31.4s, v18.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + fmla v8.4s, v28.4s, v24.s[1] + fmla v9.4s, v29.4s, v24.s[1] + fmla v10.4s, v30.4s, v24.s[1] + fmla v11.4s, v31.4s, v24.s[1] + fmla v12.4s, v28.4s, v27.s[1] + fmla v13.4s, v29.4s, v27.s[1] + fmla v14.4s, v30.4s, v27.s[1] + fmla v15.4s, v31.4s, v27.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v30.4s, v18.s[2] + fmla v3.4s, v31.4s, v18.s[2] + fmla v4.4s, v28.4s, v21.s[2] + fmla v5.4s, v29.4s, v21.s[2] + fmla v6.4s, v30.4s, v21.s[2] + fmla v7.4s, v31.4s, v21.s[2] + fmla v8.4s, v28.4s, v24.s[2] + fmla v9.4s, v29.4s, v24.s[2] + fmla v10.4s, v30.4s, v24.s[2] + fmla v11.4s, v31.4s, v24.s[2] + fmla v12.4s, v28.4s, v27.s[2] + fmla v13.4s, v29.4s, v27.s[2] + fmla v14.4s, v30.4s, v27.s[2] + fmla v15.4s, v31.4s, v27.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[3] + fmla v1.4s, v29.4s, v18.s[3] + fmla v2.4s, v30.4s, v18.s[3] + fmla v3.4s, v31.4s, v18.s[3] + fmla v4.4s, v28.4s, v21.s[3] + fmla v5.4s, v29.4s, v21.s[3] + fmla v6.4s, v30.4s, v21.s[3] + fmla v7.4s, v31.4s, v21.s[3] + fmla v8.4s, v28.4s, v24.s[3] + fmla v9.4s, v29.4s, v24.s[3] + fmla v10.4s, v30.4s, v24.s[3] + fmla v11.4s, v31.4s, v24.s[3] + fmla v12.4s, v28.4s, v27.s[3] + fmla v13.4s, v29.4s, v27.s[3] + fmla v14.4s, v30.4s, v27.s[3] + fmla v15.4s, v31.4s, v27.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v22.4s, v23.4s}, [x27], #32 + ld1 {v25.4s, v26.4s}, [x28], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v25.s[0] + fmla v13.4s, v29.4s, v25.s[0] + fmla v14.4s, v30.4s, v25.s[0] + fmla v15.4s, v31.4s, v25.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v25.s[1] + fmla v13.4s, v29.4s, v25.s[1] + fmla v14.4s, v30.4s, v25.s[1] + fmla v15.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + fmla v12.4s, v28.4s, v25.s[2] + fmla v13.4s, v29.4s, v25.s[2] + fmla v14.4s, v30.4s, v25.s[2] + fmla v15.4s, v31.4s, v25.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + fmla v12.4s, v28.4s, v25.s[3] + fmla v13.4s, v29.4s, v25.s[3] + fmla v14.4s, v30.4s, v25.s[3] + fmla v15.4s, v31.4s, v25.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + fmla v8.4s, v28.4s, v23.s[0] + fmla v9.4s, v29.4s, v23.s[0] + fmla v10.4s, v30.4s, v23.s[0] + fmla v11.4s, v31.4s, v23.s[0] + fmla v12.4s, v28.4s, v26.s[0] + fmla v13.4s, v29.4s, v26.s[0] + fmla v14.4s, v30.4s, v26.s[0] + fmla v15.4s, v31.4s, v26.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + fmla v8.4s, v28.4s, v23.s[1] + fmla v9.4s, v29.4s, v23.s[1] + fmla v10.4s, v30.4s, v23.s[1] + fmla v11.4s, v31.4s, v23.s[1] + fmla v12.4s, v28.4s, v26.s[1] + fmla v13.4s, v29.4s, v26.s[1] + fmla v14.4s, v30.4s, v26.s[1] + fmla v15.4s, v31.4s, v26.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + fmla v8.4s, v28.4s, v23.s[2] + fmla v9.4s, v29.4s, v23.s[2] + fmla v10.4s, v30.4s, v23.s[2] + fmla v11.4s, v31.4s, v23.s[2] + fmla v12.4s, v28.4s, v26.s[2] + fmla v13.4s, v29.4s, v26.s[2] + fmla v14.4s, v30.4s, v26.s[2] + fmla v15.4s, v31.4s, v26.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + fmla v8.4s, v28.4s, v23.s[3] + fmla v9.4s, v29.4s, v23.s[3] + fmla v10.4s, v30.4s, v23.s[3] + fmla v11.4s, v31.4s, v23.s[3] + fmla v12.4s, v28.4s, v26.s[3] + fmla v13.4s, v29.4s, v26.s[3] + fmla v14.4s, v30.4s, v26.s[3] + fmla v15.4s, v31.4s, v26.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v22.4s}, [x27], #16 + ld1 {v25.4s}, [x28], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v25.s[0] + fmla v13.4s, v29.4s, v25.s[0] + fmla v14.4s, v30.4s, v25.s[0] + fmla v15.4s, v31.4s, v25.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v25.s[1] + fmla v13.4s, v29.4s, v25.s[1] + fmla v14.4s, v30.4s, v25.s[1] + fmla v15.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + fmla v12.4s, v28.4s, v25.s[2] + fmla v13.4s, v29.4s, v25.s[2] + fmla v14.4s, v30.4s, v25.s[2] + fmla v15.4s, v31.4s, v25.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + fmla v12.4s, v28.4s, v25.s[3] + fmla v13.4s, v29.4s, v25.s[3] + fmla v14.4s, v30.4s, v25.s[3] + fmla v15.4s, v31.4s, v25.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v22.s}[0], [x27], #4 + ld1 {v25.s}[0], [x28], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v25.s[0] + fmla v13.4s, v29.4s, v25.s[0] + fmla v14.4s, v30.4s, v25.s[0] + fmla v15.4s, v31.4s, v25.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + fmax v8.4s, v8.4s, v24.4s + fmax v9.4s, v9.4s, v24.4s + fmax v10.4s, v10.4s, v24.4s + fmax v11.4s, v11.4s, v24.4s + fmax v12.4s, v12.4s, v24.4s + fmax v13.4s, v13.4s, v24.4s + fmax v14.4s, v14.4s, v24.4s + fmax v15.4s, v15.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + fmin v8.4s, v8.4s, v24.4s + fmin v9.4s, v9.4s, v24.4s + fmin v10.4s, v10.4s, v24.4s + fmin v11.4s, v11.4s, v24.4s + fmin v12.4s, v12.4s, v24.4s + fmin v13.4s, v13.4s, v24.4s + fmin v14.4s, v14.4s, v24.4s + fmin v15.4s, v15.4s, v24.4s + + WriteBack: + add x21, x0, x7, LSL #1 + add x22, x21, x7 + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0] + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x21] + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x22] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v4.4s}, [x0], #16 + st1 {v8.4s}, [x0], #16 + st1 {v12.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v5.4s}, [x20], #16 + st1 {v9.4s}, [x20], #16 + st1 {v13.4s}, [x20] + st1 {v2.4s}, [x21], #16 + st1 {v6.4s}, [x21], #16 + st1 {v10.4s}, [x21], #16 + st1 {v14.4s}, [x21] + st1 {v3.4s}, [x22], #16 + st1 {v7.4s}, [x22], #16 + st1 {v11.4s}, [x22], #16 + st1 {v15.4s}, [x22] + End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW4x8Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW4x8Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..0de222ce4f04c39638774b9cc45228f6988bb3b2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW4x8Kernel.S @@ -0,0 +1,406 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv4x8Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s}, [x3] + ld1 {v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s}, [x3] + ld1 {v6.4s, v7.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x27, x23, x13, lsl #1 + prfm pldl1keep, [x27] + add x28, x27, x13 + prfm pldl1keep, [x28] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v22.4s, v23.4s, v24.4s}, [x27], #48 + ld1 {v25.4s, v26.4s, v27.4s}, [x28], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v25.s[0] + fmla v7.4s, v29.4s, v25.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v25.s[1] + fmla v7.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v6.4s, v28.4s, v25.s[2] + fmla v7.4s, v29.4s, v25.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + fmla v6.4s, v30.4s, v25.s[3] + fmla v7.4s, v31.4s, v25.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v4.4s, v28.4s, v23.s[0] + fmla v5.4s, v29.4s, v23.s[0] + fmla v6.4s, v28.4s, v26.s[0] + fmla v7.4s, v29.4s, v26.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v30.4s, v23.s[1] + fmla v5.4s, v31.4s, v23.s[1] + fmla v6.4s, v30.4s, v26.s[1] + fmla v7.4s, v31.4s, v26.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v4.4s, v28.4s, v23.s[2] + fmla v5.4s, v29.4s, v23.s[2] + fmla v6.4s, v28.4s, v26.s[2] + fmla v7.4s, v29.4s, v26.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + fmla v4.4s, v30.4s, v23.s[3] + fmla v5.4s, v31.4s, v23.s[3] + fmla v6.4s, v30.4s, v26.s[3] + fmla v7.4s, v31.4s, v26.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v24.s[0] + fmla v5.4s, v29.4s, v24.s[0] + fmla v6.4s, v28.4s, v27.s[0] + fmla v7.4s, v29.4s, v27.s[0] + fmla v0.4s, v30.4s, v18.s[1] + fmla v1.4s, v31.4s, v18.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + fmla v4.4s, v30.4s, v24.s[1] + fmla v5.4s, v31.4s, v24.s[1] + fmla v6.4s, v30.4s, v27.s[1] + fmla v7.4s, v31.4s, v27.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v28.4s, v21.s[2] + fmla v3.4s, v29.4s, v21.s[2] + fmla v4.4s, v28.4s, v24.s[2] + fmla v5.4s, v29.4s, v24.s[2] + fmla v6.4s, v28.4s, v27.s[2] + fmla v7.4s, v29.4s, v27.s[2] + fmla v0.4s, v30.4s, v18.s[3] + fmla v1.4s, v31.4s, v18.s[3] + fmla v2.4s, v30.4s, v21.s[3] + fmla v3.4s, v31.4s, v21.s[3] + fmla v4.4s, v30.4s, v24.s[3] + fmla v5.4s, v31.4s, v24.s[3] + fmla v6.4s, v30.4s, v27.s[3] + fmla v7.4s, v31.4s, v27.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v22.4s, v23.4s}, [x27], #32 + ld1 {v25.4s, v26.4s}, [x28], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v25.s[0] + fmla v7.4s, v29.4s, v25.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v25.s[1] + fmla v7.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v6.4s, v28.4s, v25.s[2] + fmla v7.4s, v29.4s, v25.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + fmla v6.4s, v30.4s, v25.s[3] + fmla v7.4s, v31.4s, v25.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v4.4s, v28.4s, v23.s[0] + fmla v5.4s, v29.4s, v23.s[0] + fmla v6.4s, v28.4s, v26.s[0] + fmla v7.4s, v29.4s, v26.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v30.4s, v23.s[1] + fmla v5.4s, v31.4s, v23.s[1] + fmla v6.4s, v30.4s, v26.s[1] + fmla v7.4s, v31.4s, v26.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v4.4s, v28.4s, v23.s[2] + fmla v5.4s, v29.4s, v23.s[2] + fmla v6.4s, v28.4s, v26.s[2] + fmla v7.4s, v29.4s, v26.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + fmla v4.4s, v30.4s, v23.s[3] + fmla v5.4s, v31.4s, v23.s[3] + fmla v6.4s, v30.4s, v26.s[3] + fmla v7.4s, v31.4s, v26.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v22.4s}, [x27], #16 + ld1 {v25.4s}, [x28], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v25.s[0] + fmla v7.4s, v29.4s, v25.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v25.s[1] + fmla v7.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v6.4s, v28.4s, v25.s[2] + fmla v7.4s, v29.4s, v25.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + fmla v6.4s, v30.4s, v25.s[3] + fmla v7.4s, v31.4s, v25.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v22.s}[0], [x27], #4 + ld1 {v25.s}[0], [x28], #4 + ld1 {v28.4s, v29.4s}, [x2], #32 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v25.s[0] + fmla v7.4s, v29.4s, v25.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + + WriteBack: + add x21, x0, x7, LSL #1 + cmp x15, #13 + beq NC4HW4 + add x22, x21, x7 + st1 {v0.4s, v1.4s}, [x0] + st1 {v2.4s, v3.4s}, [x20] + st1 {v4.4s, v5.4s}, [x21] + st1 {v6.4s, v7.4s}, [x22] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v2.4s}, [x0], #16 + st1 {v4.4s}, [x0], #16 + st1 {v6.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v3.4s}, [x20], #16 + st1 {v5.4s}, [x20], #16 + st1 {v7.4s}, [x20] + End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW5x16Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW5x16Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..11583d539d32c06a7a12cd1302e64cb6afb6c594 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW5x16Kernel.S @@ -0,0 +1,457 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv5x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv5x16Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3] + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x3] + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x3] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x27, x23, x13, lsl #1 + prfm pldl1keep, [x27] + add x28, x27, x13 + prfm pldl1keep, [x28] + add x20, x23, x13, lsl #2 + prfm pldl1keep, [x20] + subs x25, x25, #4 + blt LoopCTail + LoopC4: + ld1 {v20.4s}, [x24], #16 + ld1 {v21.4s}, [x26], #16 + ld1 {v22.4s}, [x27], #16 + ld1 {v23.4s}, [x28], #16 + ld1 {v24.4s}, [x20], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v30.4s, v20.s[0] + fmla v3.4s, v31.4s, v20.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v23.s[0] + fmla v13.4s, v29.4s, v23.s[0] + fmla v14.4s, v30.4s, v23.s[0] + fmla v15.4s, v31.4s, v23.s[0] + fmla v16.4s, v28.4s, v24.s[0] + fmla v17.4s, v29.4s, v24.s[0] + fmla v18.4s, v30.4s, v24.s[0] + fmla v19.4s, v31.4s, v24.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[1] + fmla v1.4s, v29.4s, v20.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v23.s[1] + fmla v13.4s, v29.4s, v23.s[1] + fmla v14.4s, v30.4s, v23.s[1] + fmla v15.4s, v31.4s, v23.s[1] + fmla v16.4s, v28.4s, v24.s[1] + fmla v17.4s, v29.4s, v24.s[1] + fmla v18.4s, v30.4s, v24.s[1] + fmla v19.4s, v31.4s, v24.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[2] + fmla v1.4s, v29.4s, v20.s[2] + fmla v2.4s, v30.4s, v20.s[2] + fmla v3.4s, v31.4s, v20.s[2] + fmla v4.4s, v28.4s, v21.s[2] + fmla v5.4s, v29.4s, v21.s[2] + fmla v6.4s, v30.4s, v21.s[2] + fmla v7.4s, v31.4s, v21.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + fmla v12.4s, v28.4s, v23.s[2] + fmla v13.4s, v29.4s, v23.s[2] + fmla v14.4s, v30.4s, v23.s[2] + fmla v15.4s, v31.4s, v23.s[2] + fmla v16.4s, v28.4s, v24.s[2] + fmla v17.4s, v29.4s, v24.s[2] + fmla v18.4s, v30.4s, v24.s[2] + fmla v19.4s, v31.4s, v24.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[3] + fmla v1.4s, v29.4s, v20.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + fmla v4.4s, v28.4s, v21.s[3] + fmla v5.4s, v29.4s, v21.s[3] + fmla v6.4s, v30.4s, v21.s[3] + fmla v7.4s, v31.4s, v21.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + fmla v12.4s, v28.4s, v23.s[3] + fmla v13.4s, v29.4s, v23.s[3] + fmla v14.4s, v30.4s, v23.s[3] + fmla v15.4s, v31.4s, v23.s[3] + fmla v16.4s, v28.4s, v24.s[3] + fmla v17.4s, v29.4s, v24.s[3] + fmla v18.4s, v30.4s, v24.s[3] + fmla v19.4s, v31.4s, v24.s[3] + subs x25, x25, #4 + bge LoopC4 + LoopCTail: + add x25, x25, #4 + cbz x25, LoopCEnd + cmp x25, #3 + beq LoopCTail3 + cmp x25, #2 + beq LoopCTail2 + ld1 {v20.s}[0], [x24], #4 + ld1 {v21.s}[0], [x26], #4 + ld1 {v22.s}[0], [x27], #4 + ld1 {v23.s}[0], [x28], #4 + ld1 {v24.s}[0], [x20], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v30.4s, v20.s[0] + fmla v3.4s, v31.4s, v20.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v23.s[0] + fmla v13.4s, v29.4s, v23.s[0] + fmla v14.4s, v30.4s, v23.s[0] + fmla v15.4s, v31.4s, v23.s[0] + fmla v16.4s, v28.4s, v24.s[0] + fmla v17.4s, v29.4s, v24.s[0] + fmla v18.4s, v30.4s, v24.s[0] + fmla v19.4s, v31.4s, v24.s[0] + b LoopCEnd + LoopCTail2: + ld1 {v20.2s}, [x24], #8 + ld1 {v21.2s}, [x26], #8 + ld1 {v22.2s}, [x27], #8 + ld1 {v23.2s}, [x28], #8 + ld1 {v24.2s}, [x20], #8 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v30.4s, v20.s[0] + fmla v3.4s, v31.4s, v20.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v23.s[0] + fmla v13.4s, v29.4s, v23.s[0] + fmla v14.4s, v30.4s, v23.s[0] + fmla v15.4s, v31.4s, v23.s[0] + fmla v16.4s, v28.4s, v24.s[0] + fmla v17.4s, v29.4s, v24.s[0] + fmla v18.4s, v30.4s, v24.s[0] + fmla v19.4s, v31.4s, v24.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[1] + fmla v1.4s, v29.4s, v20.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v23.s[1] + fmla v13.4s, v29.4s, v23.s[1] + fmla v14.4s, v30.4s, v23.s[1] + fmla v15.4s, v31.4s, v23.s[1] + fmla v16.4s, v28.4s, v24.s[1] + fmla v17.4s, v29.4s, v24.s[1] + fmla v18.4s, v30.4s, v24.s[1] + fmla v19.4s, v31.4s, v24.s[1] + b LoopCEnd + LoopCTail3: + ld1 {v20.2s}, [x24], #8 + ld1 {v21.2s}, [x26], #8 + ld1 {v22.2s}, [x27], #8 + ld1 {v23.2s}, [x28], #8 + ld1 {v24.2s}, [x20], #8 + ld1 {v20.s}[2], [x24], #4 + ld1 {v21.s}[2], [x26], #4 + ld1 {v22.s}[2], [x27], #4 + ld1 {v23.s}[2], [x28], #4 + ld1 {v24.s}[2], [x20], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v30.4s, v20.s[0] + fmla v3.4s, v31.4s, v20.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v23.s[0] + fmla v13.4s, v29.4s, v23.s[0] + fmla v14.4s, v30.4s, v23.s[0] + fmla v15.4s, v31.4s, v23.s[0] + fmla v16.4s, v28.4s, v24.s[0] + fmla v17.4s, v29.4s, v24.s[0] + fmla v18.4s, v30.4s, v24.s[0] + fmla v19.4s, v31.4s, v24.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[1] + fmla v1.4s, v29.4s, v20.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v23.s[1] + fmla v13.4s, v29.4s, v23.s[1] + fmla v14.4s, v30.4s, v23.s[1] + fmla v15.4s, v31.4s, v23.s[1] + fmla v16.4s, v28.4s, v24.s[1] + fmla v17.4s, v29.4s, v24.s[1] + fmla v18.4s, v30.4s, v24.s[1] + fmla v19.4s, v31.4s, v24.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[2] + fmla v1.4s, v29.4s, v20.s[2] + fmla v2.4s, v30.4s, v20.s[2] + fmla v3.4s, v31.4s, v20.s[2] + fmla v4.4s, v28.4s, v21.s[2] + fmla v5.4s, v29.4s, v21.s[2] + fmla v6.4s, v30.4s, v21.s[2] + fmla v7.4s, v31.4s, v21.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + fmla v12.4s, v28.4s, v23.s[2] + fmla v13.4s, v29.4s, v23.s[2] + fmla v14.4s, v30.4s, v23.s[2] + fmla v15.4s, v31.4s, v23.s[2] + fmla v16.4s, v28.4s, v24.s[2] + fmla v17.4s, v29.4s, v24.s[2] + fmla v18.4s, v30.4s, v24.s[2] + fmla v19.4s, v31.4s, v24.s[2] + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + fmax v8.4s, v8.4s, v24.4s + fmax v9.4s, v9.4s, v24.4s + fmax v10.4s, v10.4s, v24.4s + fmax v11.4s, v11.4s, v24.4s + fmax v12.4s, v12.4s, v24.4s + fmax v13.4s, v13.4s, v24.4s + fmax v14.4s, v14.4s, v24.4s + fmax v15.4s, v15.4s, v24.4s + fmax v16.4s, v16.4s, v24.4s + fmax v17.4s, v17.4s, v24.4s + fmax v18.4s, v18.4s, v24.4s + fmax v19.4s, v19.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + fmin v8.4s, v8.4s, v24.4s + fmin v9.4s, v9.4s, v24.4s + fmin v10.4s, v10.4s, v24.4s + fmin v11.4s, v11.4s, v24.4s + fmin v12.4s, v12.4s, v24.4s + fmin v13.4s, v13.4s, v24.4s + fmin v14.4s, v14.4s, v24.4s + fmin v15.4s, v15.4s, v24.4s + fmin v16.4s, v16.4s, v24.4s + fmin v17.4s, v17.4s, v24.4s + fmin v18.4s, v18.4s, v24.4s + fmin v19.4s, v19.4s, v24.4s + + WriteBack: + add x20, x0, x7 + add x21, x0, x7, LSL #1 + add x23, x0, x7, LSL #2 + add x22, x20, x7, LSL #1 + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0] + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x21] + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x22] + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x23] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v4.4s}, [x0], #16 + st1 {v8.4s}, [x0], #16 + st1 {v12.4s}, [x0], #16 + st1 {v16.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v5.4s}, [x20], #16 + st1 {v9.4s}, [x20], #16 + st1 {v13.4s}, [x20], #16 + st1 {v17.4s}, [x20] + st1 {v2.4s}, [x21], #16 + st1 {v6.4s}, [x21], #16 + st1 {v10.4s}, [x21], #16 + st1 {v14.4s}, [x21], #16 + st1 {v18.4s}, [x21] + st1 {v3.4s}, [x22], #16 + st1 {v7.4s}, [x22], #16 + st1 {v11.4s}, [x22], #16 + st1 {v15.4s}, [x22], #16 + st1 {v19.4s}, [x22] + End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW5x8Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW5x8Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..58181f0fc586cb49bff5a428aa8f9c3f6f5648a5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW5x8Kernel.S @@ -0,0 +1,308 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv5x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv5x8Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s}, [x3] + ld1 {v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s}, [x3] + ld1 {v6.4s, v7.4s}, [x3] + ld1 {v8.4s, v9.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + dup v8.2d, xzr + dup v9.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x27, x23, x13, lsl #1 + prfm pldl1keep, [x27] + add x28, x27, x13 + prfm pldl1keep, [x28] + add x20, x23, x13, lsl #2 + prfm pldl1keep, [x20] + subs x25, x25, #4 + blt LoopCTail + LoopC4: + ld1 {v20.4s}, [x24], #16 + ld1 {v21.4s}, [x26], #16 + ld1 {v22.4s}, [x27], #16 + ld1 {v23.4s}, [x28], #16 + ld1 {v24.4s}, [x20], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v23.s[0] + fmla v7.4s, v29.4s, v23.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + fmla v0.4s, v30.4s, v20.s[1] + fmla v1.4s, v31.4s, v20.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v23.s[1] + fmla v7.4s, v31.4s, v23.s[1] + fmla v8.4s, v30.4s, v24.s[1] + fmla v9.4s, v31.4s, v24.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[2] + fmla v1.4s, v29.4s, v20.s[2] + fmla v2.4s, v28.4s, v21.s[2] + fmla v3.4s, v29.4s, v21.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v6.4s, v28.4s, v23.s[2] + fmla v7.4s, v29.4s, v23.s[2] + fmla v8.4s, v28.4s, v24.s[2] + fmla v9.4s, v29.4s, v24.s[2] + fmla v0.4s, v30.4s, v20.s[3] + fmla v1.4s, v31.4s, v20.s[3] + fmla v2.4s, v30.4s, v21.s[3] + fmla v3.4s, v31.4s, v21.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + fmla v6.4s, v30.4s, v23.s[3] + fmla v7.4s, v31.4s, v23.s[3] + fmla v8.4s, v30.4s, v24.s[3] + fmla v9.4s, v31.4s, v24.s[3] + subs x25, x25, #4 + bge LoopC4 + LoopCTail: + add x25, x25, #4 + cbz x25, LoopCEnd + cmp x25, #3 + beq LoopCTail3 + cmp x25, #2 + beq LoopCTail2 + ld1 {v20.s}[0], [x24], #4 + ld1 {v21.s}[0], [x26], #4 + ld1 {v22.s}[0], [x27], #4 + ld1 {v23.s}[0], [x28], #4 + ld1 {v24.s}[0], [x20], #4 + ld1 {v28.4s, v29.4s}, [x2], #32 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v23.s[0] + fmla v7.4s, v29.4s, v23.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + b LoopCEnd + LoopCTail2: + ld1 {v20.2s}, [x24], #8 + ld1 {v21.2s}, [x26], #8 + ld1 {v22.2s}, [x27], #8 + ld1 {v23.2s}, [x28], #8 + ld1 {v24.2s}, [x20], #8 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v23.s[0] + fmla v7.4s, v29.4s, v23.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + fmla v0.4s, v30.4s, v20.s[1] + fmla v1.4s, v31.4s, v20.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v23.s[1] + fmla v7.4s, v31.4s, v23.s[1] + fmla v8.4s, v30.4s, v24.s[1] + fmla v9.4s, v31.4s, v24.s[1] + b LoopCEnd + LoopCTail3: + ld1 {v20.2s}, [x24], #8 + ld1 {v21.2s}, [x26], #8 + ld1 {v22.2s}, [x27], #8 + ld1 {v23.2s}, [x28], #8 + ld1 {v24.2s}, [x20], #8 + ld1 {v20.s}[2], [x24], #4 + ld1 {v21.s}[2], [x26], #4 + ld1 {v22.s}[2], [x27], #4 + ld1 {v23.s}[2], [x28], #4 + ld1 {v24.s}[2], [x20], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v23.s[0] + fmla v7.4s, v29.4s, v23.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + ld1 {v26.4s, v27.4s}, [x2], #32 + fmla v0.4s, v30.4s, v20.s[1] + fmla v1.4s, v31.4s, v20.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v23.s[1] + fmla v7.4s, v31.4s, v23.s[1] + fmla v8.4s, v30.4s, v24.s[1] + fmla v9.4s, v31.4s, v24.s[1] + fmla v0.4s, v26.4s, v20.s[2] + fmla v1.4s, v27.4s, v20.s[2] + fmla v2.4s, v26.4s, v21.s[2] + fmla v3.4s, v27.4s, v21.s[2] + fmla v4.4s, v26.4s, v22.s[2] + fmla v5.4s, v27.4s, v22.s[2] + fmla v6.4s, v26.4s, v23.s[2] + fmla v7.4s, v27.4s, v23.s[2] + fmla v8.4s, v26.4s, v24.s[2] + fmla v9.4s, v27.4s, v24.s[2] + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + fmax v8.4s, v8.4s, v24.4s + fmax v9.4s, v9.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + fmin v8.4s, v8.4s, v24.4s + fmin v9.4s, v9.4s, v24.4s + + WriteBack: + add x20, x0, x7 + cmp x15, #13 + beq NC4HW4 + add x21, x0, x7, LSL #1 + add x23, x0, x7, LSL #2 + add x22, x20, x7, LSL #1 + st1 {v0.4s, v1.4s}, [x0] + st1 {v2.4s, v3.4s}, [x20] + st1 {v4.4s, v5.4s}, [x21] + st1 {v6.4s, v7.4s}, [x22] + st1 {v8.4s, v9.4s}, [x23] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v2.4s}, [x0], #16 + st1 {v4.4s}, [x0], #16 + st1 {v6.4s}, [x0], #16 + st1 {v8.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v3.4s}, [x20], #16 + st1 {v5.4s}, [x20], #16 + st1 {v7.4s}, [x20], #16 + st1 {v9.4s}, [x20] + End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwFp32Border.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwFp32Border.S new file mode 100644 index 0000000000000000000000000000000000000000..74723e98816adbe4d6ca7e4451edf4296f29c5d3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwFp32Border.S @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwFp32Border(float *dst, const float *src, const float *weight, size_t height, size_t width, +// size_t in_kh_step, size_t in_kw_step, size_t kernel_w) + +// x0: dst, x1: src, x2: weight, x3: height, x4: width, x5: in_kh_step, x6: in_kw_step, x7: kernel_w +asm_function DeconvDwFp32Border + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + cmp x3, #0 + beq End + cmp x4, #0 + beq End + ld1 {v1.4s}, [x1] + + mov x13, x0 + mov x14, x2 + LoopH: + mov x15, x13 + mov x16, x14 + mov x17, x4 + LoopW: + ld1 {v0.4s}, [x15] + ld1 {v2.4s}, [x16], #16 + fmla v0.4s, v1.4s, v2.4s + st1 {v0.4s}, [x15], x6 + subs x17, x17, #1 + bne LoopW + subs x3, x3, #1 + add x13, x13, x5 + add x14, x14, x7 + bne LoopH + End: + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwFp32Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwFp32Center.S new file mode 100644 index 0000000000000000000000000000000000000000..1ef311ee4a3eb4e1a0664ee6eb12de7041de770e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwFp32Center.S @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step); +// x0: dst, x1: src, x2: weight, x3: height, x4: weight, x5: kernel_h, x6: kernel_w, x7: out_h_step +// x8: block_channel, x9: in_sh_step, x10: in_sw_step, x11: in_kh_step, x12: in_kw_step +asm_function DeconvDwFp32Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + ldr x8, [sp, #32] + ldr x9, [sp, #40] + ldr x10, [sp, #48] + ldr x11, [sp, #56] + ldr x12, [sp, #64] + + LoopH: + mov x15, x0 + mov x16, x1 + mov x17, x4 + LoopW: + mov x22, x15 + mov x19, x2 + mov x20, x5 + ld1 {v1.4s}, [x16], x8 + LoopKh: + mov x21, x22 + mov x13, x6 + LoopKw: + ld1 {v0.4s}, [x21] + ld1 {v2.4s}, [x19], #16 + fmla v0.4s, v1.4s, v2.4s + st1 {v0.4s}, [x21], x12 + subs x13, x13, #1 + bne LoopKw + add x22, x22, x11 + subs x20, x20, #1 + bne LoopKh + add x15, x15, x10 + subs x17, x17, #1 + bne LoopW + add x0, x0, x9 + add x1, x1, x7 + subs x3, x3, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwInt8Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwInt8Center.S new file mode 100644 index 0000000000000000000000000000000000000000..299d37008d744b11d07a3caa845238f2d39f3c34 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwInt8Center.S @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step); +// x0: dst, x1: src, x2: weight, x3: height, x4: weight, x5: kernel_h, x6: kernel_w, x7: out_h_step +// x8: block_channel, x9: in_sh_step, x10: in_sw_step, x11: in_kh_step, x12: in_kw_step +asm_function DeconvDwInt8Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + ldr x8, [sp, #32] + ldr x9, [sp, #40] + ldr x10, [sp, #48] + ldr x11, [sp, #56] + ldr x12, [sp, #64] + + LoopH: + mov x15, x0 + mov x16, x1 + mov x17, x4 + LoopW: + mov x18, x15 + mov x19, x2 + mov x20, x5 + ld1 {v1.4h}, [x16], x8 + LoopKh: + mov x21, x18 + mov x13, x6 + LoopKw: + ld1 {v0.4s}, [x21] + ld1 {v2.4h}, [x19], #8 + smlal v0.4s, v1.4h, v2.4h + st1 {v0.4s}, [x21], x12 + subs x13, x13, #1 + bne LoopKw + add x18, x18, x11 + subs x20, x20, #1 + bne LoopKh + add x15, x15, x10 + subs x17, x17, #1 + bne LoopW + add x0, x0, x9 + add x1, x1, x7 + subs x3, x3, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwInt8Post.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwInt8Post.S new file mode 100644 index 0000000000000000000000000000000000000000..ff9d6a648d8f98dbff2f0ee8b468def9ff3c1310 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwInt8Post.S @@ -0,0 +1,66 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwInt8Post(int8_t *dst, int32_t *output_buffer, const int32_t *bias, int block_channel, int pixel_nums, +// int out_multiplier, int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, +// int32_t acc_max) +// x0: dst, x1: output_buffer, x2: bias, x3: block_channel, x4: pixel_nums, x5: out_multiplier +// x6: left_shift, x7: right_shift, x8: out_zp, x9: acc_min, x10: acc_max + +asm_function DeconvDwInt8Post + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ld1 {v25.4s}, [x2] + + dup v26.4s, w6 // left_shift + dup v27.4s, w5 // out_multiplier + dup v28.4s, w7 // right_shift + + ldr w8, [sp] + dup v29.4s, w8 // out_zp + ldr w9, [sp, #8] + dup v30.4s, w9 // acc_min + ldr w10, [sp, #16] + dup v31.4s, w10 // acc_max + + LoopCount: + ld1 {v0.4s}, [x1], #16 + add v0.4s, v0.4s, v25.4s + sqshl v0.4s, v0.4s, v26.4s + sqrdmulh v0.4s, v0.4s, v27.4s + sqrshl v0.4s, v0.4s, v28.4s + + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v0.8b, v0.8h + + st1 {v0.s}[0], [x0], x3 + + sub x4, x4, #1 + cmp x4, #1 + bge LoopCount + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DynamicGatherArm64.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DynamicGatherArm64.S new file mode 100644 index 0000000000000000000000000000000000000000..ef3a39fb50a2e71151b23b84c9a7009bacae5ad0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DynamicGatherArm64.S @@ -0,0 +1,48 @@ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" +.text +.align 5 + +// void DynamicGatherArm64(const int8_t *src, float *output, int count_16, int zp, float scale); +// x0: src(left matrix ptr) +// x1: output(right matrix ptr) +// w2: count_16 +// w3: zp +// w4: scale + +asm_function DynamicGatherArm64 + mov x5, x0 // reload src + mov x6, x1 // reload out + mov w7, w2 // reload count_16 + dup v1.4s, w3 // zp + dup v2.4s, v0.s[0] // scale + + LoopCount: + ld1 {v0.16b}, [x5], #16 + + sxtl v3.8h, v0.8b + sxtl2 v4.8h, v0.16b + + sxtl v16.4s, v3.4h + sxtl2 v17.4s, v3.8h + sxtl v18.4s, v4.4h + sxtl2 v19.4s, v4.8h + + sub v16.4s, v16.4s, v1.4s + scvtf v16.4s,v16.4s + fmul v16.4s, v16.4s, v2.4s + sub v17.4s, v17.4s, v1.4s + scvtf v17.4s,v17.4s + fmul v17.4s, v17.4s, v2.4s + sub v18.4s, v18.4s, v1.4s + scvtf v18.4s,v18.4s + fmul v18.4s, v18.4s, v2.4s + sub v19.4s, v19.4s, v1.4s + scvtf v19.4s,v19.4s + fmul v19.4s, v19.4s, v2.4s + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x6], #64 + subs w7, w7, #16 + bgt LoopCount +ret + +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/IndirectGemmInt16to32_8x4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/IndirectGemmInt16to32_8x4.S new file mode 100644 index 0000000000000000000000000000000000000000..ba60f39021c876a0dde568e920c4d4d4820ecb27 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/IndirectGemmInt16to32_8x4.S @@ -0,0 +1,233 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void IndirectGemmInt16to32_8x4(int *output, short *input, short *weight, size_t ksize, size_t ic8, size_t oc4, size_t offset); +// x0: output, x1: input, x2: weight, x3: ksize, x4: ic8, x5: oc4, x6: offset +asm_function IndirectGemmInt16to32_8x4 + + .macro INIT_ZERO + dup v28.4s, wzr + mov v29.16b, v28.16b + mov v30.16b, v28.16b + mov v31.16b, v28.16b + .endm + + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + LoopOc: + mov x7, x3 + mov x8, x1 + + LoopKsize: + mov x9, x0 + INIT_ZERO + + // load input + ld1 {v0.8h, v1.8h}, [x8], #32 + // load weight + ld1 {v16.8h}, [x2], #16 + smull v24.4s, v16.4h, v0.h[0] + smull v25.4s, v16.4h, v1.h[0] + // load weight + ld1 {v17.8h}, [x2], #16 + smlal2 v24.4s, v16.8h, v0.h[1] + smlal2 v25.4s, v16.8h, v1.h[1] + // load input + ld1 {v2.8h, v3.8h}, [x8], #32 + smlal v24.4s, v17.4h, v0.h[2] + smlal v25.4s, v17.4h, v1.h[2] + smlal2 v24.4s, v17.8h, v0.h[3] + smlal2 v25.4s, v17.8h, v1.h[3] + // load weight + ld1 {v18.8h, v19.8h}, [x2], #32 + smull v26.4s, v16.4h, v2.h[0] + smull v27.4s, v16.4h, v3.h[0] + + subs x10, x4, #1 + beq LoopIcEnd + + LoopIc: + + smlal2 v26.4s, v16.8h, v2.h[1] + smlal2 v27.4s, v16.8h, v3.h[1] + smlal v26.4s, v17.4h, v2.h[2] + smlal v27.4s, v17.4h, v3.h[2] + smlal2 v26.4s, v17.8h, v2.h[3] + smlal2 v27.4s, v17.8h, v3.h[3] + + smlal v24.4s, v18.4h, v0.h[4] + smlal v25.4s, v18.4h, v1.h[4] + smlal2 v24.4s, v18.8h, v0.h[5] + smlal2 v25.4s, v18.8h, v1.h[5] + smlal v24.4s, v19.4h, v0.h[6] + smlal v25.4s, v19.4h, v1.h[6] + smlal2 v24.4s, v19.8h, v0.h[7] + smlal2 v25.4s, v19.8h, v1.h[7] + // load input + ld1 {v4.8h, v5.8h}, [x8], #32 + smlal v26.4s, v18.4h, v2.h[4] + smlal v27.4s, v18.4h, v3.h[4] + smlal2 v26.4s, v18.8h, v2.h[5] + smlal2 v27.4s, v18.8h, v3.h[5] + smlal v26.4s, v19.4h, v2.h[6] + smlal v27.4s, v19.4h, v3.h[6] + smlal2 v26.4s, v19.8h, v2.h[7] + smlal2 v27.4s, v19.8h, v3.h[7] + + // load input + ld1 {v6.8h, v7.8h}, [x8], #32 + smlal v28.4s, v16.4h, v4.h[0] + smlal v29.4s, v16.4h, v5.h[0] + smlal2 v28.4s, v16.8h, v4.h[1] + smlal2 v29.4s, v16.8h, v5.h[1] + smlal v28.4s, v17.4h, v4.h[2] + smlal v29.4s, v17.4h, v5.h[2] + smlal2 v28.4s, v17.8h, v4.h[3] + smlal2 v29.4s, v17.8h, v5.h[3] + + smlal v30.4s, v16.4h, v6.h[0] + smlal v31.4s, v16.4h, v7.h[0] + smlal2 v30.4s, v16.8h, v6.h[1] + smlal2 v31.4s, v16.8h, v7.h[1] + smlal v30.4s, v17.4h, v6.h[2] + smlal v31.4s, v17.4h, v7.h[2] + smlal2 v30.4s, v17.8h, v6.h[3] + smlal2 v31.4s, v17.8h, v7.h[3] + // load weight + ld1 {v16.8h, v17.8h}, [x2], #32 + smlal v28.4s, v18.4h, v4.h[4] + smlal v29.4s, v18.4h, v5.h[4] + smlal2 v28.4s, v18.8h, v4.h[5] + smlal2 v29.4s, v18.8h, v5.h[5] + smlal v28.4s, v19.4h, v4.h[6] + smlal v29.4s, v19.4h, v5.h[6] + smlal2 v28.4s, v19.8h, v4.h[7] + smlal2 v29.4s, v19.8h, v5.h[7] + // load input + ld1 {v0.8h, v1.8h}, [x8], #32 + smlal v30.4s, v18.4h, v6.h[4] + smlal v31.4s, v18.4h, v7.h[4] + smlal2 v30.4s, v18.8h, v6.h[5] + smlal2 v31.4s, v18.8h, v7.h[5] + smlal v30.4s, v19.4h, v6.h[6] + smlal v31.4s, v19.4h, v7.h[6] + smlal2 v30.4s, v19.8h, v6.h[7] + smlal2 v31.4s, v19.8h, v7.h[7] + // load input + ld1 {v2.8h, v3.8h}, [x8], #32 + smlal v24.4s, v16.4h, v0.h[0] + smlal v25.4s, v16.4h, v1.h[0] + smlal2 v24.4s, v16.8h, v0.h[1] + smlal2 v25.4s, v16.8h, v1.h[1] + // load weight + ld1 {v18.8h, v19.8h}, [x2], #32 + smlal v24.4s, v17.4h, v0.h[2] + smlal v25.4s, v17.4h, v1.h[2] + smlal2 v24.4s, v17.8h, v0.h[3] + smlal2 v25.4s, v17.8h, v1.h[3] + smlal v26.4s, v16.4h, v2.h[0] + smlal v27.4s, v16.4h, v3.h[0] + + subs x10, x10, #1 + bne LoopIc + + LoopIcEnd: + smlal2 v26.4s, v16.8h, v2.h[1] + smlal2 v27.4s, v16.8h, v3.h[1] + smlal v26.4s, v17.4h, v2.h[2] + smlal v27.4s, v17.4h, v3.h[2] + smlal2 v26.4s, v17.8h, v2.h[3] + smlal2 v27.4s, v17.8h, v3.h[3] + + smlal v24.4s, v18.4h, v0.h[4] + smlal v25.4s, v18.4h, v1.h[4] + smlal2 v24.4s, v18.8h, v0.h[5] + smlal2 v25.4s, v18.8h, v1.h[5] + smlal v24.4s, v19.4h, v0.h[6] + smlal v25.4s, v19.4h, v1.h[6] + smlal2 v24.4s, v19.8h, v0.h[7] + smlal2 v25.4s, v19.8h, v1.h[7] + // load input + ld1 {v4.8h, v5.8h}, [x8], #32 + smlal v26.4s, v18.4h, v2.h[4] + smlal v27.4s, v18.4h, v3.h[4] + smlal2 v26.4s, v18.8h, v2.h[5] + st1 {v24.4s}, [x9], x6 + smlal2 v27.4s, v18.8h, v3.h[5] + smlal v26.4s, v19.4h, v2.h[6] + st1 {v25.4s}, [x9], x6 + smlal v27.4s, v19.4h, v3.h[6] + smlal2 v26.4s, v19.8h, v2.h[7] + smlal2 v27.4s, v19.8h, v3.h[7] + + // load input + ld1 {v6.8h, v7.8h}, [x8], #32 + smlal v28.4s, v16.4h, v4.h[0] + smlal v29.4s, v16.4h, v5.h[0] + smlal2 v28.4s, v16.8h, v4.h[1] + smlal2 v29.4s, v16.8h, v5.h[1] + smlal v28.4s, v17.4h, v4.h[2] + st1 {v26.4s}, [x9], x6 + smlal v29.4s, v17.4h, v5.h[2] + smlal2 v28.4s, v17.8h, v4.h[3] + smlal2 v29.4s, v17.8h, v5.h[3] + st1 {v27.4s}, [x9], x6 + smlal v30.4s, v16.4h, v6.h[0] + smlal v31.4s, v16.4h, v7.h[0] + smlal2 v30.4s, v16.8h, v6.h[1] + smlal2 v31.4s, v16.8h, v7.h[1] + smlal v30.4s, v17.4h, v6.h[2] + smlal v31.4s, v17.4h, v7.h[2] + smlal2 v30.4s, v17.8h, v6.h[3] + smlal2 v31.4s, v17.8h, v7.h[3] + smlal v28.4s, v18.4h, v4.h[4] + smlal v29.4s, v18.4h, v5.h[4] + smlal2 v28.4s, v18.8h, v4.h[5] + smlal2 v29.4s, v18.8h, v5.h[5] + smlal v28.4s, v19.4h, v4.h[6] + smlal v29.4s, v19.4h, v5.h[6] + smlal2 v28.4s, v19.8h, v4.h[7] + smlal2 v29.4s, v19.8h, v5.h[7] + smlal v30.4s, v18.4h, v6.h[4] + smlal v31.4s, v18.4h, v7.h[4] + st1 {v28.4s}, [x9], x6 + smlal2 v30.4s, v18.8h, v6.h[5] + smlal2 v31.4s, v18.8h, v7.h[5] + smlal v30.4s, v19.4h, v6.h[6] + st1 {v29.4s}, [x9], x6 + smlal v31.4s, v19.4h, v7.h[6] + smlal2 v30.4s, v19.8h, v6.h[7] + smlal2 v31.4s, v19.8h, v7.h[7] + + st1 {v30.4s}, [x9], x6 + st1 {v31.4s}, [x9] + + subs x7, x7, #1 + add x0, x0, #16 + bne LoopKsize + + subs x5, x5, #1 + bne LoopOc + + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatVecMulFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatVecMulFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..bd427335ea6f6aed0e2221dbf7232a2857315bbb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatVecMulFp32.S @@ -0,0 +1,252 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: col + +asm_default_function MatVecMulFp32 + sub sp, sp, #128 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + + mov w14, #4 // sizeof(float) + mul w8, w14, w5 // rhs depthx1 block stride + mov w14, #4 + mul w13, w8, w14 // rhs depthx4 block stride + +Loop: + mov x15, x0 // reload a ptr + mov x7, x1 // reload b ptr + mov w9, w5 // reload depth + cmp w6, #4 + blt Loop1x1 + +Loop1x4: + dup v10.8h, wzr + dup v11.8h, wzr + dup v12.8h, wzr + dup v13.8h, wzr + dup v14.8h, wzr + + add x10, x7, x8 + add x11, x10, x8 + add x12, x11, x8 + +Depth8_1x4: + cmp w9, #8 + blt Depth4_1x4 + sub w9, w9, #8 + ld1 {v0.4s, v1.4s}, [x15], #32 + ld1 {v2.4s, v3.4s}, [x7], #32 + ld1 {v4.4s, v5.4s}, [x10], #32 + cmp w9, #8 + blt Depth8_1x4_Loop_End + +Depth8_1x4_Loop: + fmla v10.4s, v0.4s, v2.4s + fmla v10.4s, v1.4s, v3.4s + ld1 {v6.4s, v7.4s}, [x11], #32 + fmla v11.4s, v0.4s, v4.4s + fmla v11.4s, v1.4s, v5.4s + ld1 {v8.4s, v9.4s}, [x12], #32 + fmla v12.4s, v0.4s, v6.4s + fmla v12.4s, v1.4s, v7.4s + ld1 {v2.4s, v3.4s}, [x7], #32 + fmla v13.4s, v0.4s, v8.4s + fmla v13.4s, v1.4s, v9.4s + ld1 {v0.4s, v1.4s}, [x15], #32 + ld1 {v4.4s, v5.4s}, [x10], #32 + sub w9, w9, #8 + cmp w9, #8 + bge Depth8_1x4_Loop + +Depth8_1x4_Loop_End: + fmla v10.4s, v0.4s, v2.4s + fmla v10.4s, v1.4s, v3.4s + ld1 {v6.4s, v7.4s}, [x11], #32 + fmla v11.4s, v0.4s, v4.4s + fmla v11.4s, v1.4s, v5.4s + ld1 {v8.4s, v9.4s}, [x12], #32 + fmla v12.4s, v0.4s, v6.4s + fmla v12.4s, v1.4s, v7.4s + fmla v13.4s, v0.4s, v8.4s + fmla v13.4s, v1.4s, v9.4s + +Depth4_1x4: + cmp w9, #4 + blt Depth1_1x4 + sub w9, w9, #4 + ld1 {v0.4s}, [x15], #16 + ld1 {v1.4s}, [x7], #16 + ld1 {v2.4s}, [x10], #16 + cmp w9, #4 + blt Depth4_1x4_Loop_End + +Depth4_1x4_Loop: + fmla v10.4s, v1.4s, v0.4s + ld1 {v3.4s}, [x11], #16 + fmla v11.4s, v2.4s, v0.4s + ld1 {v4.4s}, [x12], #16 + fmla v12.4s, v3.4s, v0.4s + ld1 {v1.4s}, [x7], #16 + fmla v13.4s, v4.4s, v0.4s + ld1 {v0.4s}, [x15], #16 + ld1 {v2.4s}, [x10], #16 + sub w9, w9, #4 + cmp w9, #4 + bge Depth4_1x4_Loop + +Depth4_1x4_Loop_End: + fmla v10.4s, v1.4s, v0.4s + ld1 {v3.4s}, [x11], #16 + fmla v11.4s, v2.4s, v0.4s + ld1 {v4.4s}, [x12], #16 + fmla v12.4s, v3.4s, v0.4s + fmla v13.4s, v4.4s, v0.4s + +Depth1_1x4: + cmp w9, #0 + beq End1x4 + ld1 {v0.s}[0], [x15], #4 + ld1 {v1.s}[0], [x7], #4 + ld1 {v1.s}[1], [x10], #4 + ld1 {v1.s}[2], [x11], #4 + ld1 {v1.s}[3], [x12], #4 + + fmla v14.4s, v1.4s, v0.s[0] + sub w9, w9, #1 + cbz w9, End1x4 + b Depth1_1x4 + +End1x4: + faddp v15.4s, v10.4s, v11.4s + faddp v16.4s, v12.4s, v13.4s + faddp v17.4s, v15.4s, v16.4s + fadd v14.4s, v14.4s, v17.4s + + cbz x3, Act1x4 + ld1 {v15.4s}, [x3], #16 + fadd v14.4s, v14.4s, v15.4s // add bias + +Act1x4: + cmp w4, #3 + beq Relu6_1x4 + cmp w4, #1 + beq Relu1x4 + b Write1x4 + +Relu6_1x4: + movi v15.4s, #0x46, lsl #8 + fmin v14.4s, v14.4s, v15.4s + +Relu1x4: + dup v15.4s, wzr + fmax v14.4s, v14.4s, v15.4s + +Write1x4: + st1 {v14.4s}, [x2], #16 + sub w6, w6, #4 + cbz w6, End + add x1, x1, x13 + b Loop + + +Loop1x1: + dup v4.4s, wzr + dup v5.4s, wzr + +Depth8_1x1: + cmp w9, #8 + blt Depth4_1x1 + + ld1 {v0.4s, v1.4s}, [x15], #32 + ld1 {v2.4s, v3.4s}, [x7], #32 + + fmla v4.4s, v2.4s, v0.4s + fmla v4.4s, v3.4s, v1.4s + sub w9, w9, #8 + cbz w9, End1x1 + b Depth8_1x1 + +Depth4_1x1: + cmp w9, #4 + blt Depth1_1x1 + + ld1 {v0.4s}, [x15], #16 + ld1 {v1.4s}, [x7], #16 + + fmla v4.4s, v1.4s, v0.4s + sub w9, w9, #4 + cbz w9, End1x1 + b Depth8_1x1 + +Depth1_1x1: + ld1 {v0.s}[0], [x15], #4 + ld1 {v1.s}[0], [x7], #4 + + fmla v5.4s, v1.4s, v0.s[0] + sub w9, w9, #1 + cbz w9, End1x1 + b Depth1_1x1 + +End1x1: + faddp v6.4s, v4.4s, v4.4s + faddp v7.4s, v6.4s, v6.4s + fadd v7.4s, v7.4s, v5.4s + + cbz x3, Act1x1 + ld1 {v8.s}[0], [x3], #4 + fadd v7.4s, v7.4s, v8.4s // add bias + +Act1x1: + cmp w4, #3 + beq Relu6_1x1 + cmp w4, #1 + beq Relu1x1 + b Write1x1 + +Relu6_1x1: + movi v8.4s, #0x46, lsl #8 + fmin v7.4s, v7.4s, v8.4s + +Relu1x1: + dup v8.4s, wzr + fmax v7.4s, v7.4s, v8.4s + +Write1x1: + st1 {v7.s}[0], [x2], #4 + sub w6, w6, #1 + cbz w6, End + add x1, x1, x8 + b Loop + +End: + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ret +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatVecMulPackFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatVecMulPackFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..058a807ce065db20dd0049b30d21c9a12ecc053c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatVecMulPackFp32.S @@ -0,0 +1,198 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatVecMulPackFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: col + +asm_default_function MatVecMulPackFp32 + sub sp, sp, #16 + stp x29, x30, [sp] + + dup v1.2d, xzr + mov w7, #6 + dup v2.4s, w7 + scvtf v2.4s, v2.4s + subs w6, w6, #8 + blt Loop1xNStart + Loop1x8Start: + bl Compute1x8Unit + st1 {v24.4s, v25.4s}, [x2], #32 + subs w6, w6, #8 + bge Loop1x8Start + + Loop1xNStart: + add w6, w6, #8 + cbz w6, End + subs w6, w6, #4 + ble Loop1x4Start + bl Compute1x8Unit + st1 {v24.4s}, [x2], #16 + st1 {v25.s}[0], [x2], #4 + cmp w6, #1 + beq End + st1 {v25.s}[1], [x2], #4 + cmp w6, #2 + beq End + st1 {v25.s}[2], [x2] + b End + + Loop1x4Start: + add w6, w6, #4 + cbz w6, End + bl Compute1x4Unit + st1 {v24.s}[0], [x2], #4 + cmp w6, #1 + beq End + st1 {v24.s}[1], [x2], #4 + cmp w6, #2 + beq End + st1 {v24.s}[2], [x2], #4 + cmp w6, #3 + beq End + st1 {v24.s}[3], [x2], #4 + b End + + Compute1x8Unit: + mov x7, x0 // reload a-ptr + mov w8, w5 // reset depth + dup v24.2d, xzr + dup v25.2d, xzr + dup v26.2d, xzr + dup v27.2d, xzr + dup v28.2d, xzr + dup v29.2d, xzr + dup v30.2d, xzr + dup v31.2d, xzr + cbz x3, Compute1x8Enter + ld1 {v24.4s, v25.4s}, [x3], #32 + Compute1x8Enter: + subs w8, w8, #4 + blt Compute1x8Tail + Compute1x8: + ld1 {v0.4s}, [x7], #16 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64 + fmla v24.4s, v16.4s, v0.s[0] + fmla v25.4s, v17.4s, v0.s[0] + fmla v26.4s, v18.4s, v0.s[1] + fmla v27.4s, v19.4s, v0.s[1] + fmla v28.4s, v20.4s, v0.s[2] + fmla v29.4s, v21.4s, v0.s[2] + fmla v30.4s, v22.4s, v0.s[3] + fmla v31.4s, v23.4s, v0.s[3] + subs w8, w8, #4 + bge Compute1x8 + Compute1x8Tail: + add w8, w8, #4 + cbz w8, Compute1x8UnionTail + Compute1x8DepthTail: + ld1 {v0.s}[0], [x7], #4 + ld1 {v16.4s, v17.4s}, [x1], #32 + fmla v24.4s, v16.4s, v0.s[0] + fmla v25.4s, v17.4s, v0.s[0] + subs w8, w8, #1 + bgt Compute1x8DepthTail + Compute1x8UnionTail: + fadd v24.4s, v24.4s, v26.4s + fadd v25.4s, v25.4s, v27.4s + fadd v28.4s, v28.4s, v30.4s + fadd v29.4s, v29.4s, v31.4s + fadd v24.4s, v24.4s, v28.4s + fadd v25.4s, v25.4s, v29.4s + Act1x8: + cmp x4, #3 + beq Relu61x8 + cmp x4, #1 + beq Relu1x8 + b Return1x8 + Relu61x8: + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmax v24.4s, v24.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + b Return1x8 + Relu1x8: + fmax v24.4s, v24.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + Return1x8: + ret + + Compute1x4Unit: + mov x7, x0 // reload a-ptr + mov w8, w5 // reset depth + dup v24.2d, xzr + dup v26.2d, xzr + dup v28.2d, xzr + dup v30.2d, xzr + cbz x3, Compute1x4Enter + ld1 {v24.4s}, [x3] + Compute1x4Enter: + subs w8, w8, #4 + blt Compute1x4Tail + Compute1x4: + ld1 {v0.4s}, [x7], #16 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64 + fmla v24.4s, v16.4s, v0.s[0] + fmla v26.4s, v18.4s, v0.s[1] + fmla v28.4s, v20.4s, v0.s[2] + fmla v30.4s, v22.4s, v0.s[3] + subs w8, w8, #4 + bge Compute1x4 + Compute1x4Tail: + add w8, w8, #4 + cbz w8, Compute1x4UnionTail + Compute1x4DepthTail: + ld1 {v0.s}[0], [x7], #4 + ld1 {v16.4s}, [x1] + add x1, x1, #32 + fmla v24.4s, v16.4s, v0.s[0] + subs w8, w8, #1 + bgt Compute1x4DepthTail + Compute1x4UnionTail: + fadd v24.4s, v24.4s, v26.4s + fadd v28.4s, v28.4s, v30.4s + fadd v24.4s, v24.4s, v28.4s + Act1x4: + cmp x4, #3 + beq Relu61x4 + cmp x4, #1 + beq Relu1x4 + b Return1x4 + Relu61x4: + fmin v24.4s, v24.4s, v2.4s + fmax v24.4s, v24.4s, v1.4s + b Return1x8 + Relu1x4: + fmax v24.4s, v24.4s, v1.4s + Return1x4: + ret + + End: + ldp x29, x30, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..3c648444a632cc805e831338bc52ded2ef6f9dbd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32.S @@ -0,0 +1,787 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeNhwc, size_t WriteWino) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: row +// w7: col +// w17: stride +// w13: c8_nhwc_c4 + +asm_function MatmulFloatNeon64 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + + ldr x9, [sp, #152] + ldr x14, [sp, #160] + + mov w19, #32 // sizeof(float) * 8 + mul w15, w5, w19 // block stride of lhs/rhs: sizeof(float) * 8 * depth + mov x19, #4 + ldr x17, [sp, #144] + cbz x14, NoWinoSteps + mul x8, x7, x17 + mov x11, #8 + mul x11, x11, x17 + mul x8, x8, x19 + mul x11, x11, x19 +NoWinoSteps: + mul x17, x17, x19 + +L1: + mov w10, w6 // reload lhs row + mov x12, x0 // reload lhs ptr + mov x19, x2 // reload dst ptr + +L2: + mov x16, x1 // reload rhs ptr + mov w13, w5 // reload depth + dup v8.4s, wzr + dup v9.4s, wzr + dup v10.4s, wzr + dup v11.4s, wzr + dup v12.4s, wzr + dup v13.4s, wzr + dup v14.4s, wzr + dup v15.4s, wzr + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + +LoopStart: + ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48 + ld1 {v3.4s, v4.4s}, [x16], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + + subs w13, w13, #1 + beq LoopEnd + +Loop: + ld1 {v0.4s}, [x12], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + ld1 {v1.4s}, [x12], #16 + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + ld1 {v3.4s}, [x16], #16 + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + ld1 {v4.4s}, [x16], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v2.4s}, [x12], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + + subs w13, w13, #1 + bgt Loop + +LoopEnd: + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + +Bias: + cbz x3, Activation + ld1 {v0.4s}, [x3], #16 + ld1 {v1.4s}, [x3] + sub x3, x3, #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + fadd v24.4s, v24.4s, v0.4s + fadd v25.4s, v25.4s, v1.4s + fadd v26.4s, v26.4s, v0.4s + fadd v27.4s, v27.4s, v1.4s + fadd v28.4s, v28.4s, v0.4s + fadd v29.4s, v29.4s, v1.4s + fadd v30.4s, v30.4s, v0.4s + fadd v31.4s, v31.4s, v1.4s + +Activation: + cmp w4, #3 + beq Relu6 + cmp w4, #1 + beq Relu + b Write + +Relu6: + mov w13, #6 + dup v2.4s, w13 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + +Relu: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + +Write: + cbnz x14, WriteWino + cbz x9, WriteC8 + cmp w7, #1 + beq Write1 + cmp w7, #2 + beq Write2 + cmp w7, #3 + beq Write3 + cmp w7, #4 + beq Write4 + cmp w7, #5 + beq Write5 + cmp w7, #6 + beq Write6 + cmp w7, #7 + beq Write7 + b Write8 + +Write1: + str s8, [x19] + cmp w10, #1 + beq WriteEnd + add x19, x19, x17 + str s10, [x19] + cmp w10, #2 + beq WriteEnd + add x19, x19, x17 + str s12, [x19] + cmp w10, #3 + beq WriteEnd + add x19, x19, x17 + str s14, [x19] + cmp w10, #4 + beq WriteEnd + add x19, x19, x17 + str s16, [x19] + cmp w10, #5 + beq WriteEnd + add x19, x19, x17 + str s18, [x19] + cmp w10, #6 + beq WriteEnd + add x19, x19, x17 + str s20, [x19] + cmp w10, #7 + beq WriteEnd + add x19, x19, x17 + str s22, [x19] + cmp w10, #8 + beq WriteEnd + add x19, x19, x17 + str s24, [x19] + cmp w10, #9 + beq WriteEnd + add x19, x19, x17 + str s26, [x19] + cmp w10, #10 + beq WriteEnd + add x19, x19, x17 + str s28, [x19] + cmp w10, #11 + beq WriteEnd + add x19, x19, x17 + str s30, [x19] + add x19, x19, x17 + b WriteEnd +Write2: + dup s9, v8.s[1] + stp s8, s9, [x19] + cmp w10, #1 + beq WriteEnd + add x19, x19, x17 + dup s11, v10.s[1] + stp s10, s11, [x19] + cmp w10, #2 + beq WriteEnd + add x19, x19, x17 + dup s13, v12.s[1] + stp s12, s13, [x19] + cmp w10, #3 + beq WriteEnd + add x19, x19, x17 + dup s15, v14.s[1] + stp s14, s15, [x19] + cmp w10, #4 + beq WriteEnd + add x19, x19, x17 + dup s17, v16.s[1] + stp s16, s17, [x19] + cmp w10, #5 + beq WriteEnd + add x19, x19, x17 + dup s19, v18.s[1] + stp s18, s19, [x19] + cmp w10, #6 + beq WriteEnd + add x19, x19, x17 + dup s21, v20.s[1] + stp s20, s21, [x19] + cmp w10, #7 + beq WriteEnd + add x19, x19, x17 + dup s23, v22.s[1] + stp s22, s23, [x19] + cmp w10, #8 + beq WriteEnd + add x19, x19, x17 + dup s25, v24.s[1] + stp s24, s25, [x19] + cmp w10, #9 + beq WriteEnd + add x19, x19, x17 + dup s27, v26.s[1] + stp s26, s27, [x19] + cmp w10, #10 + beq WriteEnd + add x19, x19, x17 + dup s29, v28.s[1] + stp s28, s29, [x19] + cmp w10, #11 + beq WriteEnd + add x19, x19, x17 + dup s31, v30.s[1] + stp s30, s31, [x19] + add x19, x19, x17 + b WriteEnd +Write3: + add x13, x19, #8 + dup s9, v8.s[1] + stp s8, s9, [x19] + add x19, x19, x17 + st1 {v8.s}[2], [x13], x17 + cmp w10, #1 + beq WriteEnd + dup s11, v10.s[1] + stp s10, s11, [x19] + add x19, x19, x17 + st1 {v10.s}[2], [x13], x17 + cmp w10, #2 + beq WriteEnd + dup s13, v12.s[1] + stp s12, s13, [x19] + add x19, x19, x17 + st1 {v12.s}[2], [x13], x17 + cmp w10, #3 + beq WriteEnd + dup s15, v14.s[1] + stp s14, s15, [x19] + add x19, x19, x17 + st1 {v14.s}[2], [x13], x17 + cmp w10, #4 + beq WriteEnd + dup s17, v16.s[1] + stp s16, s17, [x19] + add x19, x19, x17 + st1 {v16.s}[2], [x13], x17 + cmp w10, #5 + beq WriteEnd + dup s19, v18.s[1] + stp s18, s19, [x19] + add x19, x19, x17 + st1 {v18.s}[2], [x13], x17 + cmp w10, #6 + beq WriteEnd + dup s21, v20.s[1] + stp s20, s21, [x19] + add x19, x19, x17 + st1 {v20.s}[2], [x13], x17 + cmp w10, #7 + beq WriteEnd + dup s23, v22.s[1] + stp s22, s23, [x19] + add x19, x19, x17 + st1 {v22.s}[2], [x13], x17 + cmp w10, #8 + beq WriteEnd + dup s25, v24.s[1] + stp s24, s25, [x19] + add x19, x19, x17 + st1 {v24.s}[2], [x13], x17 + cmp w10, #9 + beq WriteEnd + dup s27, v26.s[1] + stp s26, s27, [x19] + add x19, x19, x17 + st1 {v26.s}[2], [x13], x17 + cmp w10, #10 + beq WriteEnd + dup s29, v28.s[1] + stp s28, s29, [x19] + add x19, x19, x17 + st1 {v28.s}[2], [x13], x17 + cmp w10, #11 + beq WriteEnd + dup s31, v30.s[1] + stp s30, s31, [x19] + add x19, x19, x17 + st1 {v30.s}[2], [x13] + b WriteEnd +Write4: + st1 {v8.4s}, [x19], x17 + cmp w10, #1 + beq WriteEnd + st1 {v10.4s}, [x19], x17 + cmp w10, #2 + beq WriteEnd + st1 {v12.4s}, [x19], x17 + cmp w10, #3 + beq WriteEnd + st1 {v14.4s}, [x19], x17 + cmp w10, #4 + beq WriteEnd + st1 {v16.4s}, [x19], x17 + cmp w10, #5 + beq WriteEnd + st1 {v18.4s}, [x19], x17 + cmp w10, #6 + beq WriteEnd + st1 {v20.4s}, [x19], x17 + cmp w10, #7 + beq WriteEnd + st1 {v22.4s}, [x19], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4s}, [x19], x17 + cmp w10, #9 + beq WriteEnd + st1 {v26.4s}, [x19], x17 + cmp w10, #10 + beq WriteEnd + st1 {v28.4s}, [x19], x17 + cmp w10, #11 + beq WriteEnd + st1 {v30.4s}, [x19], x17 + b WriteEnd +Write5: + add x13, x19, #16 + st1 {v8.4s}, [x19], x17 + str s9, [x13] + cmp w10, #1 + beq WriteEnd + add x13, x13, x17 + st1 {v10.4s}, [x19], x17 + str s11, [x13] + cmp w10, #2 + beq WriteEnd + add x13, x13, x17 + st1 {v12.4s}, [x19], x17 + str s13, [x13] + cmp w10, #3 + beq WriteEnd + add x13, x13, x17 + st1 {v14.4s}, [x19], x17 + str s15, [x13] + cmp w10, #4 + beq WriteEnd + add x13, x13, x17 + st1 {v16.4s}, [x19], x17 + str s17, [x13] + cmp w10, #5 + beq WriteEnd + add x13, x13, x17 + st1 {v18.4s}, [x19], x17 + str s19, [x13] + cmp w10, #6 + beq WriteEnd + add x13, x13, x17 + st1 {v20.4s}, [x19], x17 + str s21, [x13] + cmp w10, #7 + beq WriteEnd + add x13, x13, x17 + st1 {v22.4s}, [x19], x17 + str s23, [x13] + cmp w10, #8 + beq WriteEnd + add x13, x13, x17 + st1 {v24.4s}, [x19], x17 + str s25, [x13] + cmp w10, #9 + beq WriteEnd + add x13, x13, x17 + st1 {v26.4s}, [x19], x17 + str s27, [x13] + cmp w10, #10 + beq WriteEnd + add x13, x13, x17 + st1 {v28.4s}, [x19], x17 + str s29, [x13] + cmp w10, #11 + beq WriteEnd + add x13, x13, x17 + st1 {v30.4s}, [x19], x17 + str s31, [x13] + b WriteEnd +Write6: + add x13, x19, #16 + st1 {v8.4s}, [x19], x17 + dup s8, v9.s[1] + stp s9, s8, [x13] + cmp w10, #1 + beq WriteEnd + add x13, x13, x17 + st1 {v10.4s}, [x19], x17 + dup s10, v11.s[1] + stp s11, s10, [x13] + cmp w10, #2 + beq WriteEnd + add x13, x13, x17 + st1 {v12.4s}, [x19], x17 + dup s12, v13.s[1] + stp s13, s12, [x13] + cmp w10, #3 + beq WriteEnd + add x13, x13, x17 + st1 {v14.4s}, [x19], x17 + dup s14, v15.s[1] + stp s15, s14, [x13] + cmp w10, #4 + beq WriteEnd + add x13, x13, x17 + st1 {v16.4s}, [x19], x17 + dup s16, v17.s[1] + stp s17, s16, [x13] + cmp w10, #5 + beq WriteEnd + add x13, x13, x17 + st1 {v18.4s}, [x19], x17 + dup s18, v19.s[1] + stp s19, s18, [x13] + cmp w10, #6 + beq WriteEnd + add x13, x13, x17 + st1 {v20.4s}, [x19], x17 + dup s20, v21.s[1] + stp s21, s20, [x13] + cmp w10, #7 + beq WriteEnd + add x13, x13, x17 + st1 {v22.4s}, [x19], x17 + dup s22, v23.s[1] + stp s23, s22, [x13] + cmp w10, #8 + beq WriteEnd + add x13, x13, x17 + st1 {v24.4s}, [x19], x17 + dup s24, v25.s[1] + stp s25, s24, [x13] + cmp w10, #9 + beq WriteEnd + add x13, x13, x17 + st1 {v26.4s}, [x19], x17 + dup s26, v27.s[1] + stp s27, s26, [x13] + cmp w10, #10 + beq WriteEnd + add x13, x13, x17 + st1 {v28.4s}, [x19], x17 + dup s28, v29.s[1] + stp s29, s28, [x13] + cmp w10, #11 + beq WriteEnd + add x13, x13, x17 + st1 {v30.4s}, [x19], x17 + dup s30, v31.s[1] + stp s31, s30, [x13] + b WriteEnd +Write7: + add x13, x19, #16 + add x16, x19, #24 + st1 {v8.4s}, [x19], x17 + dup s8, v9.s[1] + stp s9, s8, [x13] + add x13, x13, x17 + st1 {v9.s}[2], [x16], x17 + cmp w10, #1 + beq WriteEnd + st1 {v10.4s}, [x19], x17 + dup s10, v11.s[1] + stp s11, s10, [x13] + add x13, x13, x17 + st1 {v11.s}[2], [x16], x17 + cmp w10, #2 + beq WriteEnd + st1 {v12.4s}, [x19], x17 + dup s12, v13.s[1] + stp s13, s12, [x13] + add x13, x13, x17 + st1 {v13.s}[2], [x16], x17 + cmp w10, #3 + beq WriteEnd + st1 {v14.4s}, [x19], x17 + dup s14, v15.s[1] + stp s15, s14, [x13] + add x13, x13, x17 + st1 {v15.s}[2], [x16], x17 + cmp w10, #4 + beq WriteEnd + st1 {v16.4s}, [x19], x17 + dup s16, v17.s[1] + stp s17, s16, [x13] + add x13, x13, x17 + st1 {v17.s}[2], [x16], x17 + cmp w10, #5 + beq WriteEnd + st1 {v18.4s}, [x19], x17 + dup s18, v19.s[1] + stp s19, s18, [x13] + add x13, x13, x17 + st1 {v19.s}[2], [x16], x17 + cmp w10, #6 + beq WriteEnd + st1 {v20.4s}, [x19], x17 + dup s20, v21.s[1] + stp s21, s20, [x13] + add x13, x13, x17 + st1 {v21.s}[2], [x16], x17 + cmp w10, #7 + beq WriteEnd + st1 {v22.4s}, [x19], x17 + dup s22, v23.s[1] + stp s23, s22, [x13] + add x13, x13, x17 + st1 {v23.s}[2], [x16], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4s}, [x19], x17 + dup s24, v25.s[1] + stp s25, s24, [x13] + add x13, x13, x17 + st1 {v25.s}[2], [x16], x17 + cmp w10, #9 + beq WriteEnd + st1 {v26.4s}, [x19], x17 + dup s26, v27.s[1] + stp s27, s26, [x13] + add x13, x13, x17 + st1 {v27.s}[2], [x16], x17 + cmp w10, #10 + beq WriteEnd + st1 {v28.4s}, [x19], x17 + dup s28, v29.s[1] + stp s29, s28, [x13] + add x13, x13, x17 + st1 {v29.s}[2], [x16], x17 + cmp w10, #11 + beq WriteEnd + st1 {v30.4s}, [x19], x17 + dup s30, v31.s[1] + stp s31, s30, [x13] + add x13, x13, x17 + st1 {v31.s}[2], [x16], x17 + b WriteEnd +WriteC8: + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + b WriteEnd +WriteWino: + st1 {v8.4s, v9.4s}, [x19], x8 + st1 {v10.4s, v11.4s}, [x19], x8 + st1 {v12.4s, v13.4s}, [x19], x8 + st1 {v14.4s, v15.4s}, [x19], x8 + st1 {v16.4s, v17.4s}, [x19], x8 + st1 {v18.4s, v19.4s}, [x19], x8 + st1 {v20.4s, v21.4s}, [x19], x8 + st1 {v22.4s, v23.4s}, [x19], x8 + st1 {v24.4s, v25.4s}, [x19], x8 + st1 {v26.4s, v27.4s}, [x19], x8 + st1 {v28.4s, v29.4s}, [x19], x8 + st1 {v30.4s, v31.4s}, [x19], x8 + b WriteEnd +Write8: + st1 {v8.4s, v9.4s}, [x19], x17 + cmp w10, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x19], x17 + cmp w10, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x19], x17 + cmp w10, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x19], x17 + cmp w10, #4 + beq WriteEnd + st1 {v16.4s, v17.4s}, [x19], x17 + cmp w10, #5 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x19], x17 + cmp w10, #6 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x19], x17 + cmp w10, #7 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x19], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4s, v25.4s}, [x19], x17 + cmp w10, #9 + beq WriteEnd + st1 {v26.4s, v27.4s}, [x19], x17 + cmp w10, #10 + beq WriteEnd + st1 {v28.4s, v29.4s}, [x19], x17 + cmp w10, #11 + beq WriteEnd + st1 {v30.4s, v31.4s}, [x19], x17 + +WriteEnd: + subs w10, w10, #12 // lhs row - 12 + bgt L2 + +End2: + subs w7, w7, #8 // rhs col - 8 + add x1, x1, x15 // rhs ptr + stride + cbz x3, NoBiasStep + add x3, x3, #32 // bias ptr + stride +NoBiasStep: + cbnz x14, WinoDstStep + cbz x9, NoDstStep + add x2, x2, #32 // dst ptr + stride + b NoDstStep +WinoDstStep: + add x2, x2, x11 +NoDstStep: + bgt L1 + +End1: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32Opt.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32Opt.S new file mode 100644 index 0000000000000000000000000000000000000000..abaf79e31f18f16a87a2ddef16a5f09248e7c681 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32Opt.S @@ -0,0 +1,1669 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFloatNeon64Opt + sub sp, sp, #160 + + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] + + mov x21, #48 // sizeof(float) * 12 + mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth + cmp x9, #3 // c4 + beq C4Stride + cbnz x9, NoC8Steps + mov x11, x2 + mov x21, #32 + mul x16, x6, x21 // row * 8 * sizeof(float) + b NoC8Steps +C4Stride: + mov x18, #48 // 12 * sizeof(float) + mov x22, #4 + mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row + mul x8, x8, x22 // col stride + // col >= 4 , block stride 192, otherwise 12 * 4 * col + cmp x7, #4 + bge C4StrideCommon + mul x18, x18, x7 // block stride + b LoopRowStart +C4StrideCommon: + mov x18, #192 // block stride + + b LoopRowStart + +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x21, #4 + mul x15, x7, x8 + mul x15, x15, x21 // kernel_size * col *sizeof(float) + mov x21, #32 + mul x16, x8, x21 // kernel_size * 8 * sizeof(float) +NoWinoSteps: + mov x21, #4 + mul x8, x8, x21 + +LoopRowStart: + cmp x9, #3 + bne RowStart + mov x20, x2 +RowStart: + cmp x6, #4 + ble LoopRow4 + cmp x6, #8 + ble LoopRow8 + +LoopRow: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol: + cbz x9, NoReloadDst + cmp x9, #3 + beq C4ReloadDst + mov x11, x2 + b NoReloadDst + C4ReloadDst: + mov x11, x20 + NoReloadDst: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf + + LoopDepthStart: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + ld1 {v4.4s}, [x14], #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v17.4s, v4.4s, v1.s[0] + fmul v19.4s, v4.4s, v1.s[1] + fmul v21.4s, v4.4s, v1.s[2] + fmul v23.4s, v4.4s, v1.s[3] + fmul v24.4s, v3.4s, v2.s[0] + fmul v26.4s, v3.4s, v2.s[1] + fmul v28.4s, v3.4s, v2.s[2] + fmul v30.4s, v3.4s, v2.s[3] + fmul v25.4s, v4.4s, v2.s[0] + fmul v27.4s, v4.4s, v2.s[1] + fmul v29.4s, v4.4s, v2.s[2] + fmul v31.4s, v4.4s, v2.s[3] + + subs x19, x19, #1 + beq Bias + + LoopDepth: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + ld1 {v4.4s}, [x14], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + + subs x19, x19, #1 + bgt LoopDepth + + Bias: + cbz x3, Activation + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + fadd v24.4s, v24.4s, v0.4s + fadd v25.4s, v25.4s, v1.4s + fadd v26.4s, v26.4s, v0.4s + fadd v27.4s, v27.4s, v1.4s + fadd v28.4s, v28.4s, v0.4s + fadd v29.4s, v29.4s, v1.4s + fadd v30.4s, v30.4s, v0.4s + fadd v31.4s, v31.4s, v1.4s + + Activation: + cmp x4, #3 + beq Relu6 + cmp x4, #1 + beq Relu + b Write + + Relu6: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + + Relu: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + b Write + + LoopDepthStartHalf: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + add x14, x14, #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v24.4s, v3.4s, v2.s[0] + fmul v26.4s, v3.4s, v2.s[1] + fmul v28.4s, v3.4s, v2.s[2] + fmul v30.4s, v3.4s, v2.s[3] + + subs x19, x19, #1 + beq BiasHalf + + LoopDepthHalf: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + add x14, x14, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf + + BiasHalf: + cbz x3, ActivationHalf + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + fadd v16.4s, v16.4s, v0.4s + fadd v18.4s, v18.4s, v0.4s + fadd v20.4s, v20.4s, v0.4s + fadd v22.4s, v22.4s, v0.4s + fadd v24.4s, v24.4s, v0.4s + fadd v26.4s, v26.4s, v0.4s + fadd v28.4s, v28.4s, v0.4s + fadd v30.4s, v30.4s, v0.4s + + ActivationHalf: + cmp x4, #3 + beq Relu6Half + cmp x4, #1 + beq ReluHalf + b Write + + Relu6Half: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + + ReluHalf: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + b Write + + +LoopRow8: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + cbz x9, NoReloadDst8 + cmp x9, #3 + beq C4ReloadDst8 + mov x11, x2 + b NoReloadDst8 + C4ReloadDst8: + mov x11, x20 + NoReloadDst8: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf8 + + LoopDepthStart8: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + ld1 {v4.4s}, [x14], #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v17.4s, v4.4s, v1.s[0] + fmul v19.4s, v4.4s, v1.s[1] + fmul v21.4s, v4.4s, v1.s[2] + fmul v23.4s, v4.4s, v1.s[3] + + subs x19, x19, #1 + beq Bias8 + + LoopDepth8: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + ld1 {v4.4s}, [x14], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + + subs x19, x19, #1 + bgt LoopDepth8 + + Bias8: + cbz x3, Activation8 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + + Relu8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + b Write + + LoopDepthStartHalf8: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + add x14, x14, #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + + subs x19, x19, #1 + beq BiasHalf8 + + LoopDepthHalf8: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + add x14, x14, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf8 + + BiasHalf8: + cbz x3, ActivationHalf8 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + fadd v16.4s, v16.4s, v0.4s + fadd v18.4s, v18.4s, v0.4s + fadd v20.4s, v20.4s, v0.4s + fadd v22.4s, v22.4s, v0.4s + + ActivationHalf8: + cmp x4, #3 + beq Relu6Half8 + cmp x4, #1 + beq ReluHalf8 + b Write + + Relu6Half8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + + ReluHalf8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + b Write + +LoopRow4: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + cbz x9, NoReloadDst4 + cmp x9, #3 + beq C4ReloadDst4 + mov x11, x2 + b NoReloadDst4 + C4ReloadDst4: + mov x11, x20 + NoReloadDst4: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf4 + + LoopDepthStart4: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + add x10, x10, #32 + ld1 {v4.4s}, [x14], #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + + subs x19, x19, #1 + beq Bias4 + + LoopDepth4: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + add x10, x10, #32 + ld1 {v4.4s}, [x14], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + + subs x19, x19, #1 + bgt LoopDepth4 + + Bias4: + cbz x3, Activation4 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + + Relu4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + b Write + + LoopDepthStartHalf4: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + add x10, x10, #32 + add x14, x14, #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + + subs x19, x19, #1 + beq BiasHalf4 + + LoopDepthHalf4: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + add x10, x10, #32 + add x14, x14, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf4 + + BiasHalf4: + cbz x3, ActivationHalf4 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + + ActivationHalf4: + cmp x4, #3 + beq Relu6Half4 + cmp x4, #1 + beq ReluHalf4 + b Write + + Relu6Half4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + + ReluHalf4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + + Write: + cmp x9, #2 + beq WriteWino + cmp x9, #3 + beq WriteC4 + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #4 + str s8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str s16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str s18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str s20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str s22, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str s24, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str s26, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str s28, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str s30, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + st1 {v8.2s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + st1 {v8.2s}, [x11], x8 + st1 {v8.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + st1 {v10.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + st1 {v12.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + st1 {v14.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], x8 + st1 {v30.s}[2], [x19] + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v8.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write5: + add x2, x2, #20 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str s9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str s11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str s13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str s15, [x19] + cmp x6, #4 + beq WriteEnd + add x19, x19, x8 + st1 {v16.4s}, [x11], x8 + str s17, [x19] + cmp x6, #5 + beq WriteEnd + add x19, x19, x8 + st1 {v18.4s}, [x11], x8 + str s19, [x19] + cmp x6, #6 + beq WriteEnd + add x19, x19, x8 + st1 {v20.4s}, [x11], x8 + str s21, [x19] + cmp x6, #7 + beq WriteEnd + add x19, x19, x8 + st1 {v22.4s}, [x11], x8 + str s23, [x19] + cmp x6, #8 + beq WriteEnd + add x19, x19, x8 + st1 {v24.4s}, [x11], x8 + str s25, [x19] + cmp x6, #9 + beq WriteEnd + add x19, x19, x8 + st1 {v26.4s}, [x11], x8 + str s27, [x19] + cmp x6, #10 + beq WriteEnd + add x19, x19, x8 + st1 {v28.4s}, [x11], x8 + str s29, [x19] + cmp x6, #11 + beq WriteEnd + add x19, x19, x8 + st1 {v30.4s}, [x11], x8 + str s31, [x19] + add x11, x11, #20 + b WriteEnd + Write6: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + st1 {v25.2s}, [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + st1 {v27.2s}, [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + st1 {v29.2s}, [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + st1 {v31.2s}, [x19] + add x11, x11, #24 + b WriteEnd + Write7: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + st1 {v9.s}[2], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + st1 {v11.s}[2], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + st1 {v13.s}[2], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + st1 {v15.s}[2], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + st1 {v17.s}[2], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + st1 {v19.s}[2], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + st1 {v21.s}[2], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + st1 {v23.s}[2], [x20], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + st1 {v25.2s}, [x19], x8 + st1 {v25.s}[2], [x20], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + st1 {v27.2s}, [x19], x8 + st1 {v27.s}[2], [x20], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + st1 {v29.2s}, [x19], x8 + st1 {v29.s}[2], [x20], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + st1 {v31.2s}, [x19] + st1 {v31.s}[2], [x20] + add x11, x11, #28 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x19], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x19], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x19], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x19], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x19], #64 + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v8.4s, v9.4s}, [x11], x15 + st1 {v10.4s, v11.4s}, [x11], x15 + st1 {v12.4s, v13.4s}, [x11], x15 + st1 {v14.4s, v15.4s}, [x11], x15 + st1 {v16.4s, v17.4s}, [x11], x15 + st1 {v18.4s, v19.4s}, [x11], x15 + st1 {v20.4s, v21.4s}, [x11], x15 + st1 {v22.4s, v23.4s}, [x11], x15 + st1 {v24.4s, v25.4s}, [x11], x15 + st1 {v26.4s, v27.4s}, [x11], x15 + st1 {v28.4s, v29.4s}, [x11], x15 + st1 {v30.4s, v31.4s}, [x11], x15 + b WriteEnd + Write8: + add x2, x2, #32 + st1 {v8.4s, v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s, v17.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s, v25.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s, v27.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s, v29.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s, v31.4s}, [x11], x8 + add x11, x11, #32 + b WriteEnd + WriteC4: + cmp x13, #1 + beq C4Write1 + cmp x13, #2 + beq C4Write2 + cmp x13, #3 + beq C4Write3 + cmp x13, #4 + beq C4Write4 + cmp x13, #5 + beq C4Write5 + cmp x13, #6 + beq C4Write6 + cmp x13, #7 + beq C4Write7 + b C4Write8 + C4Write1: + // add x20, x11, x8 + str s8, [x11], #4 + cmp x6, #1 + beq WriteEnd + str s10, [x11], #4 + cmp x6, #2 + beq WriteEnd + str s12, [x11], #4 + cmp x6, #3 + beq WriteEnd + str s14, [x11], #4 + cmp x6, #4 + beq WriteEnd + str s16, [x11], #4 + cmp x6, #5 + beq WriteEnd + str s18, [x11], #4 + cmp x6, #6 + beq WriteEnd + str s20, [x11], #4 + cmp x6, #7 + beq WriteEnd + str s22, [x11], #4 + cmp x6, #8 + beq WriteEnd + str s24, [x11], #4 + cmp x6, #9 + beq WriteEnd + str s26, [x11], #4 + cmp x6, #10 + beq WriteEnd + str s28, [x11], #4 + cmp x6, #11 + beq WriteEnd + str s30, [x11], #4 + b WriteEnd + C4Write2: + // add x20, x11, x8 + st1 {v8.2s}, [x11], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], #8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], #8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], #8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], #8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], #8 + b WriteEnd + C4Write3: + // add x20, x11, x8 + add x19, x11, #8 + st1 {v8.2s}, [x11] + add x11, x11, #12 + st1 {v8.s}[2], [x19] + add x19, x19, #12 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11] + add x11, x11, #12 + st1 {v10.s}[2], [x19] + add x19, x19, #12 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11] + add x11, x11, #12 + st1 {v12.s}[2], [x19] + add x19, x19, #12 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11] + add x11, x11, #12 + st1 {v14.s}[2], [x19] + add x19, x19, #12 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11] + add x11, x11, #12 + st1 {v16.s}[2], [x19] + add x19, x19, #12 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11] + add x11, x11, #12 + st1 {v18.s}[2], [x19] + add x19, x19, #12 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11] + add x11, x11, #12 + st1 {v20.s}[2], [x19] + add x19, x19, #12 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11] + add x11, x11, #12 + st1 {v22.s}[2], [x19] + add x19, x19, #12 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11] + add x11, x11, #12 + st1 {v24.s}[2], [x19] + add x19, x19, #12 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11] + add x11, x11, #12 + st1 {v26.s}[2], [x19] + add x19, x19, #12 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11] + add x11, x11, #12 + st1 {v28.s}[2], [x19] + add x19, x19, #12 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11] + add x11, x11, #12 + st1 {v30.s}[2], [x19] + add x19, x19, #12 + b WriteEnd + + C4Write4: + add x20, x11, x8 + st1 {v8.4s}, [x11], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + b WriteEnd + C4Write5: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + str s9, [x19], #4 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + str s11, [x19], #4 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + str s13, [x19], #4 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + str s15, [x19], #4 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + str s17, [x19], #4 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + str s19, [x19], #4 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + str s21, [x19], #4 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + str s23, [x19], #4 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + str s25, [x19], #4 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + str s27, [x19], #4 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + str s29, [x19], #4 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + str s31, [x19], #4 + b WriteEnd + C4Write6: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], #8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + st1 {v25.2s}, [x19], #8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + st1 {v27.2s}, [x19], #8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + st1 {v29.2s}, [x19], #8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + st1 {v31.2s}, [x19], #8 + b WriteEnd + C4Write7: + add x19, x11, x8 + add x16, x19, #8 + mov x15, #12 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], x15 + st1 {v9.s}[2], [x16], x15 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], x15 + st1 {v11.s}[2], [x16], x15 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], x15 + st1 {v13.s}[2], [x16], x15 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], x15 + st1 {v15.s}[2], [x16], x15 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], x15 + st1 {v17.s}[2], [x16], x15 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], x15 + st1 {v19.s}[2], [x16], x15 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], x15 + st1 {v21.s}[2], [x16], x15 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], x15 + st1 {v23.s}[2], [x16], x15 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + st1 {v25.2s}, [x19], x15 + st1 {v25.s}[2], [x16], x15 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + st1 {v27.2s}, [x19], x15 + st1 {v27.s}[2], [x16], x15 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + st1 {v29.2s}, [x19], x15 + st1 {v29.s}[2], [x16], x15 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11] + st1 {v31.2s}, [x19] + st1 {v31.s}[2], [x16] + b WriteEnd + C4Write8: + add x19, x11, x8 + add x20, x19, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.4s}, [x19], #16 + cmp x6, #1 + beq WriteEnd + + st1 {v10.4s}, [x11], #16 + st1 {v11.4s}, [x19], #16 + cmp x6, #2 + beq WriteEnd + + st1 {v12.4s}, [x11], #16 + st1 {v13.4s}, [x19], #16 + cmp x6, #3 + beq WriteEnd + + st1 {v14.4s}, [x11], #16 + st1 {v15.4s}, [x19], #16 + cmp x6, #4 + beq WriteEnd + + st1 {v16.4s}, [x11], #16 + st1 {v17.4s}, [x19], #16 + cmp x6, #5 + beq WriteEnd + + st1 {v18.4s}, [x11], #16 + st1 {v19.4s}, [x19], #16 + cmp x6, #6 + beq WriteEnd + + st1 {v20.4s}, [x11], #16 + st1 {v21.4s}, [x19], #16 + cmp x6, #7 + beq WriteEnd + + st1 {v22.4s}, [x11], #16 + st1 {v23.4s}, [x19], #16 + cmp x6, #8 + beq WriteEnd + + st1 {v24.4s}, [x11], #16 + st1 {v25.4s}, [x19], #16 + cmp x6, #9 + beq WriteEnd + + st1 {v26.4s}, [x11], #16 + st1 {v27.4s}, [x19], #16 + cmp x6, #10 + beq WriteEnd + + st1 {v28.4s}, [x11], #16 + st1 {v29.4s}, [x19], #16 + cmp x6, #11 + beq WriteEnd + + st1 {v30.4s}, [x11] + st1 {v31.4s}, [x19] + b WriteEnd + + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + ble LoopColEnd + cmp x6, #4 + ble LoopCol4 + cmp x6, #8 + ble LoopCol8 + b LoopCol + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + cmp x9, #3 + beq C4DstStep + mov x21, #4 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + b NoDstStep + C4DstStep: + add x2, x2, x18 + b NoDstStep + C8DstStep: + add x2, x2, #384 + mov x11, x2 + NoDstStep: + subs x6, x6, #12 + bgt LoopRowStart + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow12.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow12.S new file mode 100644 index 0000000000000000000000000000000000000000..21dc91caf73ba999e4ef732a1d2551af1b134e8f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow12.S @@ -0,0 +1,1229 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFloatNeon64OptRow12 + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] + + mov x21, #48 // sizeof(float) * 12 + mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth + cmp x9, #3 // c4 + beq C4Stride + cbnz x9, NoC8Steps + mov x11, x2 + mov x21, #32 + mul x16, x6, x21 // row * 8 * sizeof(float) + b NoC8Steps +C4Stride: + mov x18, #48 // 12 * sizeof(float) + mov x22, #4 + mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row + mul x8, x8, x22 // col stride + // col >= 4 , block stride 192, otherwise 12 * 4 * col + cmp x7, #4 + bge C4StrideCommon + mul x18, x18, x7 // block stride + b LoopRowStart +C4StrideCommon: + mov x18, #192 // block stride + b LoopRowStart + +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x21, #4 + mul x15, x7, x8 + mul x15, x15, x21 // kernel_size * col *sizeof(float) + mov x21, #32 + mul x16, x8, x21 // kernel_size * 8 * sizeof(float) +NoWinoSteps: + mov x21, #4 + mul x8, x8, x21 + +LoopRowStart: + cmp x9, #3 + bne LoopRow + mov x20, x2 +LoopRow: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol: + cbz x9, NoReloadDst + cmp x9, #3 + beq C4ReloadDst + mov x11, x2 + b NoReloadDst + C4ReloadDst: + mov x11, x20 + NoReloadDst: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf + + LoopDepthStart: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v17.4s, v4.4s, v1.s[0] + fmul v19.4s, v4.4s, v1.s[1] + fmul v21.4s, v4.4s, v1.s[2] + fmul v23.4s, v4.4s, v1.s[3] + fmul v24.4s, v3.4s, v2.s[0] + fmul v26.4s, v3.4s, v2.s[1] + fmul v28.4s, v3.4s, v2.s[2] + fmul v30.4s, v3.4s, v2.s[3] + fmul v25.4s, v4.4s, v2.s[0] + fmul v27.4s, v4.4s, v2.s[1] + fmul v29.4s, v4.4s, v2.s[2] + fmul v31.4s, v4.4s, v2.s[3] + + subs x19, x19, #1 + beq Bias + + LoopDepth: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + + subs x19, x19, #1 + bgt LoopDepth + + Bias: + cbz x3, Activation + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + fadd v24.4s, v24.4s, v0.4s + fadd v25.4s, v25.4s, v1.4s + fadd v26.4s, v26.4s, v0.4s + fadd v27.4s, v27.4s, v1.4s + fadd v28.4s, v28.4s, v0.4s + fadd v29.4s, v29.4s, v1.4s + fadd v30.4s, v30.4s, v0.4s + fadd v31.4s, v31.4s, v1.4s + + Activation: + cmp x4, #3 + beq Relu6 + cmp x4, #1 + beq Relu + b Write + + Relu6: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + + Relu: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + b Write + + LoopDepthStartHalf: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v24.4s, v3.4s, v2.s[0] + fmul v26.4s, v3.4s, v2.s[1] + fmul v28.4s, v3.4s, v2.s[2] + fmul v30.4s, v3.4s, v2.s[3] + + subs x19, x19, #1 + beq BiasHalf + + LoopDepthHalf: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf + + BiasHalf: + cbz x3, ActivationHalf + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + fadd v16.4s, v16.4s, v0.4s + fadd v18.4s, v18.4s, v0.4s + fadd v20.4s, v20.4s, v0.4s + fadd v22.4s, v22.4s, v0.4s + fadd v24.4s, v24.4s, v0.4s + fadd v26.4s, v26.4s, v0.4s + fadd v28.4s, v28.4s, v0.4s + fadd v30.4s, v30.4s, v0.4s + + ActivationHalf: + cmp x4, #3 + beq Relu6Half + cmp x4, #1 + beq ReluHalf + b Write + + Relu6Half: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + + ReluHalf: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + + Write: + cmp x9, #2 + beq WriteWino + cmp x9, #3 + beq WriteC4 + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #4 + str s8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str s16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str s18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str s20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str s22, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str s24, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str s26, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str s28, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str s30, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + st1 {v8.2s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + st1 {v8.2s}, [x11], x8 + st1 {v8.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + st1 {v10.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + st1 {v12.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + st1 {v14.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], x8 + st1 {v30.s}[2], [x19] + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v8.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write5: + add x2, x2, #20 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str s9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str s11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str s13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str s15, [x19] + cmp x6, #4 + beq WriteEnd + add x19, x19, x8 + st1 {v16.4s}, [x11], x8 + str s17, [x19] + cmp x6, #5 + beq WriteEnd + add x19, x19, x8 + st1 {v18.4s}, [x11], x8 + str s19, [x19] + cmp x6, #6 + beq WriteEnd + add x19, x19, x8 + st1 {v20.4s}, [x11], x8 + str s21, [x19] + cmp x6, #7 + beq WriteEnd + add x19, x19, x8 + st1 {v22.4s}, [x11], x8 + str s23, [x19] + cmp x6, #8 + beq WriteEnd + add x19, x19, x8 + st1 {v24.4s}, [x11], x8 + str s25, [x19] + cmp x6, #9 + beq WriteEnd + add x19, x19, x8 + st1 {v26.4s}, [x11], x8 + str s27, [x19] + cmp x6, #10 + beq WriteEnd + add x19, x19, x8 + st1 {v28.4s}, [x11], x8 + str s29, [x19] + cmp x6, #11 + beq WriteEnd + add x19, x19, x8 + st1 {v30.4s}, [x11], x8 + str s31, [x19] + add x11, x11, #20 + b WriteEnd + Write6: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + st1 {v25.2s}, [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + st1 {v27.2s}, [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + st1 {v29.2s}, [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + st1 {v31.2s}, [x19] + add x11, x11, #24 + b WriteEnd + Write7: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + st1 {v9.s}[2], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + st1 {v11.s}[2], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + st1 {v13.s}[2], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + st1 {v15.s}[2], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + st1 {v17.s}[2], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + st1 {v19.s}[2], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + st1 {v21.s}[2], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + st1 {v23.s}[2], [x20], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + st1 {v25.2s}, [x19], x8 + st1 {v25.s}[2], [x20], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + st1 {v27.2s}, [x19], x8 + st1 {v27.s}[2], [x20], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + st1 {v29.2s}, [x19], x8 + st1 {v29.s}[2], [x20], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + st1 {v31.2s}, [x19] + st1 {v31.s}[2], [x20] + add x11, x11, #28 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x19], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x19], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x19], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x19], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x19], #64 + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v8.4s, v9.4s}, [x11], x15 + st1 {v10.4s, v11.4s}, [x11], x15 + st1 {v12.4s, v13.4s}, [x11], x15 + st1 {v14.4s, v15.4s}, [x11], x15 + st1 {v16.4s, v17.4s}, [x11], x15 + st1 {v18.4s, v19.4s}, [x11], x15 + st1 {v20.4s, v21.4s}, [x11], x15 + st1 {v22.4s, v23.4s}, [x11], x15 + st1 {v24.4s, v25.4s}, [x11], x15 + st1 {v26.4s, v27.4s}, [x11], x15 + st1 {v28.4s, v29.4s}, [x11], x15 + st1 {v30.4s, v31.4s}, [x11], x15 + b WriteEnd + Write8: + add x2, x2, #32 + st1 {v8.4s, v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s, v17.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s, v25.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s, v27.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s, v29.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s, v31.4s}, [x11], x8 + add x11, x11, #32 + b WriteEnd + WriteC4: + cmp x13, #1 + beq C4Write1 + cmp x13, #2 + beq C4Write2 + cmp x13, #3 + beq C4Write3 + cmp x13, #4 + beq C4Write4 + cmp x13, #5 + beq C4Write5 + cmp x13, #6 + beq C4Write6 + cmp x13, #7 + beq C4Write7 + b C4Write8 + C4Write1: + str s8, [x11], #4 + cmp x6, #1 + beq WriteEnd + str s10, [x11], #4 + cmp x6, #2 + beq WriteEnd + str s12, [x11], #4 + cmp x6, #3 + beq WriteEnd + str s14, [x11], #4 + cmp x6, #4 + beq WriteEnd + str s16, [x11], #4 + cmp x6, #5 + beq WriteEnd + str s18, [x11], #4 + cmp x6, #6 + beq WriteEnd + str s20, [x11], #4 + cmp x6, #7 + beq WriteEnd + str s22, [x11], #4 + cmp x6, #8 + beq WriteEnd + str s24, [x11], #4 + cmp x6, #9 + beq WriteEnd + str s26, [x11], #4 + cmp x6, #10 + beq WriteEnd + str s28, [x11], #4 + cmp x6, #11 + beq WriteEnd + str s30, [x11], #4 + b WriteEnd + C4Write2: + st1 {v8.2s}, [x11], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], #8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], #8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], #8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], #8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], #8 + b WriteEnd + C4Write3: + add x19, x11, #8 + st1 {v8.2s}, [x11] + add x11, x11, #12 + st1 {v8.s}[2], [x19] + add x19, x19, #12 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11] + add x11, x11, #12 + st1 {v10.s}[2], [x19] + add x19, x19, #12 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11] + add x11, x11, #12 + st1 {v12.s}[2], [x19] + add x19, x19, #12 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11] + add x11, x11, #12 + st1 {v14.s}[2], [x19] + add x19, x19, #12 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11] + add x11, x11, #12 + st1 {v16.s}[2], [x19] + add x19, x19, #12 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11] + add x11, x11, #12 + st1 {v18.s}[2], [x19] + add x19, x19, #12 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11] + add x11, x11, #12 + st1 {v20.s}[2], [x19] + add x19, x19, #12 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11] + add x11, x11, #12 + st1 {v22.s}[2], [x19] + add x19, x19, #12 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11] + add x11, x11, #12 + st1 {v24.s}[2], [x19] + add x19, x19, #12 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11] + add x11, x11, #12 + st1 {v26.s}[2], [x19] + add x19, x19, #12 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11] + add x11, x11, #12 + st1 {v28.s}[2], [x19] + add x19, x19, #12 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11] + add x11, x11, #12 + st1 {v30.s}[2], [x19] + add x19, x19, #12 + b WriteEnd + C4Write4: + st1 {v8.4s}, [x11], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + b WriteEnd + C4Write5: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + str s9, [x19], #4 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + str s11, [x19], #4 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + str s13, [x19], #4 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + str s15, [x19], #4 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + str s17, [x19], #4 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + str s19, [x19], #4 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + str s21, [x19], #4 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + str s23, [x19], #4 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + str s25, [x19], #4 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + str s27, [x19], #4 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + str s29, [x19], #4 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + str s31, [x19], #4 + b WriteEnd + C4Write6: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], #8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + st1 {v25.2s}, [x19], #8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + st1 {v27.2s}, [x19], #8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + st1 {v29.2s}, [x19], #8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + st1 {v31.2s}, [x19], #8 + b WriteEnd + C4Write7: + add x19, x11, x8 + add x16, x19, #8 + mov x15, #12 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], x15 + st1 {v9.s}[2], [x16], x15 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], x15 + st1 {v11.s}[2], [x16], x15 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], x15 + st1 {v13.s}[2], [x16], x15 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], x15 + st1 {v15.s}[2], [x16], x15 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], x15 + st1 {v17.s}[2], [x16], x15 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], x15 + st1 {v19.s}[2], [x16], x15 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], x15 + st1 {v21.s}[2], [x16], x15 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], x15 + st1 {v23.s}[2], [x16], x15 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + st1 {v25.2s}, [x19], x15 + st1 {v25.s}[2], [x16], x15 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + st1 {v27.2s}, [x19], x15 + st1 {v27.s}[2], [x16], x15 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + st1 {v29.2s}, [x19], x15 + st1 {v29.s}[2], [x16], x15 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11] + st1 {v31.2s}, [x19] + st1 {v31.s}[2], [x16] + b WriteEnd + C4Write8: + add x19, x11, x8 + add x20, x19, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.4s}, [x19], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.4s}, [x19], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.4s}, [x19], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.4s}, [x19], #16 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.4s}, [x19], #16 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.4s}, [x19], #16 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.4s}, [x19], #16 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.4s}, [x19], #16 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + st1 {v25.4s}, [x19], #16 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + st1 {v27.4s}, [x19], #16 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + st1 {v29.4s}, [x19], #16 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11] + st1 {v31.4s}, [x19] + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + bgt LoopCol + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + cmp x9, #3 + beq C4DstStep + mov x21, #4 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + b NoDstStep + C4DstStep: + add x2, x2, x18 + b NoDstStep + C8DstStep: + add x2, x2, #384 + mov x11, x2 + NoDstStep: + subs x6, x6, #12 + bgt LoopRow + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow4.S new file mode 100644 index 0000000000000000000000000000000000000000..9798eabd7ec91ed27c8f462710442ede5689c1ef --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow4.S @@ -0,0 +1,597 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon64OptRow4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFloatNeon64OptRow4 + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] + + mov x21, #48 // sizeof(float) * 12 + + mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth + cmp x9, #3 // c4 + beq C4Stride + cbnz x9, NoC8Steps + mov x11, x2 + mov x21, #32 + mul x16, x6, x21 // row * 8 * sizeof(float) + b NoC8Steps +C4Stride: + mov x18, #16 // 4 * sizeof(float) + mov x22, #4 + mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row + mul x8, x8, x22 // col stride + // col >= 4 , block stride 64, otherwise 4 * 4 * col + cmp x7, #4 + bge C4StrideCommon + mul x18, x18, x7 // block stride + b LoopRowStart +C4StrideCommon: + mov x18, #64 // block stride + b LoopRowStart + +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x21, #4 + mul x15, x7, x8 + mul x15, x15, x21 // kernel_size * col *sizeof(float) + mov x21, #32 + mul x16, x8, x21 // kernel_size * 8 * sizeof(float) +NoWinoSteps: + mov x21, #4 + mul x8, x8, x21 + +LoopRowStart: + cmp x9, #3 + bne LoopRow4 + mov x20, x2 +LoopRow4: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + cbz x9, NoReloadDst4 + cmp x9, #3 + beq C4ReloadDst4 + mov x11, x2 + b NoReloadDst4 + C4ReloadDst4: + mov x11, x20 + NoReloadDst4: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf4 + + LoopDepthStart4: + ld1 {v0.4s}, [x10], #16 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + + subs x19, x19, #1 + beq Bias4 + + LoopDepth4: + ld1 {v0.4s}, [x10], #16 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + + subs x19, x19, #1 + bgt LoopDepth4 + + Bias4: + cbz x3, Activation4 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + + Relu4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + b Write + + LoopDepthStartHalf4: + ld1 {v0.4s}, [x10], #16 + ld1 {v3.4s}, [x14], #16 + ld1 {v4.4s}, [x14], #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + + subs x19, x19, #1 + beq BiasHalf4 + + LoopDepthHalf4: + ld1 {v0.4s}, [x10], #16 + ld1 {v3.4s}, [x14], #16 + ld1 {v4.4s}, [x14], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf4 + + BiasHalf4: + cbz x3, ActivationHalf4 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + + ActivationHalf4: + cmp x4, #3 + beq Relu6Half4 + cmp x4, #1 + beq ReluHalf4 + b Write + + Relu6Half4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + + ReluHalf4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + + Write: + cmp x9, #2 + beq WriteWino + cmp x9, #3 + beq WriteC4 + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #4 + str s8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s14, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + st1 {v8.2s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + st1 {v8.2s}, [x11], x8 + st1 {v8.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + st1 {v10.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + st1 {v12.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + st1 {v14.s}[2], [x19] + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v8.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write5: + add x2, x2, #20 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str s9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str s11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str s13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str s15, [x19] + add x11, x11, #20 + b WriteEnd + Write6: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + add x11, x11, #24 + b WriteEnd + Write7: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + st1 {v9.s}[2], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + st1 {v11.s}[2], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + st1 {v13.s}[2], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + st1 {v15.s}[2], [x20], x8 + add x11, x11, #28 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x19], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x19], #64 + + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v8.4s, v9.4s}, [x11], x15 + st1 {v10.4s, v11.4s}, [x11], x15 + st1 {v12.4s, v13.4s}, [x11], x15 + st1 {v14.4s, v15.4s}, [x11], x15 + + b WriteEnd + Write8: + add x2, x2, #32 + st1 {v8.4s, v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x11], x8 + add x11, x11, #32 + b WriteEnd + WriteC4: + cmp x13, #1 + beq C4Write1 + cmp x13, #2 + beq C4Write2 + cmp x13, #3 + beq C4Write3 + cmp x13, #4 + beq C4Write4 + cmp x13, #5 + beq C4Write5 + cmp x13, #6 + beq C4Write6 + cmp x13, #7 + beq C4Write7 + b C4Write8 + C4Write1: + str s8, [x11], #4 + cmp x6, #1 + beq WriteEnd + str s10, [x11], #4 + cmp x6, #2 + beq WriteEnd + str s12, [x11], #4 + cmp x6, #3 + beq WriteEnd + str s14, [x11], #4 + b WriteEnd + C4Write2: + st1 {v8.2s}, [x11], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], #8 + b WriteEnd + C4Write3: + add x19, x11, #8 + st1 {v8.2s}, [x11] + add x11, x11, #12 + st1 {v8.s}[2], [x19] + add x19, x19, #12 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11] + add x11, x11, #12 + st1 {v10.s}[2], [x19] + add x19, x19, #12 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11] + add x11, x11, #12 + st1 {v12.s}[2], [x19] + add x19, x19, #12 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11] + st1 {v14.s}[2], [x19] + b WriteEnd + C4Write4: + st1 {v8.4s}, [x11], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + b WriteEnd + C4Write5: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + str s9, [x19], #4 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + str s11, [x19], #4 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + str s13, [x19], #4 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + str s15, [x19], #4 + b WriteEnd + C4Write6: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], #8 + b WriteEnd + C4Write7: + add x19, x11, x8 + add x16, x19, #8 + mov x15, #12 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], x15 + st1 {v9.s}[2], [x16], x15 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], x15 + st1 {v11.s}[2], [x16], x15 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], x15 + st1 {v13.s}[2], [x16], x15 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], x15 + st1 {v15.s}[2], [x16], x15 + b WriteEnd + C4Write8: + add x19, x11, x8 + add x20, x19, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.4s}, [x19], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.4s}, [x19], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.4s}, [x19], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.4s}, [x19], #16 + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + bgt LoopCol4 + + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + cmp x9, #3 + beq C4DstStep + mov x21, #4 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + b NoDstStep + C4DstStep: + add x2, x2, x18 + b NoDstStep + C8DstStep: + add x2, x2, #384 + mov x11, x2 + NoDstStep: + subs x6, x6, #12 + bgt LoopRow4 + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow8.S new file mode 100644 index 0000000000000000000000000000000000000000..998b1e93081b31c9162d6f6873a5e787e2d0165b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow8.S @@ -0,0 +1,911 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFloatNeon64OptRow8 + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] + + mov x21, #48 // sizeof(float) * 12 + mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth + cmp x9, #3 // c4 + beq C4Stride + cbnz x9, NoC8Steps + mov x11, x2 + mov x21, #32 + mul x16, x6, x21 // row * 8 * sizeof(float) + b NoC8Steps +C4Stride: + mov x18, #32 // 8 * sizeof(float) + mov x22, #4 + mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row + mul x8, x8, x22 // col stride + // col >= 4 , block stride 128, otherwise 8 * 4 * col + cmp x7, #4 + bge C4StrideCommon + mul x18, x18, x7 // block stride + b LoopRowStart +C4StrideCommon: + mov x18, #128 // block stride + b LoopRowStart + +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x21, #4 + mul x15, x7, x8 + mul x15, x15, x21 // kernel_size * col *sizeof(float) + mov x21, #32 + mul x16, x8, x21 // kernel_size * 8 * sizeof(float) +NoWinoSteps: + mov x21, #4 + mul x8, x8, x21 + +LoopRowStart: + cmp x9, #3 + bne LoopRow8 + mov x20, x2 +LoopRow8: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + cbz x9, NoReloadDst8 + cmp x9, #3 + beq C4ReloadDst8 + mov x11, x2 + b NoReloadDst8 + C4ReloadDst8: + mov x11, x20 + NoReloadDst8: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf8 + + LoopDepthStart8: + ld1 {v0.4s, v1.4s}, [x10], #32 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v17.4s, v4.4s, v1.s[0] + fmul v19.4s, v4.4s, v1.s[1] + fmul v21.4s, v4.4s, v1.s[2] + fmul v23.4s, v4.4s, v1.s[3] + + subs x19, x19, #1 + beq Bias8 + + LoopDepth8: + ld1 {v0.4s, v1.4s}, [x10], #32 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + + subs x19, x19, #1 + bgt LoopDepth8 + + Bias8: + cbz x3, Activation8 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + + Relu8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + b Write + + LoopDepthStartHalf8: + ld1 {v0.4s, v1.4s}, [x10], #32 + ld1 {v3.4s}, [x14], #16 + ld1 {v4.4s}, [x14], #16 // weight packed 8, only hold place + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + + subs x19, x19, #1 + beq BiasHalf8 + + LoopDepthHalf8: + ld1 {v0.4s, v1.4s}, [x10], #32 + ld1 {v3.4s}, [x14], #16 + ld1 {v4.4s}, [x14], #16 // only hold place + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf8 + + BiasHalf8: + cbz x3, ActivationHalf8 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + fadd v16.4s, v16.4s, v0.4s + fadd v18.4s, v18.4s, v0.4s + fadd v20.4s, v20.4s, v0.4s + fadd v22.4s, v22.4s, v0.4s + + ActivationHalf8: + cmp x4, #3 + beq Relu6Half8 + cmp x4, #1 + beq ReluHalf8 + b Write + + Relu6Half8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + + ReluHalf8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + + Write: + cmp x9, #2 + beq WriteWino + cmp x9, #3 + beq WriteC4 + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #4 + str s8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str s16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str s18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str s20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str s22, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + st1 {v8.2s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + st1 {v8.2s}, [x11], x8 + st1 {v8.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + st1 {v10.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + st1 {v12.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + st1 {v14.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v8.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write5: + add x2, x2, #20 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str s9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str s11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str s13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str s15, [x19] + cmp x6, #4 + beq WriteEnd + add x19, x19, x8 + st1 {v16.4s}, [x11], x8 + str s17, [x19] + cmp x6, #5 + beq WriteEnd + add x19, x19, x8 + st1 {v18.4s}, [x11], x8 + str s19, [x19] + cmp x6, #6 + beq WriteEnd + add x19, x19, x8 + st1 {v20.4s}, [x11], x8 + str s21, [x19] + cmp x6, #7 + beq WriteEnd + add x19, x19, x8 + st1 {v22.4s}, [x11], x8 + str s23, [x19] + add x11, x11, #20 + b WriteEnd + Write6: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + add x11, x11, #24 + b WriteEnd + Write7: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + st1 {v9.s}[2], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + st1 {v11.s}[2], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + st1 {v13.s}[2], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + st1 {v15.s}[2], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + st1 {v17.s}[2], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + st1 {v19.s}[2], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + st1 {v21.s}[2], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + st1 {v23.s}[2], [x20], x8 + add x11, x11, #28 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x19], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x19], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x19], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64 + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v8.4s, v9.4s}, [x11], x15 + st1 {v10.4s, v11.4s}, [x11], x15 + st1 {v12.4s, v13.4s}, [x11], x15 + st1 {v14.4s, v15.4s}, [x11], x15 + st1 {v16.4s, v17.4s}, [x11], x15 + st1 {v18.4s, v19.4s}, [x11], x15 + st1 {v20.4s, v21.4s}, [x11], x15 + st1 {v22.4s, v23.4s}, [x11], x15 + b WriteEnd + Write8: + add x2, x2, #32 + st1 {v8.4s, v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s, v17.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x11], x8 + add x11, x11, #32 + b WriteEnd + WriteC4: + cmp x13, #1 + beq C4Write1 + cmp x13, #2 + beq C4Write2 + cmp x13, #3 + beq C4Write3 + cmp x13, #4 + beq C4Write4 + cmp x13, #5 + beq C4Write5 + cmp x13, #6 + beq C4Write6 + cmp x13, #7 + beq C4Write7 + b C4Write8 + C4Write1: + str s8, [x11], #4 + cmp x6, #1 + beq WriteEnd + str s10, [x11], #4 + cmp x6, #2 + beq WriteEnd + str s12, [x11], #4 + cmp x6, #3 + beq WriteEnd + str s14, [x11], #4 + cmp x6, #4 + beq WriteEnd + str s16, [x11], #4 + cmp x6, #5 + beq WriteEnd + str s18, [x11], #4 + cmp x6, #6 + beq WriteEnd + str s20, [x11], #4 + cmp x6, #7 + beq WriteEnd + str s22, [x11], #4 + b WriteEnd + C4Write2: + st1 {v8.2s}, [x11], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], #8 + b WriteEnd + C4Write3: + add x19, x11, #8 + st1 {v8.2s}, [x11] + add x11, x11, #12 + st1 {v8.s}[2], [x19] + add x19, x19, #12 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11] + add x11, x11, #12 + st1 {v10.s}[2], [x19] + add x19, x19, #12 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11] + add x11, x11, #12 + st1 {v12.s}[2], [x19] + add x19, x19, #12 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11] + add x11, x11, #12 + st1 {v14.s}[2], [x19] + add x19, x19, #12 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11] + add x11, x11, #12 + st1 {v16.s}[2], [x19] + add x19, x19, #12 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11] + add x11, x11, #12 + st1 {v18.s}[2], [x19] + add x19, x19, #12 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11] + add x11, x11, #12 + st1 {v20.s}[2], [x19] + add x19, x19, #12 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11] + st1 {v22.s}[2], [x19] + b WriteEnd + C4Write4: + st1 {v8.4s}, [x11], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + b WriteEnd + C4Write5: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + str s9, [x19], #4 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + str s11, [x19], #4 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + str s13, [x19], #4 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + str s15, [x19], #4 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + str s17, [x19], #4 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + str s19, [x19], #4 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + str s21, [x19], #4 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + str s23, [x19], #4 + b WriteEnd + C4Write6: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], #8 + b WriteEnd + C4Write7: + add x19, x11, x8 + add x16, x19, #8 + mov x15, #12 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], x15 + st1 {v9.s}[2], [x16], x15 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], x15 + st1 {v11.s}[2], [x16], x15 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], x15 + st1 {v13.s}[2], [x16], x15 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], x15 + st1 {v15.s}[2], [x16], x15 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], x15 + st1 {v17.s}[2], [x16], x15 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], x15 + st1 {v19.s}[2], [x16], x15 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], x15 + st1 {v21.s}[2], [x16], x15 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], x15 + st1 {v23.s}[2], [x16], x15 + b WriteEnd + C4Write8: + add x19, x11, x8 + add x20, x19, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.4s}, [x19], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.4s}, [x19], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.4s}, [x19], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.4s}, [x19], #16 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.4s}, [x19], #16 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.4s}, [x19], #16 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.4s}, [x19], #16 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.4s}, [x19], #16 + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + bgt LoopCol8 + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + cmp x9, #3 + beq C4DstStep + mov x21, #4 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + b NoDstStep + C4DstStep: + add x2, x2, x18 + b NoDstStep + C8DstStep: + add x2, x2, #384 + mov x11, x2 + NoDstStep: + subs x6, x6, #12 + bgt LoopCol8 + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulInt8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulInt8.S new file mode 100644 index 0000000000000000000000000000000000000000..10a74163ded860f237d46b8ed4d83eeb37a88e81 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulInt8.S @@ -0,0 +1,420 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, +// const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, +// int32_t *right_shift, int row, int col, int stride, int filter_peroc) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row4 +// w4: col4 +// w5: deep16 +// x6: a_sums +// x7: bias +// w8: act_min +// w9: act_max +// w10: out_zp +// x11: multiplier +// x12: left_shift +// x13: right_shift +// w14: row +// w15: col +// w24: stride +// w27: filter_peroc + +asm_function MatmulInt8Neon64 + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr w8, [sp, #208] + ldr w9, [sp, #216] + ldr w10, [sp, #224] + ldr x11, [sp, #232] + ldr x12, [sp, #240] + ldr x13, [sp, #248] + ldr w14, [sp, #256] + ldr w15, [sp, #264] + ldr w24, [sp, #272] + ldr w27, [sp, #280] + + mov w17, #4 // sizeof(int8)*4 + mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 + mov w17, #1 + mov x25, x2 +L1: + cmp w4, #0 // if at the end of col4 + beq End1 + + mov w16, w3 // reset a row4 counter + mov w23, w14 // reset a row counter + mov x17, x0 // reload a ptr + mov x22, x6 // reload a_sums ptr +L2: + cmp w16, #0 + beq End2 + + mov x28, x1 // reload b ptr + mov x19, x7 // reload bias ptr + mov w20, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w20, #0 + beq End3 + + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x17], #16 + ld1 {v2.16b}, [x17], #16 + ld1 {v3.16b}, [x17], #16 + ld1 {v4.16b}, [x28], #16 + ld1 {v5.16b}, [x28], #16 + ld1 {v6.16b}, [x28], #16 + ld1 {v7.16b}, [x28], #16 + + smull v8.8h, v4.8b, v0.8b + smull v9.8h, v5.8b, v0.8b + smull v10.8h, v6.8b, v0.8b + smull v11.8h, v7.8b, v0.8b + smull v12.8h, v4.8b, v1.8b + smull v13.8h, v5.8b, v1.8b + smull v14.8h, v6.8b, v1.8b + smull v15.8h, v7.8b, v1.8b + + smlal2 v8.8h, v4.16b, v0.16b + smlal2 v9.8h, v5.16b, v0.16b + smlal2 v10.8h, v6.16b, v0.16b + smlal2 v11.8h, v7.16b, v0.16b + smlal2 v12.8h, v4.16b, v1.16b + smlal2 v13.8h, v5.16b, v1.16b + smlal2 v14.8h, v6.16b, v1.16b + smlal2 v15.8h, v7.16b, v1.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + smull v8.8h, v4.8b, v2.8b + smull v9.8h, v5.8b, v2.8b + smull v10.8h, v6.8b, v2.8b + smull v11.8h, v7.8b, v2.8b + smull v12.8h, v4.8b, v3.8b + smull v13.8h, v5.8b, v3.8b + smull v14.8h, v6.8b, v3.8b + smull v15.8h, v7.8b, v3.8b + + smlal2 v8.8h, v4.16b, v2.16b + smlal2 v9.8h, v5.16b, v2.16b + smlal2 v10.8h, v6.16b, v2.16b + smlal2 v11.8h, v7.16b, v2.16b + smlal2 v12.8h, v4.16b, v3.16b + smlal2 v13.8h, v5.16b, v3.16b + smlal2 v14.8h, v6.16b, v3.16b + smlal2 v15.8h, v7.16b, v3.16b + + sadalp v24.4s, v8.8h + sadalp v25.4s, v9.8h + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s, v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + subs w20, w20, #16 // depth + 16 + b L3 + +End3: + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x19], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + cmp w27, #0 + beq PerTLoad +PerCLoad: + ld1 {v20.4s}, [x6], #16 + ld1 {v21.4s}, [x6], #16 + ld1 {v22.4s}, [x6], #16 + ld1 {v23.4s}, [x6], #16 + + ld1 {v13.4s}, [x12] + ld1 {v12.4s}, [x11] + ld1 {v11.4s}, [x13] + b Apply + +PerTLoad: + ld1 {v14.4s}, [x22], #16 + dup v20.4s, v14.s[0] + dup v21.4s, v14.s[1] + dup v22.4s, v14.s[2] + dup v23.4s, v14.s[3] + + ld1 {v14.s}[0], [x12] + dup v13.4s, v14.s[0] + ld1 {v14.s}[0], [x11] + dup v12.4s, v14.s[0] + ld1 {v14.s}[0], [x13] + dup v11.4s, v14.s[0] + b Apply + +Apply: + // Subtract (Asums*Zb) + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + // Apply left shift + sqshl v16.4s, v16.4s, v13.4s + sqshl v17.4s, v17.4s, v13.4s + sqshl v18.4s, v18.4s, v13.4s + sqshl v19.4s, v19.4s, v13.4s + + // Apply the fixed-point part of the multiplier. + sqrdmulh v16.4s, v16.4s, v12.4s + sqrdmulh v17.4s, v17.4s, v12.4s + sqrdmulh v18.4s, v18.4s, v12.4s + sqrdmulh v19.4s, v19.4s, v12.4s + + // Apply right shift + and v20.16b, v11.16b, v16.16b + sshr v20.4s, v20.4s, #31 + sqadd v16.4s, v16.4s, v20.4s + srshl v16.4s, v16.4s, v11.4s + and v21.16b, v11.16b, v17.16b + sshr v21.4s, v21.4s, #31 + sqadd v17.4s, v17.4s, v21.4s + srshl v17.4s, v17.4s, v11.4s + and v22.16b, v11.16b, v18.16b + sshr v22.4s, v22.4s, #31 + sqadd v18.4s, v18.4s, v22.4s + srshl v18.4s, v18.4s, v11.4s + and v23.16b, v11.16b, v19.16b + sshr v23.4s, v23.4s, #31 + sqadd v19.4s, v19.4s, v23.4s + srshl v19.4s, v19.4s, v11.4s + + // Add the destination zero point + dup v10.4s, w10 + add v16.4s, v16.4s, v10.4s + add v17.4s, v17.4s, v10.4s + add v18.4s, v18.4s, v10.4s + add v19.4s, v19.4s, v10.4s + + // Apply the act_min bound + dup v9.4s, w8 + smax v16.4s, v16.4s, v9.4s + smax v17.4s, v17.4s, v9.4s + smax v18.4s, v18.4s, v9.4s + smax v19.4s, v19.4s, v9.4s + + // Apply the act_min bound + dup v8.4s, w9 + smin v16.4s, v16.4s, v8.4s + smin v17.4s, v17.4s, v8.4s + smin v18.4s, v18.4s, v8.4s + smin v19.4s, v19.4s, v8.4s + + // int32 -> int16 + sqxtn v13.4h, v16.4s + sqxtn2 v13.8h, v17.4s + sqxtn v14.4h, v18.4s + sqxtn2 v14.8h, v19.4s + + // int16 -> int8 + sqxtn v15.8b, v13.8h + sqxtn2 v15.16b, v14.8h + + cmp w23, #4 + blt Write // if rows < 4 + cmp w15, #4 + blt Write // if cols < 4 + + st1 {v15.s}[0], [x2], x24 + st1 {v15.s}[1], [x2], x24 + st1 {v15.s}[2], [x2], x24 + st1 {v15.s}[3], [x2], x24 + b Endwrite + +Write: + cmp w15, #4 + beq WriteCol4 + cmp w15, #3 + beq WriteCol3 + cmp w15, #2 + beq WriteCol2 + cmp w15, #1 + beq WriteCol1 + +WriteCol4: + st1 {v15.s}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v15.s}[1], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v15.s}[2], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v15.s}[3], [x2], x24 + b Endwrite + +WriteCol3: + mov x26, x2 + st1 {v15.b}[0], [x26], #1 + st1 {v15.b}[1], [x26], #1 + st1 {v15.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v15.b}[4], [x26], #1 + st1 {v15.b}[5], [x26], #1 + st1 {v15.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v15.b}[8], [x26], #1 + st1 {v15.b}[9], [x26], #1 + st1 {v15.b}[10], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v15.b}[12], [x26], #1 + st1 {v15.b}[13], [x26], #1 + st1 {v15.b}[14], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol2: + mov x26, x2 + st1 {v15.b}[0], [x26], #1 + st1 {v15.b}[1], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v15.b}[4], [x26], #1 + st1 {v15.b}[5], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v15.b}[8], [x26], #1 + st1 {v15.b}[9], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v15.b}[12], [x26], #1 + st1 {v15.b}[13], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol1: + st1 {v15.b}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v15.b}[4], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v15.b}[8], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v15.b}[12], [x2], x24 + b Endwrite + +Endwrite: + sub w16, w16, #4 // a row4 counter - 4 + sub w23, w23, #4 // a row counter - 4 + b L2 + +End2: + sub w4, w4, #4 // b col4 counter - 4 + sub w15, w15, #4 // b col counter - 4 + add x1, x1, x21 // b ptr + stride + add x7, x7, #16 // bias ptr + stride + add x25, x25, #4 // output + stride(4 * sizeof(int8)) + mov x2, x25 + + cmp w27, #0 + beq PerTEnd2 + add x12, x12, #16 + add x11, x11, #16 + add x13, x13, #16 +PerTEnd2: + b L1 + +End1: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulInt8Opt.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulInt8Opt.S new file mode 100644 index 0000000000000000000000000000000000000000..5fbb982599bfe26e7c2397fa1cd30dbe1d6e8068 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulInt8Opt.S @@ -0,0 +1,356 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, +// const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, +// int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// x3: row4 +// x4: col4 +// x5: deep16 +// x6: a_sums +// x7: bias +// w8: act_min +// w9: act_max +// w10: out_zp +// x11: multiplier +// x12: left_shift +// x13: right_shift +// x14: stride +// x15: filter_peroc +// x28: filter_zp + +asm_function MatmulInt8Opt + sub sp, sp, #224 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + stp x29, x30, [sp, #208] + + ldr w8, [sp, #224] + ldr w9, [sp, #232] + ldr w10, [sp, #240] + ldr x11, [sp, #248] + ldr x12, [sp, #256] + ldr x13, [sp, #264] + ldr x14, [sp, #272] + ldr x15, [sp, #280] + + mov x23, #4 + mul x23, x23, x5 // lhs step + mov x24, #4 + mul x24, x24, x14 // dst step +LoopRow: + mov x16, x1 // reload rhs ptr + mov x17, x4 // reload rhs col + mov x29, x7 // reload bias ptr + mov x27, x2 // reload dst ptr + ldr x28, [sp, #288] // reload filter_zp + + LoopCol: + mov x25, x6 // reload a_sums ptr + mov x19, x27 // reload dst ptr + mov x20, x0 // reload lhs ptr + mov x21, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + LoopDepth: + ld1 {v0.16b, v1.16b}, [x20], #32 + ld1 {v4.16b, v5.16b}, [x16], #32 + smull v8.8h, v4.8b, v0.8b + smull v9.8h, v5.8b, v0.8b + smull v12.8h, v4.8b, v1.8b + smull v13.8h, v5.8b, v1.8b + ld1 {v6.16b, v7.16b}, [x16], #32 + smlal2 v8.8h, v4.16b, v0.16b + smlal2 v9.8h, v5.16b, v0.16b + smlal2 v12.8h, v4.16b, v1.16b + smlal2 v13.8h, v5.16b, v1.16b + ld1 {v2.16b, v3.16b}, [x20], #32 + smull v10.8h, v6.8b, v0.8b + smull v11.8h, v7.8b, v0.8b + smull v14.8h, v6.8b, v1.8b + smull v15.8h, v7.8b, v1.8b + smlal2 v10.8h, v6.16b, v0.16b + smlal2 v11.8h, v7.16b, v0.16b + smlal2 v14.8h, v6.16b, v1.16b + smlal2 v15.8h, v7.16b, v1.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + smull v8.8h, v4.8b, v2.8b + smull v9.8h, v5.8b, v2.8b + smull v10.8h, v6.8b, v2.8b + smull v11.8h, v7.8b, v2.8b + smull v12.8h, v4.8b, v3.8b + smull v13.8h, v5.8b, v3.8b + smull v14.8h, v6.8b, v3.8b + smull v15.8h, v7.8b, v3.8b + + smlal2 v8.8h, v4.16b, v2.16b + smlal2 v9.8h, v5.16b, v2.16b + smlal2 v10.8h, v6.16b, v2.16b + smlal2 v11.8h, v7.16b, v2.16b + smlal2 v12.8h, v4.16b, v3.16b + smlal2 v13.8h, v5.16b, v3.16b + smlal2 v14.8h, v6.16b, v3.16b + smlal2 v15.8h, v7.16b, v3.16b + + sadalp v24.4s, v8.8h + sadalp v25.4s, v9.8h + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s, v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + subs x21, x21, #16 // depth - 16 + bgt LoopDepth + + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + Bias: + cbz x7, NoBias + ld1 {v15.4s}, [x29], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + NoBias: + ld1r {v20.4s}, [x25], #4 + ld1r {v21.4s}, [x25], #4 + ld1r {v22.4s}, [x25], #4 + ld1r {v23.4s}, [x25], #4 + cbz x15, ApplySum + + ld1 {v14.4s}, [x28], #16 + mul v20.4s, v20.4s, v14.4s + mul v21.4s, v21.4s, v14.4s + mul v22.4s, v22.4s, v14.4s + mul v23.4s, v23.4s, v14.4s + + ApplySum: + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + cbnz x15, PerCLoad + + ld1r {v13.4s}, [x12] + ld1r {v12.4s}, [x11] + ld1r {v11.4s}, [x13] + b Quantize + + PerCLoad: + ld1 {v13.4s}, [x12], #16 + ld1 {v12.4s}, [x11], #16 + ld1 {v11.4s}, [x13], #16 + + Quantize: + sqshl v16.4s, v16.4s, v13.4s + sqshl v17.4s, v17.4s, v13.4s + sqshl v18.4s, v18.4s, v13.4s + sqshl v19.4s, v19.4s, v13.4s + + sqrdmulh v16.4s, v16.4s, v12.4s + sqrdmulh v17.4s, v17.4s, v12.4s + sqrdmulh v18.4s, v18.4s, v12.4s + sqrdmulh v19.4s, v19.4s, v12.4s + + and v20.16b, v11.16b, v16.16b + sshr v20.4s, v20.4s, #31 + sqadd v16.4s, v16.4s, v20.4s + srshl v16.4s, v16.4s, v11.4s + and v21.16b, v11.16b, v17.16b + sshr v21.4s, v21.4s, #31 + sqadd v17.4s, v17.4s, v21.4s + srshl v17.4s, v17.4s, v11.4s + and v22.16b, v11.16b, v18.16b + sshr v22.4s, v22.4s, #31 + sqadd v18.4s, v18.4s, v22.4s + srshl v18.4s, v18.4s, v11.4s + and v23.16b, v11.16b, v19.16b + sshr v23.4s, v23.4s, #31 + sqadd v19.4s, v19.4s, v23.4s + srshl v19.4s, v19.4s, v11.4s + + dup v10.4s, w10 + add v16.4s, v16.4s, v10.4s + add v17.4s, v17.4s, v10.4s + add v18.4s, v18.4s, v10.4s + add v19.4s, v19.4s, v10.4s + + dup v9.4s, w8 + smax v16.4s, v16.4s, v9.4s + smax v17.4s, v17.4s, v9.4s + smax v18.4s, v18.4s, v9.4s + smax v19.4s, v19.4s, v9.4s + + dup v8.4s, w9 + smin v16.4s, v16.4s, v8.4s + smin v17.4s, v17.4s, v8.4s + smin v18.4s, v18.4s, v8.4s + smin v19.4s, v19.4s, v8.4s + + sqxtn v13.4h, v16.4s + sqxtn2 v13.8h, v17.4s + sqxtn v14.4h, v18.4s + sqxtn2 v14.8h, v19.4s + + sqxtn v15.8b, v13.8h + sqxtn2 v15.16b, v14.8h + + cmp x17, #1 + beq Write1 + cmp x17, #2 + beq Write2 + cmp x17, #3 + beq Write3 + b Write4 + + Write1: + add x27, x27, #1 + st1 {v15.b}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.b}[4], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.b}[8], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.b}[12], [x19], x14 + b WriteEnd + Write2: + add x27, x27, #2 + st1 {v15.h}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.h}[2], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.h}[4], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.h}[6], [x19], x14 + b WriteEnd + Write3: + add x27, x27, #3 + add x22, x19, #2 + st1 {v15.h}[0], [x19], x14 + st1 {v15.b}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.h}[2], [x19], x14 + st1 {v15.b}[6], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.h}[4], [x19], x14 + st1 {v15.b}[10], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.h}[6], [x19], x14 + st1 {v15.b}[14], [x22], x14 + b WriteEnd + Write4: + add x27, x27, #4 + st1 {v15.s}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.s}[1], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.s}[2], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.s}[3], [x19], x14 + + WriteEnd: + subs x17, x17, #4 + bgt LoopCol + +LoopColEnd: + subs x3, x3, #4 + ble LoopRowEnd + ldr x11, [sp, #248] + ldr x12, [sp, #256] + ldr x13, [sp, #264] + add x6, x6, #16 + add x0, x0, x23 + add x2, x2, x24 + b LoopRow + +LoopRowEnd: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ldp x29, x30, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulR4Int8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulR4Int8.S new file mode 100644 index 0000000000000000000000000000000000000000..01f29170d2dab2bbac3acf911e9037fb566c8498 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulR4Int8.S @@ -0,0 +1,193 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, +// const int *input_sum, const int *bias) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row4 +// w4: col4 +// w5: deep16 +// x6: a_sums +// x7: bias + +asm_function MatMulR4Int8Neon64 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + + mov w15, #0 // b col index + mov w16, #0 // a row index + mov w17, #4 // sizeof(int8)*4 + mul w12, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 + +L1: + cmp w15, w4 + beq End1 + + mov w16, #0 // reset a row index + mov x17, x0 // reload a ptr + mov x13, x6 // reload a_sums ptr +L2: + cmp w16, w3 + beq End2 + + mov x19, x1 // reload b ptr + mov x10, x7 // reload bias ptr + mov w11, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w11, #0 + beq End3 + + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x17], #16 + ld1 {v2.16b}, [x17], #16 + ld1 {v3.16b}, [x17], #16 + ld1 {v4.16b}, [x19], #16 + ld1 {v5.16b}, [x19], #16 + ld1 {v6.16b}, [x19], #16 + ld1 {v7.16b}, [x19], #16 + + smull v8.8h, v4.8b, v0.8b + smull v9.8h, v5.8b, v0.8b + smull v10.8h, v6.8b, v0.8b + smull v11.8h, v7.8b, v0.8b + smull v12.8h, v4.8b, v1.8b + smull v13.8h, v5.8b, v1.8b + smull v14.8h, v6.8b, v1.8b + smull v15.8h, v7.8b, v1.8b + + smlal2 v8.8h, v4.16b, v0.16b + smlal2 v9.8h, v5.16b, v0.16b + smlal2 v10.8h, v6.16b, v0.16b + smlal2 v11.8h, v7.16b, v0.16b + smlal2 v12.8h, v4.16b, v1.16b + smlal2 v13.8h, v5.16b, v1.16b + smlal2 v14.8h, v6.16b, v1.16b + smlal2 v15.8h, v7.16b, v1.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + smull v8.8h, v4.8b, v2.8b + smull v9.8h, v5.8b, v2.8b + smull v10.8h, v6.8b, v2.8b + smull v11.8h, v7.8b, v2.8b + smull v12.8h, v4.8b, v3.8b + smull v13.8h, v5.8b, v3.8b + smull v14.8h, v6.8b, v3.8b + smull v15.8h, v7.8b, v3.8b + + smlal2 v8.8h, v4.16b, v2.16b + smlal2 v9.8h, v5.16b, v2.16b + smlal2 v10.8h, v6.16b, v2.16b + smlal2 v11.8h, v7.16b, v2.16b + smlal2 v12.8h, v4.16b, v3.16b + smlal2 v13.8h, v5.16b, v3.16b + smlal2 v14.8h, v6.16b, v3.16b + smlal2 v15.8h, v7.16b, v3.16b + + sadalp v24.4s, v8.8h + sadalp v25.4s, v9.8h + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s, v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + subs w11, w11, #16 // depth + 16 + b L3 + +End3: + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x10], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + // Subtract (Asums*Zb) + ld1 {v14.4s}, [x13], #16 + dup v20.4s, v14.s[0] + dup v21.4s, v14.s[1] + dup v22.4s, v14.s[2] + dup v23.4s, v14.s[3] + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + add w16, w16, #4 // a row index + 4 + b L2 + +End2: + add w15, w15, #4 // b col index + 4 + add x1, x1, x12 // b ptr + stride + add x7, x7, #16 // bias ptr + stride + b L1 + +End1: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulWinogradFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulWinogradFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..ef5b39a4b4926c83dd328d7fb33ea88740d406d9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulWinogradFp32.S @@ -0,0 +1,183 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// MatrixMultiplyWinograd(float *matix_a, float *matrix_b, float *matrix_c, int m, int k, int n, int in_channel, int c4_channel) + // x0: matrix_a, x1: matrix_b, x2: matrix_c, x3: m, x4: k, x5: n, x6: in_channel, x7: c4_channel +asm_function MatrixMultiplyWinograd + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #48 + st1 {v8.4s}, [sp] + stp x19, x20, [sp, #16] + stp x21, x22, [sp, #32] + mov x8, #4 + mul x10, x5, x8 + mov x17, x3 // m + mul x13, x6, x8 // in_channel * 4 + mul x21, x13, x4 // in_channel * k + + LoopM: + mov x15, x5 // n + mov x14, x1 // mat_b + LoopN: + mov x16, x0 // mat_a_m + sub x22, x5, x15 // ni + sub x19, x17, x3 // mi + mul x22, x22, x17 // ni * m + mov x11, x6 // in_channel + add x22, x22, x19 // (ni * m) + mi + mul x22, x22, x7 // x22 * c4_channel + add x20, x2, x22 // dst + offset + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + b EndLoopC + LoopC16: + mov x12, x14 + mov x9, x4 // new_k + dup v5.4s, wzr + dup v6.4s, wzr + dup v7.4s, wzr + dup v8.4s, wzr + LoopK16: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x16], x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + fmla v6.4s, v1.4s, v4.s[0] + fmla v7.4s, v2.4s, v4.s[0] + fmla v8.4s, v3.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK16 + Write16: + st1 {v5.4s}, [x20], #16 + st1 {v6.4s}, [x20], #16 + st1 {v7.4s}, [x20], #16 + st1 {v8.4s}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #64 // add 64B + subs x11, x11, #16 + beq EndLoopC + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC8: + dup v5.4s, wzr + dup v6.4s, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK8: + ld1 {v0.4s, v1.4s}, [x16], x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + fmla v6.4s, v1.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK8 + Write8: + st1 {v5.4s}, [x20], #16 + st1 {v6.4s}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #32 // add 64B + subs x11, x11, #8 + beq EndLoopC + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC4: + dup v5.4s, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK4: + ld1 {v0.4s}, [x16], x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK4 + Write4: + st1 {v5.4s}, [x20], #16 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #16 // add 16B + subs x11, x11, #4 + beq EndLoopC + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC: + dup v5.4s, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK: + ldr s0, [x16] + add x16, x16, x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK + Write1: + str s5, [x20], #4 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #4 // ptr add 4B + subs x11, x11, #1 + beq EndLoopC + b LoopC + + EndLoopC: + add x14, x14, #4 + subs x15, x15, #1 + beq EndLoopN + b LoopN + EndLoopN: + subs x3, x3, #1 + beq EndLoopM + add x0, x0, x21 + b LoopM + EndLoopM: + ld1 {v8.4s}, [sp], #16 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncBiasReluC4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncBiasReluC4.S new file mode 100644 index 0000000000000000000000000000000000000000..e88eab25567a475291e2fa6c5c41a98af580536f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncBiasReluC4.S @@ -0,0 +1,316 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, +// size_t plane_size, size_t plane_stride, size_t relu_type); +// x0 dst x1 srx x2 bias +// w3 oc4div w4 oc4mod w5 plane_size +// x6 plane_stride x7 relu_type + +// v0 ~ v7 value +// v16 bias data +// x12 oc_stride +// x14 x15 write loop tmp buf +// v26 relu6 #6; v27 relu #0 +// w10 oc4 loop control +// w13 hw loop control + + +asm_function WinogradPostFuncBiasReluC4 + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + mov x10, #4 + add x12, x3, x4 + mul x12, x12, x10 + + mov w10, #0 + +Loop_C4: + cmp w10, w3 + beq Loop_C1 + mov x15, #4 + mul x14, x10, x15 + add x15, x0, x14 + add w10, w10, #4 + mov w13, w5 + ld1 {v16.4s}, [x2], #16 + +Loop_8x4: + cmp w13, #8 + blt Loop_4x4 + sub w13, w13, #8 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 + + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v16.4s + fadd v4.4s, v4.4s, v16.4s + fadd v5.4s, v5.4s, v16.4s + fadd v6.4s, v6.4s, v16.4s + fadd v7.4s, v7.4s, v16.4s + + cmp x7, #3 + beq Relu6_8x4 + cmp x7, #1 + beq Relu_8x4 + b Write_8x4 +Relu6_8x4: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s +Relu_8x4: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s +Write_8x4: + st1 {v0.4s}, [x15], x12 + st1 {v1.4s}, [x15], x12 + st1 {v2.4s}, [x15], x12 + st1 {v3.4s}, [x15], x12 + st1 {v4.4s}, [x15], x12 + st1 {v5.4s}, [x15], x12 + st1 {v6.4s}, [x15], x12 + st1 {v7.4s}, [x15], x12 + b Loop_8x4 + +Loop_4x4: + cmp w13, #4 + blt Loop_1x4 + sub w13, w13, #4 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v16.4s + cmp x7, #3 + beq Relu6_4x4 + cmp x7, #1 + beq Relu_4x4 + b Write_4x4 +Relu6_4x4: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s +Relu_4x4: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s +Write_4x4: + st1 {v0.4s}, [x15], x12 + st1 {v1.4s}, [x15], x12 + st1 {v2.4s}, [x15], x12 + st1 {v3.4s}, [x15], x12 + +Loop_1x4: + cmp x7, #3 + beq Relu6_1x4 + cmp x7, #1 + beq Relu_1x4 + b Write_1x4 +Relu6_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x15], x12 + b Relu6_1x4 +Relu_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x15], x12 + b Relu_1x4 +Write_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + st1 {v0.4s}, [x15], x12 + b Write_1x4 + +HW_Add: + add x1, x1, x6 + b Loop_C4 + +Loop_C1: + cmp x4, #0 + beq End + mov w13, w5 + ld1 {v16.4s}, [x2], #16 + mov x15, #4 + mul x14, x10, x15 + add x0, x0, x14 + + cmp x4, #1 + beq Loop_C1_1 + cmp x4, #2 + beq Loop_C1_2 + cmp x4, #3 + beq Loop_C1_3 + +Loop_C1_1: + cmp x7, #3 + beq Loop_C1_1_Relu6 + cmp x7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x12 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x12 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + str s0, [x0] + add x0, x0, x12 + b Loop_C1_1_Write + +Loop_C1_2: + cmp x7, #3 + beq Loop_C1_2_Relu6 + cmp x7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + b Loop_C1_2_Write + +Loop_C1_3: + add x15, x0, #8 + cmp x7, #3 + beq Loop_C1_3_Relu6 + cmp x7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + st1 {v0.s}[2], [x15], x12 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + st1 {v0.s}[2], [x15], x12 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + st1 {v0.s}[2], [x15], x12 + b Loop_C1_3_Write + +End: + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncBiasReluC8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncBiasReluC8.S new file mode 100644 index 0000000000000000000000000000000000000000..99213447720c96758b74dae5e9bbf15bd2d64eb8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncBiasReluC8.S @@ -0,0 +1,553 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div,size_t oc8mod +// size_t plane_size, size_t stride, int relu_type); +// x0 dst x1 srx x2 bias +// x3 oc8div x4 oc8mod x5 plane_size +// x6 stride x7 relu_type + +// v0 ~ v15 value +// v16 v17 bias data +// x14 x15 weite loop tmp buf +// x16 relu6 #6; x17 relu #0 +// w10 oc8 loop control +// w13 hw loop control + +asm_function PostFuncBiasReluC8 + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + mov w10, #0 + +Loop_C8: + cmp w10, w3 + beq Loop_C1 + mov x15, #4 + mul x14, x10, x15 + add x15, x0, x14 + add w10, w10, #8 + mov w13, w5 + ld1 {v16.4s, v17.4s}, [x2], #32 + +Loop_8x8: + cmp w13, #8 + blt Loop_4x8 + sub w13, w13, #8 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64 + + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v17.4s + fadd v4.4s, v4.4s, v16.4s + fadd v5.4s, v5.4s, v17.4s + fadd v6.4s, v6.4s, v16.4s + fadd v7.4s, v7.4s, v17.4s + fadd v8.4s, v8.4s, v16.4s + fadd v9.4s, v9.4s, v17.4s + fadd v10.4s, v10.4s, v16.4s + fadd v11.4s, v11.4s, v17.4s + fadd v12.4s, v12.4s, v16.4s + fadd v13.4s, v13.4s, v17.4s + fadd v14.4s, v14.4s, v16.4s + fadd v15.4s, v15.4s, v17.4s + + cmp x7, #3 + beq Relu6_8x8 + cmp x7, #1 + beq Relu_8x8 + b Write_8x8 +Relu6_8x8: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + fmin v8.4s, v8.4s, v26.4s + fmin v9.4s, v9.4s, v26.4s + fmin v10.4s, v10.4s, v26.4s + fmin v11.4s, v11.4s, v26.4s + fmin v12.4s, v12.4s, v26.4s + fmin v13.4s, v13.4s, v26.4s + fmin v14.4s, v14.4s, v26.4s + fmin v15.4s, v15.4s, v26.4s +Relu_8x8: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + fmax v8.4s, v8.4s, v27.4s + fmax v9.4s, v9.4s, v27.4s + fmax v10.4s, v10.4s, v27.4s + fmax v11.4s, v11.4s, v27.4s + fmax v12.4s, v12.4s, v27.4s + fmax v13.4s, v13.4s, v27.4s + fmax v14.4s, v14.4s, v27.4s + fmax v15.4s, v15.4s, v27.4s +Write_8x8: + st1 {v0.4s, v1.4s}, [x15], x6 + st1 {v2.4s, v3.4s}, [x15], x6 + st1 {v4.4s, v5.4s}, [x15], x6 + st1 {v6.4s, v7.4s}, [x15], x6 + st1 {v8.4s, v9.4s}, [x15], x6 + st1 {v10.4s, v11.4s}, [x15], x6 + st1 {v12.4s, v13.4s}, [x15], x6 + st1 {v14.4s, v15.4s}, [x15], x6 + b Loop_8x8 + +Loop_4x8: + cmp w13, #4 + blt Loop_1x8 + sub w13, w13, #4 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 + + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v17.4s + fadd v4.4s, v4.4s, v16.4s + fadd v5.4s, v5.4s, v17.4s + fadd v6.4s, v6.4s, v16.4s + fadd v7.4s, v7.4s, v17.4s + + cmp x7, #3 + beq Relu6_4x8 + cmp x7, #1 + beq Relu_4x8 + b Write_4x8 +Relu6_4x8: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s +Relu_4x8: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s +Write_4x8: + st1 {v0.4s, v1.4s}, [x15], x6 + st1 {v2.4s, v3.4s}, [x15], x6 + st1 {v4.4s, v5.4s}, [x15], x6 + st1 {v6.4s, v7.4s}, [x15], x6 + +Loop_1x8: + cmp x7, #3 + beq Relu6_1x8 + cmp x7, #1 + beq Relu_1x8 + b Write_1x8 +Relu6_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s, v1.4s}, [x15], x6 + b Relu6_1x8 +Relu_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s, v1.4s}, [x15], x6 + b Relu_1x8 +Write_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s, v1.4s}, [x15], x6 + b Write_1x8 + + +Loop_C1: + cmp x4, #0 + beq End + mov w13, w5 + ld1 {v16.4s, v17.4s}, [x2], #32 + mov x15, #4 + mul x14, x10, x15 + add x0, x0, x14 + + cmp x4, #1 + beq Loop_C1_1 + cmp x4, #2 + beq Loop_C1_2 + cmp x4, #3 + beq Loop_C1_3 + cmp x4, #4 + beq Loop_C1_4 + cmp x4, #5 + beq Loop_C1_5 + cmp x4, #6 + beq Loop_C1_6 + cmp x4, #7 + beq Loop_C1_7 + +Loop_C1_1: + cmp x7, #3 + beq Loop_C1_1_Relu6 + cmp x7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x6 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x6 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + str s0, [x0] + add x0, x0, x6 + b Loop_C1_1_Write + +Loop_C1_2: + cmp x7, #3 + beq Loop_C1_2_Relu6 + cmp x7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + b Loop_C1_2_Write + + +Loop_C1_3: + add x15, x0, #8 + cmp x7, #3 + beq Loop_C1_3_Relu6 + cmp x7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + st1 {v0.s}[2], [x15], x6 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + st1 {v0.s}[2], [x15], x6 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + st1 {v0.s}[2], [x15], x6 + b Loop_C1_3_Write + +Loop_C1_4: + cmp x7, #3 + beq Loop_C1_4_Relu6 + cmp x7, #1 + beq Loop_C1_4_Relu + b Loop_C1_4_Write +Loop_C1_4_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x0], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x0], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + st1 {v0.4s}, [x0], x6 + b Loop_C1_4_Write + +Loop_C1_5: + add x15, x0, #16 + cmp x7, #3 + beq Loop_C1_5_Relu6 + cmp x7, #1 + beq Loop_C1_5_Relu + b Loop_C1_5_Write +Loop_C1_5_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + str s1, [x15] + add x15, x15, x6 + b Loop_C1_5_Relu6 +Loop_C1_5_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + str s1, [x15] + add x15, x15, x6 + b Loop_C1_5_Relu +Loop_C1_5_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s}, [x0], x6 + str s1, [x15] + add x15, x15, x6 + b Loop_C1_5_Write + +Loop_C1_6: + add x15, x0, #16 + cmp x7, #3 + beq Loop_C1_6_Relu6 + cmp x7, #1 + beq Loop_C1_6_Relu + b Loop_C1_6_Write +Loop_C1_6_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + b Loop_C1_6_Relu6 +Loop_C1_6_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + b Loop_C1_6_Relu +Loop_C1_6_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + b Loop_C1_6_Write + +Loop_C1_7: + add x15, x0, #16 + add x14, x0, #24 + cmp x7, #3 + beq Loop_C1_7_Relu6 + cmp x7, #1 + beq Loop_C1_7_Relu + b Loop_C1_7_Write +Loop_C1_7_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + st1 {v1.s}[2], [x14], x6 + b Loop_C1_7_Relu6 +Loop_C1_7_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + st1 {v1.s}[2], [x14], x6 + b Loop_C1_7_Relu +Loop_C1_7_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + st1 {v1.s}[2], [x14], x6 + b Loop_C1_7_Write + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncInt8C4Neon64.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncInt8C4Neon64.S new file mode 100644 index 0000000000000000000000000000000000000000..71c44685d249cc54469c60a7f598d9175a5db56f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncInt8C4Neon64.S @@ -0,0 +1,259 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res, +// size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift, +// int32_t zp, int32_t mini, int32_t maxi); +// x0 in +// x1 bias +// x2 out +// x3 oc4div +// x4 oc4res +// x5 plane +// x6 stride +// x7 multiplier +// x8 left_shift +// x9 right_shift +// x10 zp +// x11 mini +// x12 maxi + +// v0 ~ v15 value +// x24 x25 write loop tmp buf + + +// v16 bias data + +// v26 multiplier +// v27 left_shift +// v28 right_shift +// v29 zp +// v30 min +// v31 max + +// w15 oc4 loop control +// w16 hw loop control + +asm_function PostFuncInt8C4Neon64 + sub sp, sp, #16 + stp x24, x25, [sp] + + ldr w8, [sp, #16] + ldr w9, [sp, #24] + ldr w10, [sp, #32] + ldr w11, [sp, #40] + ldr w12, [sp, #48] + ldr w13, [sp, #56] + + dup v26.4s, w7 + dup v27.4s, w8 + dup v28.4s, w9 + dup v29.4s, w10 + dup v30.4s, w11 + dup v31.4s, w12 + + mov x15, #0 + +Loop_C4: + cmp x15, x3 + beq Loop_C1 + mov x25, #4 + mul x24, x15, x25 + add x25, x2, x24 + add w15, w15, #4 + mov w16, w5 + ld1 {v16.4s}, [x1], #16 + +Loop_4x4: + cmp x16, #4 + blt Loop_1x4 + sub x16, x16, #4 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 + + add v0.4s, v0.4s, v16.4s + add v1.4s, v1.4s, v16.4s + add v2.4s, v2.4s, v16.4s + add v3.4s, v3.4s, v16.4s + sqshl v0.4s, v0.4s, v27.4s + sqshl v1.4s, v1.4s, v27.4s + sqshl v2.4s, v2.4s, v27.4s + sqshl v3.4s, v3.4s, v27.4s + sqrdmulh v0.4s, v0.4s, v26.4s + sqrdmulh v1.4s, v1.4s, v26.4s + sqrdmulh v2.4s, v2.4s, v26.4s + sqrdmulh v3.4s, v3.4s, v26.4s + and v4.16b, v28.16b, v0.16b + and v5.16b, v28.16b, v1.16b + and v6.16b, v28.16b, v2.16b + and v7.16b, v28.16b, v3.16b + sshr v4.4s, v4.4s, #31 + sshr v5.4s, v5.4s, #31 + sshr v6.4s, v6.4s, #31 + sshr v7.4s, v7.4s, #31 + sqadd v0.4s, v0.4s, v4.4s + sqadd v1.4s, v1.4s, v5.4s + sqadd v2.4s, v2.4s, v6.4s + sqadd v3.4s, v3.4s, v7.4s + srshl v0.4s, v0.4s, v28.4s + srshl v1.4s, v1.4s, v28.4s + srshl v2.4s, v2.4s, v28.4s + srshl v3.4s, v3.4s, v28.4s + add v0.4s, v0.4s, v29.4s + add v1.4s, v1.4s, v29.4s + add v2.4s, v2.4s, v29.4s + add v3.4s, v3.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smax v1.4s, v1.4s, v30.4s + smax v2.4s, v2.4s, v30.4s + smax v3.4s, v3.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + smin v1.4s, v1.4s, v31.4s + smin v2.4s, v2.4s, v31.4s + smin v3.4s, v3.4s, v31.4s + sqxtn v4.4h, v0.4s + sqxtn v5.4h, v1.4s + sqxtn v6.4h, v2.4s + sqxtn v7.4h, v3.4s + sqxtn v0.8b, v4.8h + sqxtn v1.8b, v5.8h + sqxtn v2.8b, v6.8h + sqxtn v3.8b, v7.8h + + st1 {v0.s}[0], [x25], x6 + st1 {v1.s}[0], [x25], x6 + st1 {v2.s}[0], [x25], x6 + st1 {v3.s}[0], [x25], x6 + b Loop_4x4 + + +Loop_1x4: + cmp x16, #0 + beq Loop_C4 + sub x16, x16, #1 + ld1 {v0.4s}, [x0], #16 + + add v0.4s, v0.4s, v16.4s + sqshl v0.4s, v0.4s, v27.4s + sqrdmulh v0.4s, v0.4s, v26.4s + and v2.16b, v28.16b, v0.16b + sshr v2.4s, v2.4s, #31 + sqadd v0.4s, v0.4s, v2.4s + srshl v0.4s, v0.4s, v28.4s + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + sqxtn v1.4h, v0.4s + sqxtn v0.8b, v1.8h + + st1 {v0.s}[0], [x25], x6 + b Loop_1x4 + +Loop_C1: + cmp x4, #0 + beq End + mov x16, x5 + ld1 {v16.4s}, [x1], #16 + mov x25, #4 + mul x24, x15, x25 + add x25, x2, x24 + add x24, x25, #2 + + cmp x4, #1 + beq Loop_C1_1 + cmp x4, #2 + beq Loop_C1_2 + cmp x4, #3 + beq Loop_C1_3 + +Loop_C1_1: + cmp x16, #0 + beq End + sub x16, x16, #1 + ld1 {v0.4s}, [x0], #16 + + add v0.4s, v0.4s, v16.4s + sqshl v0.4s, v0.4s, v27.4s + sqrdmulh v0.4s, v0.4s, v26.4s + and v2.16b, v28.16b, v0.16b + sshr v2.4s, v2.4s, #31 + sqadd v0.4s, v0.4s, v2.4s + srshl v0.4s, v0.4s, v28.4s + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + sqxtn v1.4h, v0.4s + sqxtn v0.8b, v1.8h + + st1 {v0.b}[0], [x25], x6 + b Loop_C1_1 + + +Loop_C1_2: + cmp x16, #0 + beq End + sub x16, x16, #1 + ld1 {v0.4s}, [x0], #16 + + add v0.4s, v0.4s, v16.4s + sqshl v0.4s, v0.4s, v27.4s + sqrdmulh v0.4s, v0.4s, v26.4s + and v2.16b, v28.16b, v0.16b + sshr v2.4s, v2.4s, #31 + sqadd v0.4s, v0.4s, v2.4s + srshl v0.4s, v0.4s, v28.4s + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + sqxtn v1.4h, v0.4s + sqxtn v0.8b, v1.8h + + st1 {v0.h}[0], [x25], x6 + b Loop_C1_2 + + +Loop_C1_3: + cmp x16, #0 + beq End + sub x16, x16, #1 + ld1 {v0.4s}, [x0], #16 + + add v0.4s, v0.4s, v16.4s + sqshl v0.4s, v0.4s, v27.4s + sqrdmulh v0.4s, v0.4s, v26.4s + and v2.16b, v28.16b, v0.16b + sshr v2.4s, v2.4s, #31 + sqadd v0.4s, v0.4s, v2.4s + srshl v0.4s, v0.4s, v28.4s + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + sqxtn v1.4h, v0.4s + sqxtn v0.8b, v1.8h + + st1 {v0.h}[0], [x25], x6 + st1 {v0.b}[2], [x24], x6 + b Loop_C1_3 + + +End: + ldp x24, x25, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PreSum4x16Int8Peroc.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PreSum4x16Int8Peroc.S new file mode 100644 index 0000000000000000000000000000000000000000..53b6ec5a053c9f2b6445f7acf54729afa3fecbb6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PreSum4x16Int8Peroc.S @@ -0,0 +1,140 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div4, +// size_t oc_res4, size_t stride); + +// x0 src +// x1 sum +// x2 zp +// w3 hw4 +// w4 ic16 +// w5 oc_div4 +// w6 oc_res4 +// w7 stride + +asm_function PreSum4x16Int8Peroc + mov w8, #0 + +RowLoop: + cmp w8, w3 + beq End + add w8, w8, #4 + dup v16.4s, wzr + mov w9, #0 + mov x16, x2 + +Sum: + cmp w9, w4 + beq Mul + add w9, w9, #16 + + ld1 {v0.16b}, [x0], #16 + ld1 {v1.16b}, [x0], #16 + ld1 {v2.16b}, [x0], #16 + ld1 {v3.16b}, [x0], #16 + + saddlp v4.8h, v0.16b + saddlp v5.8h, v1.16b + saddlp v6.8h, v2.16b + saddlp v7.8h, v3.16b + saddlp v0.4S, v4.8h + saddlp v1.4S, v5.8h + saddlp v2.4S, v6.8h + saddlp v3.4S, v7.8h + addv s4, v0.4S + addv s5, v1.4S + addv s6, v2.4S + addv s7, v3.4S + mov v0.s[0], v4.s[0] + mov v0.s[1], v5.s[0] + mov v0.s[2], v6.s[0] + mov v0.s[3], v7.s[0] + + add v16.4s, v16.4s, v0.4s + b Sum + +Mul: + mov x12, x1 + add x1, x1, #64 + mov w9, #0 + + dup v1.4s, v16.s[0] + dup v2.4s, v16.s[1] + dup v3.4s, v16.s[2] + dup v4.4s, v16.s[3] + +WriteOc4: + cmp w9, w5 + beq OcRes4 + add w9, w9, #4 + ld1 {v5.4s}, [x16], #16 + + mul v16.4s, v5.4s, v1.4s + mul v17.4s, v5.4s, v2.4s + mul v18.4s, v5.4s, v3.4s + mul v19.4s, v5.4s, v4.4s + st1 {v16.4s}, [x12], #16 + st1 {v17.4s}, [x12], #16 + st1 {v18.4s}, [x12], #16 + st1 {v19.4s}, [x12], #16 + add x12, x12, x7 + b WriteOc4 + +OcRes4: + cmp w6, #0 + beq RowLoop + dup v15.4s, wzr + cmp w6, #1 + beq OcRes4_1 + cmp w6, #2 + beq OcRes4_2 + cmp w6, #3 + beq OcRes4_3 + +OcRes4_1: + ld1 {v15.s}[0], [x16] + b OcRes4End + +OcRes4_2: + ld1 {v15.d}[0], [x16] + b OcRes4End + +OcRes4_3: + ld1 {v15.d}[0], [x16] + add x16, x16, #8 + ld1 {v15.s}[2], [x16] + b OcRes4End + +OcRes4End: + mul v16.4s, v15.4s, v1.4s + mul v17.4s, v15.4s, v2.4s + mul v18.4s, v15.4s, v3.4s + mul v19.4s, v15.4s, v4.4s + st1 {v16.4s}, [x12], #16 + st1 {v17.4s}, [x12], #16 + st1 {v18.4s}, [x12], #16 + st1 {v19.4s}, [x12], #16 + b RowLoop + +End: + ret +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PreSum4x16Int8Pert.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PreSum4x16Int8Pert.S new file mode 100644 index 0000000000000000000000000000000000000000..1e7f0709e1885d3e99a01c53c13d5eed28e9e3c2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PreSum4x16Int8Pert.S @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void PreSum4x16Int8Pert(const int8_t *src, int32_t *dst, size_t row4, size_t col16, int32_t filter_zp); + +// x0 src +// x1 dst +// w2 row4 +// w3 co16 +// w4 filter_zp + +asm_function PreSum4x16Int8Pert + dup v17.4s, w4 + mov w5, #0 + +RowLoop: + cmp w5, w2 + beq End + add w5, w5, #4 + dup v16.4s, wzr + mov w6, #0 + +CalLoop: + cmp w6, w3 + beq Write + add w6, w6, #16 + + ld1 {v0.16b}, [x0], #16 + ld1 {v1.16b}, [x0], #16 + ld1 {v2.16b}, [x0], #16 + ld1 {v3.16b}, [x0], #16 + + saddlp v4.8h, v0.16b + saddlp v5.8h, v1.16b + saddlp v6.8h, v2.16b + saddlp v7.8h, v3.16b + + saddlp v0.4S, v4.8h + saddlp v1.4S, v5.8h + saddlp v2.4S, v6.8h + saddlp v3.4S, v7.8h + + addv s4, v0.4S + addv s5, v1.4S + addv s6, v2.4S + addv s7, v3.4S + + mov v0.s[0], v4.s[0] + mov v0.s[1], v5.s[0] + mov v0.s[2], v6.s[0] + mov v0.s[3], v7.s[0] + + add v16.4s, v16.4s, v0.4s + b CalLoop + +Write: + mul v16.4s, v16.4s, v17.4s + st1 {v16.4s}, [x1], #16 + beq RowLoop + +End: + ret +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/SPMM8x8Fp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/SPMM8x8Fp32.S new file mode 100644 index 0000000000000000000000000000000000000000..e2317a700cf4baf3829bb12f9dd2aec96ae60808 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/SPMM8x8Fp32.S @@ -0,0 +1,294 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SPMM8x8Fp32(const float *a, const float *b, const uint32_t *nnz, const size_t *dmap, float *c, +// const float *bias, ActType act_type, size_t out_stride); +// x0: a +// x1: b +// x2: nnz +// x3: dmap +// x4: c +// x5: bias +// w6: act_type +// x7: out_stride + +// wdata tmp w8 +// loop_oc_count w9 +// loop_nnz_count w10 +// dmap tmp w11 +// a_ptr +// 8 x 1 fp32 A v0-v1 +// fp32 B-value v2 +// uint32 B-NNZ x9 +// uint32 B-INDEX x10 +// 4 MIN v3 +// 4 MAX v4 +// 2 vacc v5-v6 +// 8 x 8 fp32 C v16-v31 + +// v16[0] v18[0] v20[0] v22[0] v24[0] v26[0] v28[0] v30[0] +// v16[1] v18[1] v20[1] v22[1] v24[1] v26[1] v28[1] v30[1] +// v16[2] v18[2] v20[2] v22[2] v24[2] v26[2] v28[2] v30[2] +// v16[3] v18[3] v20[3] v22[3] v24[3] v26[3] v28[3] v30[3] +// v17[0] v19[0] v21[0] v23[0] v25[0] v27[0] v29[0] v31[0] +// v17[1] v19[1] v21[1] v23[1] v25[1] v27[1] v29[1] v31[1] +// v17[2] v19[2] v21[2] v23[2] v25[2] v27[2] v29[2] v31[2] +// v17[3] v19[3] v21[3] v23[3] v25[3] v27[3] v29[3] v31[3] + +asm_function SPMM8x8Fp32 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + + // init output with bias + ldr w8, [x5], #4 + dup v16.4s, w8 + dup v17.4s, w8 + ldr w8, [x5], #4 + dup v18.4s, w8 + dup v19.4s, w8 + ldr w8, [x5], #4 + dup v20.4s, w8 + dup v21.4s, w8 + ldr w8, [x5], #4 + dup v22.4s, w8 + dup v23.4s, w8 + ldr w8, [x5], #4 + dup v24.4s, w8 + dup v25.4s, w8 + ldr w8, [x5], #4 + dup v26.4s, w8 + dup v27.4s, w8 + ldr w8, [x5], #4 + dup v28.4s, w8 + dup v29.4s, w8 + ldr w8, [x5] + dup v30.4s, w8 + dup v31.4s, w8 + + // OC 0 + ldr w10, [x2], #4 // load nnz + cmp w10, #0 + beq OC_1 +LOOP_NNZ0: + ldr x11, [x3], #8 // load dmap + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] // load inputs + ldr w8, [x1], #4 // load weight + dup v2.4s, w8 + // matmul + fmla v16.4s, v0.4s, v2.4s + fmla v17.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ0 + +OC_1: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_2 +LOOP_NNZ1: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v18.4s, v0.4s, v2.4s + fmla v19.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ1 + +OC_2: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_3 +LOOP_NNZ2: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v20.4s, v0.4s, v2.4s + fmla v21.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ2 + +OC_3: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_4 +LOOP_NNZ3: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v22.4s, v0.4s, v2.4s + fmla v23.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ3 + +OC_4: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_5 +LOOP_NNZ4: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v24.4s, v0.4s, v2.4s + fmla v25.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ4 + +OC_5: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_6 +LOOP_NNZ5: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v26.4s, v0.4s, v2.4s + fmla v27.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ5 + +OC_6: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_7 +LOOP_NNZ6: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v28.4s, v0.4s, v2.4s + fmla v29.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ6 + +OC_7: + ldr w10, [x2], #4 + cmp w10, #0 + beq REORDER_OUT +LOOP_NNZ7: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v30.4s, v0.4s, v2.4s + fmla v31.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ7 + + // reorder output +// v16[0] v18[0] v20[0] v22[0] v24[0] v26[0] v28[0] v30[0] +// v16[1] v18[1] v20[1] v22[1] v24[1] v26[1] v28[1] v30[1] +// v16[2] v18[2] v20[2] v22[2] v24[2] v26[2] v28[2] v30[2] +// v16[3] v18[3] v20[3] v22[3] v24[3] v26[3] v28[3] v30[3] +// v17[0] v19[0] v21[0] v23[0] v25[0] v27[0] v29[0] v31[0] +// v17[1] v19[1] v21[1] v23[1] v25[1] v27[1] v29[1] v31[1] +// v17[2] v19[2] v21[2] v23[2] v25[2] v27[2] v29[2] v31[2] +// v17[3] v19[3] v21[3] v23[3] v25[3] v27[3] v29[3] v31[3] + +// v0[0] v0[1] v0[2] v0[3] v1[0] v1[1] v1[2] v1[3] +// v2[0] v2[1] v2[2] v2[3] v3[0] v3[1] v3[2] v3[3] +// v4[0] v4[1] v4[2] v4[3] v5[0] v5[1] v5[2] v5[3] +// v6[0] v6[1] v6[2] v6[3] v7[0] v7[1] v7[2] v7[3] +// v8[0] v8[1] v8[2] v8[3] v9[0] v9[1] v9[2] v9[3] +// v10[0] v10[1] v10[2] v10[3] v11[0] v11[1] v11[2] v11[3] +// v12[0] v12[1] v12[2] v12[3] v13[0] v13[1] v13[2] v13[3] +// v14[0] v14[1] v14[2] v14[3] v15[0] v15[1] v15[2] v15[3] + +REORDER_OUT: + zip1 v1.4s, v16.4s, v18.4s + zip2 v3.4s, v16.4s, v18.4s + zip1 v9.4s, v17.4s, v19.4s + zip2 v11.4s, v17.4s, v19.4s + zip1 v5.4s, v20.4s, v22.4s + zip2 v7.4s, v20.4s, v22.4s + zip1 v13.4s, v21.4s, v23.4s + zip2 v15.4s, v21.4s, v23.4s + trn1 v0.2d, v1.2d, v5.2d + trn2 v2.2d, v1.2d, v5.2d + trn1 v4.2d, v3.2d, v7.2d + trn2 v6.2d, v3.2d, v7.2d + trn1 v8.2d, v9.2d, v13.2d + trn2 v10.2d, v9.2d, v13.2d + trn1 v12.2d, v11.2d, v15.2d + trn2 v14.2d, v11.2d, v15.2d + + zip1 v16.4s, v24.4s, v26.4s + zip2 v17.4s, v24.4s, v26.4s + zip1 v20.4s, v25.4s, v27.4s + zip2 v21.4s, v25.4s, v27.4s + zip1 v18.4s, v28.4s, v30.4s + zip2 v19.4s, v28.4s, v30.4s + zip1 v22.4s, v29.4s, v31.4s + zip2 v23.4s, v29.4s, v31.4s + trn1 v1.2d, v16.2d, v18.2d + trn2 v3.2d, v16.2d, v18.2d + trn1 v5.2d, v17.2d, v19.2d + trn2 v7.2d, v17.2d, v19.2d + trn1 v9.2d, v20.2d, v22.2d + trn2 v11.2d, v20.2d, v22.2d + trn1 v13.2d, v21.2d, v23.2d + trn2 v15.2d, v21.2d, v23.2d + +WRITE_OUT: + st1 {v0.4s, v1.4s}, [x4], x7 + st1 {v2.4s, v3.4s}, [x4], x7 + st1 {v4.4s, v5.4s}, [x4], x7 + st1 {v6.4s, v7.4s}, [x4], x7 + st1 {v8.4s, v9.4s}, [x4], x7 + st1 {v10.4s, v11.4s}, [x4], x7 + st1 {v12.4s, v13.4s}, [x4], x7 + st1 {v14.4s, v15.4s}, [x4] + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/TiledC4MatmulFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/TiledC4MatmulFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..dfb70710ac0f5594860b02bc8307ab172c92f1eb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/TiledC4MatmulFp32.S @@ -0,0 +1,279 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function TiledC4MatmulFp32 +//void TiledC4MatmulFp32(float* dst, const float* src, const float* weight, size_t ic4, size_t cal_num, size_t oc4) +//x0: dst +//x1: src +//x2: weight +//x3: cal_num +//x4: ic4 +//x5: oc4 + +sub sp, sp, #128 +st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] +add x9, sp, #64 +st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + +mov x7, #4 //sizeof(float) +mul x3, x3, x7 +mov x7, #64 +mul x10, x4, x7 + +cmp x5, #2 +blt LoopOcHalf +LoopOc: + mov x8, x1 + subs x9, x4, #1 + + add x6, x2, x10 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + fmul v16.4s, v8.4s, v0.s[0] + fmul v17.4s, v8.4s, v1.s[0] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64 + fmul v18.4s, v8.4s, v2.s[0] + fmul v19.4s, v8.4s, v3.s[0] + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + fmul v20.4s, v8.4s, v4.s[0] + fmul v21.4s, v8.4s, v5.s[0] + fmul v22.4s, v8.4s, v6.s[0] + fmul v23.4s, v8.4s, v7.s[0] + fmul v24.4s, v12.4s, v0.s[0] + fmul v25.4s, v12.4s, v1.s[0] + fmul v26.4s, v12.4s, v2.s[0] + fmul v27.4s, v12.4s, v3.s[0] + fmul v28.4s, v12.4s, v4.s[0] + fmul v29.4s, v12.4s, v5.s[0] + fmul v30.4s, v12.4s, v6.s[0] + fmul v31.4s, v12.4s, v7.s[0] + + beq LoopIcEnd + LoopIc: + add x2, x2, #128 + prfm pldl1keep, [x2] + prfm pldl1keep, [x2, x10] + sub x2, x2, #128 + prfm pldl1keep, [x8, #128] + prfm pldl1keep, [x8, #192] + + fmla v16.4s, v9.4s, v0.s[1] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v2.s[1] + fmla v19.4s, v9.4s, v3.s[1] + fmla v20.4s, v9.4s, v4.s[1] + fmla v21.4s, v9.4s, v5.s[1] + fmla v22.4s, v9.4s, v6.s[1] + fmla v23.4s, v9.4s, v7.s[1] + fmla v24.4s, v13.4s, v0.s[1] + fmla v25.4s, v13.4s, v1.s[1] + fmla v26.4s, v13.4s, v2.s[1] + fmla v27.4s, v13.4s, v3.s[1] + fmla v28.4s, v13.4s, v4.s[1] + fmla v29.4s, v13.4s, v5.s[1] + fmla v30.4s, v13.4s, v6.s[1] + fmla v31.4s, v13.4s, v7.s[1] + + fmla v16.4s, v10.4s, v0.s[2] + fmla v17.4s, v10.4s, v1.s[2] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v3.s[2] + fmla v20.4s, v10.4s, v4.s[2] + fmla v21.4s, v10.4s, v5.s[2] + fmla v22.4s, v10.4s, v6.s[2] + fmla v23.4s, v10.4s, v7.s[2] + fmla v24.4s, v14.4s, v0.s[2] + fmla v25.4s, v14.4s, v1.s[2] + fmla v26.4s, v14.4s, v2.s[2] + fmla v27.4s, v14.4s, v3.s[2] + fmla v28.4s, v14.4s, v4.s[2] + fmla v29.4s, v14.4s, v5.s[2] + fmla v30.4s, v14.4s, v6.s[2] + fmla v31.4s, v14.4s, v7.s[2] + + fmla v16.4s, v11.4s, v0.s[3] + fmla v17.4s, v11.4s, v1.s[3] + fmla v18.4s, v11.4s, v2.s[3] + fmla v19.4s, v11.4s, v3.s[3] + fmla v20.4s, v11.4s, v4.s[3] + fmla v21.4s, v11.4s, v5.s[3] + fmla v22.4s, v11.4s, v6.s[3] + fmla v23.4s, v11.4s, v7.s[3] + fmla v24.4s, v15.4s, v0.s[3] + fmla v25.4s, v15.4s, v1.s[3] + fmla v26.4s, v15.4s, v2.s[3] + fmla v27.4s, v15.4s, v3.s[3] + fmla v28.4s, v15.4s, v4.s[3] + fmla v29.4s, v15.4s, v5.s[3] + fmla v30.4s, v15.4s, v6.s[3] + fmla v31.4s, v15.4s, v7.s[3] + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v1.s[0] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64 + fmla v18.4s, v8.4s, v2.s[0] + fmla v19.4s, v8.4s, v3.s[0] + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + fmla v20.4s, v8.4s, v4.s[0] + fmla v21.4s, v8.4s, v5.s[0] + fmla v22.4s, v8.4s, v6.s[0] + fmla v23.4s, v8.4s, v7.s[0] + fmla v24.4s, v12.4s, v0.s[0] + fmla v25.4s, v12.4s, v1.s[0] + fmla v26.4s, v12.4s, v2.s[0] + fmla v27.4s, v12.4s, v3.s[0] + fmla v28.4s, v12.4s, v4.s[0] + fmla v29.4s, v12.4s, v5.s[0] + fmla v30.4s, v12.4s, v6.s[0] + fmla v31.4s, v12.4s, v7.s[0] + + subs x9, x9, #1 + bne LoopIc + + LoopIcEnd: + fmla v16.4s, v9.4s, v0.s[1] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v2.s[1] + fmla v19.4s, v9.4s, v3.s[1] + fmla v20.4s, v9.4s, v4.s[1] + fmla v21.4s, v9.4s, v5.s[1] + fmla v22.4s, v9.4s, v6.s[1] + fmla v23.4s, v9.4s, v7.s[1] + fmla v24.4s, v13.4s, v0.s[1] + fmla v25.4s, v13.4s, v1.s[1] + fmla v26.4s, v13.4s, v2.s[1] + fmla v27.4s, v13.4s, v3.s[1] + fmla v28.4s, v13.4s, v4.s[1] + fmla v29.4s, v13.4s, v5.s[1] + fmla v30.4s, v13.4s, v6.s[1] + fmla v31.4s, v13.4s, v7.s[1] + + fmla v16.4s, v10.4s, v0.s[2] + fmla v17.4s, v10.4s, v1.s[2] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v3.s[2] + fmla v20.4s, v10.4s, v4.s[2] + fmla v21.4s, v10.4s, v5.s[2] + fmla v22.4s, v10.4s, v6.s[2] + fmla v23.4s, v10.4s, v7.s[2] + fmla v24.4s, v14.4s, v0.s[2] + fmla v25.4s, v14.4s, v1.s[2] + fmla v26.4s, v14.4s, v2.s[2] + fmla v27.4s, v14.4s, v3.s[2] + fmla v28.4s, v14.4s, v4.s[2] + fmla v29.4s, v14.4s, v5.s[2] + fmla v30.4s, v14.4s, v6.s[2] + fmla v31.4s, v14.4s, v7.s[2] + + add x7, x0, #64 + + fmla v16.4s, v11.4s, v0.s[3] + fmla v17.4s, v11.4s, v1.s[3] + fmla v18.4s, v11.4s, v2.s[3] + fmla v19.4s, v11.4s, v3.s[3] + fmla v20.4s, v11.4s, v4.s[3] + fmla v21.4s, v11.4s, v5.s[3] + fmla v22.4s, v11.4s, v6.s[3] + fmla v23.4s, v11.4s, v7.s[3] + fmla v24.4s, v15.4s, v0.s[3] + fmla v25.4s, v15.4s, v1.s[3] + fmla v26.4s, v15.4s, v2.s[3] + fmla v27.4s, v15.4s, v3.s[3] + fmla v28.4s, v15.4s, v4.s[3] + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x3 + fmla v29.4s, v15.4s, v5.s[3] + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x7], x3 + fmla v30.4s, v15.4s, v6.s[3] + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], x3 + mov x2, x6 + fmla v31.4s, v15.4s, v7.s[3] + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x7] + + subs x5, x5, #2 + beq LoopOcEnd + cmp x5, #2 + bge LoopOc + +LoopOcHalf: + mov x8, x1 + mov x9, x4 + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + + LoopIcHalf: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v1.s[0] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64 + fmla v18.4s, v8.4s, v2.s[0] + fmla v19.4s, v8.4s, v3.s[0] + fmla v20.4s, v8.4s, v4.s[0] + fmla v21.4s, v8.4s, v5.s[0] + fmla v22.4s, v8.4s, v6.s[0] + fmla v23.4s, v8.4s, v7.s[0] + + fmla v16.4s, v9.4s, v0.s[1] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v2.s[1] + fmla v19.4s, v9.4s, v3.s[1] + fmla v20.4s, v9.4s, v4.s[1] + fmla v21.4s, v9.4s, v5.s[1] + fmla v22.4s, v9.4s, v6.s[1] + fmla v23.4s, v9.4s, v7.s[1] + + fmla v16.4s, v10.4s, v0.s[2] + fmla v17.4s, v10.4s, v1.s[2] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v3.s[2] + fmla v20.4s, v10.4s, v4.s[2] + fmla v21.4s, v10.4s, v5.s[2] + fmla v22.4s, v10.4s, v6.s[2] + fmla v23.4s, v10.4s, v7.s[2] + + fmla v16.4s, v11.4s, v0.s[3] + fmla v17.4s, v11.4s, v1.s[3] + fmla v18.4s, v11.4s, v2.s[3] + fmla v19.4s, v11.4s, v3.s[3] + fmla v20.4s, v11.4s, v4.s[3] + fmla v21.4s, v11.4s, v5.s[3] + fmla v22.4s, v11.4s, v6.s[3] + fmla v23.4s, v11.4s, v7.s[3] + + subs x9, x9, #1 + bne LoopIcHalf + + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + +LoopOcEnd: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/WinogradTransLeft.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/WinogradTransLeft.S new file mode 100644 index 0000000000000000000000000000000000000000..f79abfc5c1f5cca0b4e19878d9121546337e7779 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/WinogradTransLeft.S @@ -0,0 +1,158 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function WinogradTransLeft +//void WinogradTransLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6:length + +sub sp, sp, #32 +stp x19, x20, [sp] + +mov x8, #16 // 4 * sizeof(float) +mul x8, x6, x8 +mul x9, x3, x8 +sub x9, x9, x8 +add x7, x9, x8 // step for S +mov x10, #4 +mul x10, x4, x10 // step for B + +LoopH: + mov x13, x0 + mov x15, x3 + LoopW: + mov x14, x13 + mov x17, x1 + dup v30.4s, wzr + mov x11, x6 + InitZero: + st1 {v30.4s}, [x2], #16 + subs x11, x11, #1 + bne InitZero + + sub x2, x2, x8 + mov x12, x5 + LoopKStart4: + cmp x12, #4 + blt LoopKStart3 + mov x16, x15 + mov x19, x4 + LoopK4: + ld1 {v0.s}[0], [x17], x10 + ld1 {v0.s}[1], [x17], x10 + ld1 {v0.s}[2], [x17], x10 + ld1 {v0.s}[3], [x17], x10 + mov x11, x6 + mov x20, x17 + add x20, x14, x7 + add x16, x20, x7 + add x19, x16, x7 + + LoopLength4: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x14], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x20], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + ld1 {v21.4s}, [x19], #16 + fmla v17.4s, v21.4s, v0.s[3] + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength4 + + sub x2, x2, x8 + sub x12, x12, #4 + add x14, x19, x9 + cmp x12, #4 + bge LoopK4 + + LoopKStart3: + cmp x12, #3 + blt LoopKStart + mov x16, x15 + LoopK3: + ld1 {v0.s}[0], [x17], x10 + ld1 {v0.s}[1], [x17], x10 + ld1 {v0.s}[2], [x17], x10 + mov x11, x6 + mov x20, x17 + add x20, x14, x7 + add x16, x20, x7 + LoopLength3: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x14], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x20], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength3 + + sub x2, x2, x8 + sub x12, x12, #3 + add x14, x16, x9 + cmp x12, #3 + bge LoopK3 + + LoopKStart: + cmp x12, #0 + beq LKEnd + LoopK: + ld1r {v31.4s}, [x17], x10 + mov x11, x6 + LoopLength: + ld1 {v0.4s}, [x2] + ld1 {v1.4s}, [x14], #16 + fmla v0.4s, v1.4s, v31.4s + st1 {v0.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength + + subs x12, x12, #1 + sub x2, x2, x8 + add x14, x14, x9 + bne LoopK + + LKEnd: + subs x15, x15, #1 + add x13, x13, x8 + add x2, x2, x8 + bne LoopW + + add x1, x1, #4 //sizeof(float) + subs x4, x4, #1 + bne LoopH + + ldp x19, x20, [sp], #32 + ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/WinogradTransRight.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/WinogradTransRight.S new file mode 100644 index 0000000000000000000000000000000000000000..29907d190fa1c43d9da289e1fc7640cb05b8ef8a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/WinogradTransRight.S @@ -0,0 +1,160 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function WinogradTransRight +//void WinogradTransRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6: length + +sub sp, sp, #16 +stp x19, x20, [sp] + +mov x8, #16 // 4 * sizeof(float) +mul x8, x6, x8 +mul x9, x5, x8 // step for S +mov x10, #4 +mul x10, x4, x10 // step for B + +LoopH: + mov x7, x1 + mov x15, x3 + LoopW: + mov x17, x0 + mov x13, x7 + dup v30.4s, wzr + mov x11, x6 + InitZero: + st1 {v30.4s}, [x2], #16 + subs x11, x11, #1 + bne InitZero + sub x2, x2, x8 + mov x12, x5 + + LoopKStart4: + cmp x12, #4 + blt LoopKStart3 + mov x16, x15 + mov x19, x4 + LoopK4: + ld1 {v0.s}[0], [x13], x10 + ld1 {v0.s}[1], [x13], x10 + ld1 {v0.s}[2], [x13], x10 + ld1 {v0.s}[3], [x13], x10 + mov x11, x6 + mov x14, x13 + + add x14, x17, x8 + add x16, x14, x8 + add x19, x16, x8 + + LoopLength4: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x17], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x14], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + ld1 {v21.4s}, [x19], #16 + fmla v17.4s, v21.4s, v0.s[3] + + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength4 + sub x2, x2, x8 + sub x12, x12, #4 + mov x17, x19 + + cmp x12, #4 + bge LoopK4 + + LoopKStart3: + cmp x12, #3 + blt LoopKStart + mov x16, x15 + LoopK3: + ld1 {v0.s}[0], [x13], x10 + ld1 {v0.s}[1], [x13], x10 + ld1 {v0.s}[2], [x13], x10 + mov x11, x6 + mov x14, x13 + + add x14, x17, x8 + add x16, x14, x8 + + LoopLength3: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x17], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x14], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength3 + sub x2, x2, x8 + sub x12, x12, #3 + mov x17, x19 + cmp x12, #3 + bge LoopK3 + + LoopKStart: + cmp x12, #0 + beq LoopKEnd + + LoopK: + ld1r {v31.4s}, [x13], x10 + + mov x11, x6 + LoopLength: + ld1 {v0.4s}, [x2] + ld1 {v1.4s}, [x17], #16 + fmla v0.4s, v1.4s, v31.4s + + st1 {v0.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength + subs x12, x12, #1 + + sub x2, x2, x8 + bne LoopK + LoopKEnd: + subs x15, x15, #1 + add x2, x2, x8 + add x7, x7, #4 //sizeof(float) + bne LoopW + + add x0, x0, x9 + subs x4, x4, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Float16Tofloat32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Float16Tofloat32.S new file mode 100644 index 0000000000000000000000000000000000000000..e0121e9352c05a1dfc1bc072bd220cb0e0dcf9de --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Float16Tofloat32.S @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + .text + .align 5 + .global Float16ToFloat32 +#ifndef __APPLE__ + .type Float16ToFloat32, %function +#endif + +// void Float16ToFloat32(const float16_t *input, float *output, int number); +// r0: input, r1: output, r2: number +Float16ToFloat32: + cmp r2, #0 + beq LoopEnd + cmp r2, #16 + bge Loop16 + cmp r2, #8 + bge Loop8 + b Loop + Loop16: + vld1.16 {q0, q1}, [r0]! + vcvt.f32.f16 q3, d0 + vcvt.f32.f16 q4, d1 + vcvt.f32.f16 q5, d2 + vst1.32 {q3, q4}, [r1]! + vcvt.f32.f16 q6, d3 + subs r2, r2, #16 + vst1.32 {q5, q6}, [r1]! + beq LoopEnd + cmp r2, #16 + bge Loop16 + cmp r2, #8 + bge Loop8 + b Loop + Loop8: + vld1.16 {q0}, [r0]! + vcvt.f32.f16 q1, d0 + vcvt.f32.f16 q2, d1 + vst1.32 {q1, q2}, [r1]! + subs r2, r2, #8 + beq LoopEnd + cmp r2, #8 + bge Loop8 + b Loop + Loop: + vldr.16 s0, [r0] + vcvtb.f32.f16 s0, s0 + vstr.32 s0, [r1] + add r0, r0, #2 + add r1, r1, #4 + subs r2, r2, #1 + bgt Loop + LoopEnd: + mov pc, lr +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Float32ToFloat16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Float32ToFloat16.S new file mode 100644 index 0000000000000000000000000000000000000000..85ac9d7bb825c119c3aeb47596a0207348881c35 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Float32ToFloat16.S @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + .text + .align 5 + .global Float32ToFloat16 +#ifndef __APPLE__ + .type Float32ToFloat16, %function +#endif + +// void Float32ToFloat16(const float *input, float16_t *output, int number); +// r0: input, r1: output, r2: number +Float32ToFloat16: + cmp r2, #0 + beq LoopEnd + cmp r2, #16 + bge Loop16 + cmp r2, #8 + bge Loop8 + b Loop + Loop16: + vld1.32 {q0, q1}, [r0]! + vcvt.f16.f32 d0, q0 + vcvt.f16.f32 d1, q1 + vld1.32 {q2, q3}, [r0]! + vcvt.f16.f32 d2, q2 + vcvt.f16.f32 d3, q3 + vst1.16 {q0, q1}, [r1]! + subs r2, r2, #16 + beq LoopEnd + cmp r2, #16 + bge Loop16 + cmp r2, #8 + bge Loop8 + b Loop + Loop8: + vld1.32 {q0, q1}, [r0]! + vcvt.f16.f32 d0, q0 + vcvt.f16.f32 d1, q1 + vst1.16 {q0}, [r1]! + subs r2, r2, #8 + beq LoopEnd + cmp r2, #8 + bge Loop8 + b Loop + Loop: + vldr s0, [r0] + vcvtb.f16.f32 s0, s0 + vstr.16 s0, [r1] + add r0, r0, #4 + add r1, r1, #2 + subs r2, r2, #1 + bgt Loop + LoopEnd: + mov pc, lr +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/MatVecMulFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/MatVecMulFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..1fed588a8402179dfb831327eec4b5e8d2979dcd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/MatVecMulFp16.S @@ -0,0 +1,237 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatVecMulA32NeonFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int col) { +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: col + +asm_function MatVecMulA32NeonFp16 + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r9, r10, r11, lr} + add sp, sp, #52 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + + add r10, r5, r5 // stride = depth * sizeof(float16_t) + mov lr, #4 + mul r11, r10, lr // stride x 4 + + cmp r6, #4 + blt Col1Loop + +Col4Loop: + mov r7, r0 // reload a(vector) ptr + mov r9, r1 // reload b(matrix) ptr + mov r8, r5 // reload depth value + + veor q9, q9, q9 + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q15, q15, q15 + + cmp r8, #8 + bge Col4Depth8 + cmp r8, #4 + bge Col4Depth4 + cmp r8, #1 + bge Col4Depth1 + b Col4End + + Col4Depth8: + vld1.16 {q8}, [r7]! + add lr, r9, r10 + vld1.16 {q0}, [r9]! + vld1.16 {q1}, [lr], r10 + vld1.16 {q2}, [lr], r10 + vld1.16 {q3}, [lr] + + vmla.f16 q9, q8, q0 + vmla.f16 q10, q8, q1 + vmla.f16 q11, q8, q2 + vmla.f16 q12, q8, q3 + sub r8, r8, #8 + cmp r8, #8 + bge Col4Depth8 + cmp r8, #4 + bge Col4Depth4 + b AddC4 + + Col4Depth4: + vld1.16 {d16}, [r7]! + add lr, r9, r10 + vld1.16 {d0}, [r9]! + vld1.16 {d2}, [lr], r10 + vld1.16 {d4}, [lr], r10 + vld1.16 {d6}, [lr] + + vmla.f16 d18, d16, d0 + vmla.f16 d20, d16, d2 + vmla.f16 d22, d16, d4 + vmla.f16 d24, d16, d6 + sub r8, r8, #4 + cmp r8, #4 + bge Col4Depth4 + + AddC4: + vpadd.f16 d0, d18, d19 + vpadd.f16 d1, d20, d21 + vpadd.f16 d2, d22, d23 + vpadd.f16 d4, d24, d25 + vpadd.f16 d30, d0, d1 + vpadd.f16 d31, d2, d4 + vpadd.f16 d30, d30, d31 + cmp r8, #1 + bge Col4Depth1 + b Col4End + + Col4Depth1: + vld1.16 {d0[0]}, [r7]! + add lr, r9, r10 + vld1.16 {d2[0]}, [r9]! + vld1.16 {d2[1]}, [lr], r10 + vld1.16 {d2[2]}, [lr], r10 + vld1.16 {d2[3]}, [lr] + + vmla.f16 d30, d2, d0[0] + subs r8, r8, #1 + bne Col4Depth1 + + Col4End: + cmp r3, #0 + beq Col4Activation + vld1.16 {d26}, [r3]! + vadd.f16 d30, d30, d26 + + Col4Activation: + cmp r4, #3 + beq Col4Relu6 + cmp r4, #1 + beq Col4Relu + b Col4Write + + Col4Relu6: + vmov.i16 q12, #6 + vcvt.f16.s16 q12, q12 + vmin.f16 d30, d30, d24 + + Col4Relu: + veor q13, q13, q13 + vmax.f16 d30, d30, d26 + + Col4Write: + vst1.16 {d30}, [r2]! + subs r6, r6, #4 + beq End + add r1, r1, r11 + cmp r6, #4 + bge Col4Loop + +Col1Loop: + mov r7, r0 // reload a(vector) ptr + mov r9, r1 // reload b(matrix) ptr + mov r8, r5 // reload depth value + veor q10, q10, q10 + veor q15, q15, q15 + + cmp r8, #8 + bge Col1Depth8 + cmp r8, #4 + bge Col1Depth4 + cmp r8, #1 + bge Col1Depth1 + b Col1End + + Col1Depth8: + vld1.16 {q0}, [r7]! + vld1.16 {q1}, [r9]! + vmla.f16 q10, q1, q0 + sub r8, r8, #8 + cmp r8, #8 + bge Col1Depth8 + cmp r8, #4 + bge Col1Depth4 + b AddC1 + + Col1Depth4: + vld1.16 {d0}, [r7]! + vld1.16 {d2}, [r9]! + vmla.f16 d20, d2, d0 + sub r8, r8, #4 + cmp r8, #4 + bge Col1Depth4 + + AddC1: + vpadd.f16 d30, d20, d21 + vpadd.f16 d30, d30, d20 + vpadd.f16 d30, d30, d20 + cmp r8, #1 + bge Col1Depth1 + b Col1End + + Col1Depth1: + vld1.16 {d0[0]}, [r7]! + vld1.16 {d2[0]}, [r9]! + vmla.f16 d30, d2, d0[0] + subs r8, r8, #1 + bne Col1Depth1 + + Col1End: + cmp r3, #0 + beq Col1Activation + vld1.16 {d28[0]}, [r3]! + vadd.f16 d30, d30, d28 + + Col1Activation: + cmp r4, #3 + beq Col1Relu6 + cmp r4, #1 + beq Col1Relu + b Col1Write + + Col1Relu6: + vmov.i16 d26, #6 + vcvt.f16.s16 d26, d26 + vmin.f16 d30, d30, d26 + + Col1Relu: + veor d24, d24, d24 + vmax.f16 d30, d30, d24 + + Col1Write: + vst1.16 {d30[0]}, [r2]! + subs r6, r6, #1 + beq End + add r1, r1, r10 + b Col1Loop + +End: + sub sp, sp, #52 + pop {r0-r8, r9, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S new file mode 100644 index 0000000000000000000000000000000000000000..781b8c3b4070be800ced1bbf3e93a6310818ae1c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S @@ -0,0 +1,617 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + .text + .align 5 + .global MatMul12x8A32Fp16 +#ifndef __APPLE__ + .type MatMul12x8A32Fp16, %function +#endif + +// void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, +// int deep, int row, int col, int stride, bool write_mode); +// r0: a +// r1: b +// r2: dst +// r3: bias +// #4: depth +// #8: row +// #12: col +// #16: stride +// #20: writeNhwc/writeWino + +asm_function MatMul12x8A32Fp16 + // r13(sp) and r15(pc) can not be used!! + // r9 r4 is tmp register + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r3-r11, lr} + vpush {q4-q7} + add sp, sp, #104 + + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + ldr lr, [sp, #20] + + mov r10, r1 // b + mov r11, r0 // a + mov r12, r2 // dst + + cmp lr, #2 + bne NoWinograd + mul r4, r8, r7 // stride * col + add r4, r4, r4 // r4 * sizeof(float16_t) + mov r9, #16 + mul r9, r8, r9 // stride * 8 * sizeof(float16_t) +NoWinograd: + add r8, r8, r8 // stride * sizeof(float16_t) + +a .req r0 +weight .req r1 +dst .req r2 +bias .req r3 +depth .req r5 +row .req r6 +col .req r7 +stride .req r8 +b_tmp .req r10 +a_tmp .req r11 +dst_tmp .req r12 + +.macro STORE_12x8 p1 + vst1.16 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_12x7 p1, p2, p3 + add r4, dst, #8 + add r9, dst, #12 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + vst1.16 {\p3}, [r9] + add dst, dst, stride +.endm + +.macro STORE_12x6 p1, p2 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + add dst, dst, stride +.endm + +.macro STORE_12x5 p1, p2 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride +.endm + +.macro STORE_12x4 p1 + vst1.16 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_12x3 p1, p2 + add r4, dst, #4 + vst1.32 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride +.endm + +.macro STORE_12x2 p1 + vst1.32 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_12x1 p1 + vst1.16 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_C8 p1, p2 + vst1.16 {\p1}, [dst] + cmp row, \p2 + add dst, dst, stride + beq WriteEnd +.endm + +.macro STORE_C7 p1, p2, p3, p4 + add r4, dst, #8 + add r9, dst, #12 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + vst1.16 {\p3}, [r9] + add dst, dst, stride + cmp row, \p4 + beq WriteEnd +.endm + +.macro STORE_C6 p1, p2, p3 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + add dst, dst, stride + cmp row, \p3 + beq WriteEnd +.endm + +.macro STORE_C5 p1, p2, p3 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride + cmp row, \p3 + beq WriteEnd +.endm + +.macro STORE_C4 p1, p2 + vst1.16 {\p1}, [dst] + cmp row, \p2 + add dst, dst, stride + beq WriteEnd +.endm + +.macro STORE_C3 p1, p2, p3 + add r4, dst, #4 + vst1.32 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride + cmp row, \p3 + beq WriteEnd +.endm + +.macro STORE_C2 p1, p2 + vst1.32 {\p1}, [dst] + add dst, dst, stride + cmp row, \p2 + beq WriteEnd +.endm + +.macro STORE_C1 p1, p2 + vst1.16 {\p1}, [dst] + add dst, dst, stride + cmp row, \p2 + beq WriteEnd +.endm + +LoopRow12: + ldr bias, [sp, #-40] + LoopCol8: + mov dst, dst_tmp + mov a, a_tmp + ldr depth, [sp, #4] + veor q4, q4, q4 + veor q5, q5, q5 + veor q6, q6, q6 + veor q7, q7, q7 + veor q8, q8, q8 + veor q9, q9, q9 + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + LoopDepth: + vld1.16 {q0, d2}, [a]! + vld1.16 {q2}, [weight]! + vmla.f16 q4, q2, d0[0] + vmla.f16 q5, q2, d0[1] + vmla.f16 q6, q2, d0[2] + vmla.f16 q7, q2, d0[3] + vmla.f16 q8, q2, d1[0] + vmla.f16 q9, q2, d1[1] + vmla.f16 q10, q2, d1[2] + vmla.f16 q11, q2, d1[3] + vmla.f16 q12, q2, d2[0] + vmla.f16 q13, q2, d2[1] + vmla.f16 q14, q2, d2[2] + vmla.f16 q15, q2, d2[3] + + subs depth, depth, #1 + bne LoopDepth + + Bias: + cmp bias, #0 + beq Activation + vld1.16 {q0}, [bias]! + vadd.f16 q4, q4, q0 + vadd.f16 q5, q5, q0 + vadd.f16 q6, q6, q0 + vadd.f16 q7, q7, q0 + vadd.f16 q8, q8, q0 + vadd.f16 q9, q9, q0 + vadd.f16 q10, q10, q0 + vadd.f16 q11, q11, q0 + vadd.f16 q12, q12, q0 + vadd.f16 q13, q13, q0 + vadd.f16 q14, q14, q0 + vadd.f16 q15, q15, q0 + + Activation: + ldr lr, [sp] + cmp lr, #3 + beq Relu6 + cmp lr, #1 + beq Relu + b Write + + Relu6: + vmov.i16 q2, #0x4600 + vadd.f16 q4, q4, q2 + vadd.f16 q5, q5, q2 + vadd.f16 q6, q6, q2 + vadd.f16 q7, q7, q2 + vmin.f16 q8, q8, q2 + vmin.f16 q9, q9, q2 + vmin.f16 q10, q10, q2 + vmin.f16 q11, q11, q2 + vmin.f16 q12, q12, q2 + vmin.f16 q13, q13, q2 + vmin.f16 q14, q14, q2 + vmin.f16 q15, q15, q2 + + Relu: + veor q3, q3, q3 + vmax.f16 q4, q4, q3 + vmax.f16 q5, q5, q3 + vmax.f16 q6, q6, q3 + vmax.f16 q7, q7, q3 + vmax.f16 q8, q8, q3 + vmax.f16 q9, q9, q3 + vmax.f16 q10, q10, q3 + vmax.f16 q11, q11, q3 + vmax.f16 q12, q12, q3 + vmax.f16 q13, q13, q3 + vmax.f16 q14, q14, q3 + vmax.f16 q15, q15, q3 + + Write: + ldr lr, [sp, #20] + cmp lr, #2 + beq WriteWinograd + cmp row, #12 + bge Write12xCol + b WriteRowxCol + + WriteWinograd: + vst1.16 {q4}, [dst] + add dst, dst, r4 + vst1.16 {q5}, [dst] + add dst, dst, r4 + vst1.16 {q6}, [dst] + add dst, dst, r4 + vst1.16 {q7}, [dst] + add dst, dst, r4 + vst1.16 {q8}, [dst] + add dst, dst, r4 + vst1.16 {q9}, [dst] + add dst, dst, r4 + vst1.16 {q10}, [dst] + add dst, dst, r4 + vst1.16 {q11}, [dst] + add dst, dst, r4 + vst1.16 {q12}, [dst] + add dst, dst, r4 + vst1.16 {q13}, [dst] + add dst, dst, r4 + vst1.16 {q14}, [dst] + add dst, dst, r4 + vst1.16 {q15}, [dst] + add dst_tmp, dst_tmp, r9 + b WriteEnd + Write12xCol: + cmp col, #8 + bge Write12x8 + cmp col, #1 + beq Write12x1 + cmp col, #2 + beq Write12x2 + cmp col, #3 + beq Write12x3 + cmp col, #4 + beq Write12x4 + cmp col, #5 + beq Write12x5 + cmp col, #6 + beq Write12x6 + b Write12x7 + + WriteRowxCol: + cmp col, #8 + bge WriteRowx8 + cmp col, #1 + beq WriteRowx1 + cmp col, #2 + beq WriteRowx2 + cmp col, #3 + beq WriteRowx3 + cmp col, #4 + beq WriteRowx4 + cmp col, #5 + beq WriteRowx5 + cmp col, #6 + beq WriteRowx6 + b WriteRowx7 + + Write12x8: + STORE_12x8 q4 + STORE_12x8 q5 + STORE_12x8 q6 + STORE_12x8 q7 + STORE_12x8 q8 + STORE_12x8 q9 + STORE_12x8 q10 + STORE_12x8 q11 + STORE_12x8 q12 + STORE_12x8 q13 + STORE_12x8 q14 + STORE_12x8 q15 + b WriteEnd + WriteRowx8: + STORE_C8 q4, #1 + STORE_C8 q5, #2 + STORE_C8 q6, #3 + STORE_C8 q7, #4 + STORE_C8 q8, #5 + STORE_C8 q9, #6 + STORE_C8 q10, #7 + STORE_C8 q11, #8 + STORE_C8 q12, #9 + STORE_C8 q13, #10 + STORE_C8 q14, #11 + STORE_C8 q15, #12 + b WriteEnd + + Write12x1: + STORE_12x1 d8[0] + STORE_12x1 d10[0] + STORE_12x1 d12[0] + STORE_12x1 d14[0] + STORE_12x1 d16[0] + STORE_12x1 d18[0] + STORE_12x1 d20[0] + STORE_12x1 d22[0] + STORE_12x1 d24[0] + STORE_12x1 d26[0] + STORE_12x1 d28[0] + STORE_12x1 d30[0] + b WriteEnd + WriteRowx1: + STORE_C1 d8[0], #1 + STORE_C1 d10[0], #2 + STORE_C1 d12[0], #3 + STORE_C1 d14[0], #4 + STORE_C1 d16[0], #5 + STORE_C1 d18[0], #6 + STORE_C1 d20[0], #7 + STORE_C1 d22[0], #8 + STORE_C1 d24[0], #9 + STORE_C1 d26[0], #10 + STORE_C1 d28[0], #11 + STORE_C1 d30[0], #12 + b WriteEnd + + Write12x2: + STORE_12x2 d8[0] + STORE_12x2 d10[0] + STORE_12x2 d12[0] + STORE_12x2 d14[0] + STORE_12x2 d16[0] + STORE_12x2 d18[0] + STORE_12x2 d20[0] + STORE_12x2 d22[0] + STORE_12x2 d24[0] + STORE_12x2 d26[0] + STORE_12x2 d28[0] + STORE_12x2 d30[0] + b WriteEnd + WriteRowx2: + STORE_C2 d8[0], #1 + STORE_C2 d10[0], #2 + STORE_C2 d12[0], #3 + STORE_C2 d14[0], #4 + STORE_C2 d16[0], #5 + STORE_C2 d18[0], #6 + STORE_C2 d20[0], #7 + STORE_C2 d22[0], #8 + STORE_C2 d24[0], #9 + STORE_C2 d26[0], #10 + STORE_C2 d28[0], #11 + STORE_C2 d30[0], #12 + b WriteEnd + + Write12x3: + STORE_12x3 d8[0], d8[2] + STORE_12x3 d10[0], d10[2] + STORE_12x3 d12[0], d12[2] + STORE_12x3 d14[0], d14[2] + STORE_12x3 d16[0], d16[2] + STORE_12x3 d18[0], d18[2] + STORE_12x3 d20[0], d20[2] + STORE_12x3 d22[0], d22[2] + STORE_12x3 d24[0], d24[2] + STORE_12x3 d26[0], d26[2] + STORE_12x3 d28[0], d28[2] + STORE_12x3 d30[0], d30[2] + b WriteEnd + WriteRowx3: + STORE_C3 d8[0], d8[2], #1 + STORE_C3 d10[0], d10[2], #2 + STORE_C3 d12[0], d12[2], #3 + STORE_C3 d14[0], d14[2], #4 + STORE_C3 d16[0], d16[2], #5 + STORE_C3 d18[0], d18[2], #6 + STORE_C3 d20[0], d20[2], #7 + STORE_C3 d22[0], d22[2], #8 + STORE_C3 d24[0], d24[2], #9 + STORE_C3 d26[0], d26[2], #10 + STORE_C3 d28[0], d28[2], #11 + STORE_C3 d30[0], d30[2], #12 + b WriteEnd + + Write12x4: + STORE_12x4 d8 + STORE_12x4 d10 + STORE_12x4 d12 + STORE_12x4 d14 + STORE_12x4 d16 + STORE_12x4 d18 + STORE_12x4 d20 + STORE_12x4 d22 + STORE_12x4 d24 + STORE_12x4 d26 + STORE_12x4 d28 + STORE_12x4 d30 + b WriteEnd + WriteRowx4: + STORE_C4 d8, #1 + STORE_C4 d10, #2 + STORE_C4 d12, #3 + STORE_C4 d14, #4 + STORE_C4 d16, #5 + STORE_C4 d18, #6 + STORE_C4 d20, #7 + STORE_C4 d22, #8 + STORE_C4 d24, #9 + STORE_C4 d26, #10 + STORE_C4 d28, #11 + STORE_C4 d30, #12 + b WriteEnd + + Write12x5: + STORE_12x5 d8, d9[0] + STORE_12x5 d10, d11[0] + STORE_12x5 d12, d13[0] + STORE_12x5 d14, d15[0] + STORE_12x5 d16, d17[0] + STORE_12x5 d18, d19[0] + STORE_12x5 d20, d21[0] + STORE_12x5 d22, d23[0] + STORE_12x5 d24, d25[0] + STORE_12x5 d26, d27[0] + STORE_12x5 d28, d29[0] + STORE_12x5 d30, d31[0] + b WriteEnd + WriteRowx5: + STORE_C5 d8, d9[0], #1 + STORE_C5 d10, d11[0], #2 + STORE_C5 d12, d13[0], #3 + STORE_C5 d14, d15[0], #4 + STORE_C5 d16, d17[0], #5 + STORE_C5 d18, d19[0], #6 + STORE_C5 d20, d21[0], #7 + STORE_C5 d22, d23[0], #8 + STORE_C5 d24, d25[0], #9 + STORE_C5 d26, d27[0], #10 + STORE_C5 d28, d29[0], #11 + STORE_C5 d30, d31[0], #12 + b WriteEnd + + Write12x6: + STORE_12x6 d8, d9[0] + STORE_12x6 d10, d11[0] + STORE_12x6 d12, d13[0] + STORE_12x6 d14, d15[0] + STORE_12x6 d16, d17[0] + STORE_12x6 d18, d19[0] + STORE_12x6 d20, d21[0] + STORE_12x6 d22, d23[0] + STORE_12x6 d24, d25[0] + STORE_12x6 d26, d27[0] + STORE_12x6 d28, d29[0] + STORE_12x6 d30, d31[0] + b WriteEnd + WriteRowx6: + STORE_C6 d8, d9[0], #1 + STORE_C6 d10, d11[0], #2 + STORE_C6 d12, d13[0], #3 + STORE_C6 d14, d15[0], #4 + STORE_C6 d16, d17[0], #5 + STORE_C6 d18, d19[0], #6 + STORE_C6 d20, d21[0], #7 + STORE_C6 d22, d23[0], #8 + STORE_C6 d24, d25[0], #9 + STORE_C6 d26, d27[0], #10 + STORE_C6 d28, d29[0], #11 + STORE_C6 d30, d31[0], #12 + b WriteEnd + + Write12x7: + STORE_12x7 d8, d9[0], d9[2] + STORE_12x7 d10, d11[0], d11[2] + STORE_12x7 d12, d13[0], d13[2] + STORE_12x7 d14, d15[0], d15[2] + STORE_12x7 d16, d17[0], d17[2] + STORE_12x7 d18, d19[0], d19[2] + STORE_12x7 d20, d21[0], d21[2] + STORE_12x7 d22, d23[0], d23[2] + STORE_12x7 d24, d25[0], d25[2] + STORE_12x7 d26, d27[0], d27[2] + STORE_12x7 d28, d29[0], d29[2] + STORE_12x7 d30, d31[0], d31[2] + b WriteEnd + WriteRowx7: + STORE_C7 d8, d9[0], d9[2], #1 + STORE_C7 d10, d11[0], d11[2], #2 + STORE_C7 d12, d13[0], d13[2], #3 + STORE_C7 d14, d15[0], d15[2], #4 + STORE_C7 d16, d17[0], d17[2], #5 + STORE_C7 d18, d19[0], d19[2], #6 + STORE_C7 d20, d21[0], d21[2], #7 + STORE_C7 d22, d23[0], d23[2], #8 + STORE_C7 d24, d25[0], d25[2], #9 + STORE_C7 d26, d27[0], d27[2], #10 + STORE_C7 d28, d29[0], d29[2], #11 + STORE_C7 d30, d31[0], d31[2], #12 + b WriteEnd + + WriteEnd: + cmp col, #8 + ble LoopColEnd + sub col, col, #8 + ldr lr, [sp, #20] + cmp lr, #2 + beq LoopCol8 + add dst_tmp, dst_tmp, #16 + b LoopCol8 + LoopColEnd: + cmp row, #12 + ble LoopRowEnd + sub row, row, #12 + mov a_tmp, a + mov weight, b_tmp + ldr lr, [sp, #20] + cmp lr, #2 + beq WinogradDst + ldr lr, [sp, #12] + sub lr, lr, col + add lr, lr, lr // col *= 2 + sub dst_tmp, dst, lr + b LoopRow + WinogradDst: + add dst_tmp, dst, r9 + LoopRow: + mov dst, dst_tmp + ldr col, [sp, #12] + b LoopRow12 +LoopRowEnd: + sub sp, sp, #104 + vpop {q4-q7} + pop {r3-r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/TiledC4MatmulFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/TiledC4MatmulFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..fa32c368f1edf74e7780ba9c12197a1154dde620 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/TiledC4MatmulFp16.S @@ -0,0 +1,108 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function TiledC4MatmulFp16 +// void TiledC4MatmulFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t cal_num, size_t ic4, +// size_t oc4); +// r0: dst +// r1: src +// r2: weight +// r3: cal_num +// r4(sp): ic4 +// r5(sp + #4): oc4 +push {r4-r11, lr} +vpush {q4-q7} +add sp, sp, #100 +ldr r4, [sp] +ldr r5, [sp, #4] // oc4 +add r3, r3, r3 +mov r7, r1 + +cmp r5, #1 +blt LoopOCEnd +cmp r4, #1 +blt LoopICEnd +LoopOC: + ldr r4, [sp] + veor q15, q15, q15 + veor q14, q14, q14 + veor q13, q13, q13 + veor q12, q12, q12 + LoopIC: + vld1.16 {q4, q5}, [r2]! // weight + vld1.16 {q2, q3}, [r1]! // 16 number src + vmla.f16 d24, d8, d4[0] + vmla.f16 d24, d9, d4[1] + vmla.f16 d24, d10, d4[2] + vmla.f16 d24, d11, d4[3] + + vmla.f16 d25, d8, d5[0] + vmla.f16 d25, d9, d5[1] + vmla.f16 d25, d10, d5[2] + vmla.f16 d25, d11, d5[3] + + vmla.f16 d26, d8, d6[0] + vmla.f16 d26, d9, d6[1] + vmla.f16 d26, d10, d6[2] + vmla.f16 d26, d11, d6[3] + + vmla.f16 d27, d8, d7[0] + vmla.f16 d27, d9, d7[1] + vmla.f16 d27, d10, d7[2] + vmla.f16 d27, d11, d7[3] + + vld1.16 {q0, q1}, [r1]! // 16 number src + vmla.f16 d28, d8, d0[0] + vmla.f16 d28, d9, d0[1] + vmla.f16 d28, d10, d0[2] + vmla.f16 d28, d11, d0[3] + + vmla.f16 d29, d8, d1[0] + vmla.f16 d29, d9, d1[1] + vmla.f16 d29, d10, d1[2] + vmla.f16 d29, d11, d1[3] + + vmla.f16 d30, d8, d2[0] + vmla.f16 d30, d9, d2[1] + vmla.f16 d30, d10, d2[2] + vmla.f16 d30, d11, d2[3] + + vmla.f16 d31, d8, d3[0] + vmla.f16 d31, d9, d3[1] + vmla.f16 d31, d10, d3[2] + vmla.f16 d31, d11, d3[3] + + subs r4, r4, #1 + bne LoopIC + b LoopICEnd + LoopICEnd: + mov lr, r0 + vst1.16 {q12, q13}, [lr]! + vst1.16 {q14, q15}, [lr]! + add r0, r0, r3 // dst += cal_num + mov r1, r7 + subs r5, r5, #1 + bne LoopOC +LoopOCEnd: + sub sp, sp, #100 + vpop {q4-q7} + pop {r4-r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/WinogradTransLeft.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/WinogradTransLeft.S new file mode 100644 index 0000000000000000000000000000000000000000..334ff48e7d7570f350199f223ebaad30d569414a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/WinogradTransLeft.S @@ -0,0 +1,165 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void WinogradTransLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, +// size_t length); +//r0: S +//r1: B +//r2: M +//r3: w +//r4: h +//r5: k +//r6: length +asm_function WinogradTransLeftFp16 + push {r0, r3, r4-r11, lr} + vpush {q4-q7} + add sp, sp, #108 + ldr r4, [sp] + ldr r6, [sp, #8] + + mov r8, #8 // 4 * sizeof(float16_t) + mul r8, r6, r8 // length * 4 * 2 + mul r7, r3, r8 // step for S + add r10, r4, r4 // step for B + +cmp r4, #1 +blt LoopHEnd +cmp r3, #1 +blt LoopHEnd +LoopH: + ldr r3, [sp, #-40] // w + ldr r0, [sp, #-44] + LoopW: + mov r11, r0 // S + mov lr, r1 // B_src + veor q6, q6, q6 + ldr r6, [sp, #8] + InitZero: + vst1.16 {d12}, [r2]! + subs r6, r6, #1 + bne InitZero + sub r2, r2, r8 + + ldr r5, [sp, #4] + cmp r5, #4 + bge LoopK4 + cmp r5, #3 + bge LoopK3 + cmp r5, #1 + bge LoopK1 + b LoopKEnd + + LoopK4: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + vld1.16 {d3[0]}, [lr], r10 + vld1.16 {d5[0]}, [lr], r10 + vld1.16 {d7[0]}, [lr], r10 + + add r12, r11, r7 + add r14, r12, r7 + add r9, r14, r7 + LoopK4L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vld1.16 {d2}, [r12]! + vmla.f16 d12, d0, d1[0] + vld1.16 {d4}, [r14]! + vmla.f16 d12, d2, d3[0] + vld1.16 {d6}, [r9]! + vmla.f16 d12, d4, d5[0] + vmla.f16 d12, d6, d7[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK4L4 + + subs r5, r5, #4 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + sub r9, r9, r8 + add r11, r9, r7 + cmp r5, #4 + bge LoopK4 + cmp r5, #3 + bge LoopK3 + b LoopK1 + + LoopK3: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + vld1.16 {d3[0]}, [lr], r10 + vld1.16 {d5[0]}, [lr], r10 + + add r12, r11, r7 + add r9, r12, r7 + LoopK3L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vld1.16 {d2}, [r12]! + vmla.f16 d12, d0, d1[0] + vld1.16 {d4}, [r9]! + vmla.f16 d12, d2, d3[0] + vmla.f16 d12, d4, d5[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK3L4 + + subs r5, r5, #3 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + sub r9, r9, r8 + add r11, r9, r7 + cmp r5, #3 + bge LoopK3 + b LoopK1 + + LoopK1: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + + LoopK1L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vmla.f16 d12, d0, d1[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK1L4 + + subs r5, r5, #1 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + sub r11, r11, r8 + add r11, r11, r7 + b LoopK1 + LoopKEnd: + add r0, r0, r8 // S += unitstep + subs r3, r3, #1 + bne LoopW + LoopWEnd: + subs r4, r4, #1 + beq LoopHEnd + add r1, r1, #2 // B += 1 + b LoopH +LoopHEnd: + sub sp, sp, #108 + vpop {q4-q7} + pop {r0, r3, r4-r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/WinogradTransRight.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/WinogradTransRight.S new file mode 100644 index 0000000000000000000000000000000000000000..cb3a297adb8f9ae36f8a7aea714f36720c7922e9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/WinogradTransRight.S @@ -0,0 +1,163 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void WinogradTransRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, +// size_t length); +//r0: S +//r1: B +//r2: M +//r3: w +//r4: h +//r5: k +//r6: length +asm_function WinogradTransRightFp16 + push {r1, r3, r4-r11, lr} + vpush {q4-q7} + add sp, sp, #108 + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + + mov r8, #8 // 4 * sizeof(float16_t) + mul r8, r6, r8 // length * 4 * 2 + mul r7, r5, r8 // step for S = k * unitStep * 4 + add r10, r4, r4 // step for B = 2 * h + +cmp r4, #1 +blt LoopHEnd +cmp r3, #1 +blt LoopHEnd +LoopH: + ldr r3, [sp, #-40] // w + ldr r1, [sp, #-44] + LoopW: + mov r11, r0 // S + mov lr, r1 // B_src + veor q6, q6, q6 + ldr r6, [sp, #8] + InitZero: + vst1.16 {d12}, [r2]! + subs r6, r6, #1 + bne InitZero + sub r2, r2, r8 + + ldr r5, [sp, #4] + cmp r5, #4 + bge LoopK4 + cmp r5, #3 + bge LoopK3 + cmp r5, #1 + bge LoopK1 + b LoopKEnd + + LoopK4: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + vld1.16 {d3[0]}, [lr], r10 + vld1.16 {d5[0]}, [lr], r10 + vld1.16 {d7[0]}, [lr], r10 + + add r12, r11, r8 + add r14, r12, r8 + add r9, r14, r8 + LoopK4L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vld1.16 {d2}, [r12]! + vmla.f16 d12, d0, d1[0] + vld1.16 {d4}, [r14]! + vmla.f16 d12, d2, d3[0] + vld1.16 {d6}, [r9]! + vmla.f16 d12, d4, d5[0] + vmla.f16 d12, d6, d7[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK4L4 + + subs r5, r5, #4 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + mov r11, r9 + cmp r5, #4 + bge LoopK4 + cmp r5, #3 + bge LoopK3 + b LoopK1 + + LoopK3: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + vld1.16 {d3[0]}, [lr], r10 + vld1.16 {d5[0]}, [lr], r10 + + add r12, r11, r8 + add r9, r12, r8 + LoopK3L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vld1.16 {d2}, [r12]! + vmla.f16 d12, d0, d1[0] + vld1.16 {d4}, [r9]! + vmla.f16 d12, d2, d3[0] + vmla.f16 d12, d4, d5[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK3L4 + + subs r5, r5, #3 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + mov r11, r9 + cmp r5, #3 + bge LoopK3 + b LoopK1 + + LoopK1: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + + LoopK1L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vmla.f16 d12, d0, d1[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK1L4 + + subs r5, r5, #1 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + b LoopK1 + + LoopKEnd: + add r1, r1, #2 // B[x] + subs r3, r3, #1 + bne LoopW + LoopWEnd: + add r0, r0, r7 + subs r4, r4, #1 + beq LoopHEnd + b LoopH +LoopHEnd: + sub sp, sp, #108 + vpop {q4-q7} + pop {r1, r3, r4-r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32Avx3x3.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32Avx3x3.S new file mode 100644 index 0000000000000000000000000000000000000000..0e9eac3ccb542b683b849123d70f5890ae46e9ec --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32Avx3x3.S @@ -0,0 +1,313 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/assembly_global.h" +.text +.align 4 + +// void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels, size_t output_width, +// size_t input_stride, size_t relum, szie_t relu6) +// in linux x64 platform: +// rdi: output +// rsi: input +// rdx: weights +// rcx: bias +// r8: channels +// r9: output_width +// 8: input_stride +// 16: relu +// 24: relu6 + +// in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: output +// rdx: input +// r8: weights +// r9: bias +// 40: channels +// 48: output_width +// 56: input_stride +// 64: relu +// 72: relu6 +asm_function ConvDwFp32Avx3x3 + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rbx + pushq %rbp + pushq %r9 // -56 + pushq %r8 // -64 + pushq %rcx // -72 + pushq %rdx // -80 + pushq %rsi // -88 + pushq %rdi // -96 + addq $96, %rsp + +#ifdef _WIN32 + movq %rcx, %rdi + movq %rdx, %rsi + movq %r8, %rdx + movq %r9, %rcx + movq 40(%rsp), %r8 // channels + movq 48(%rsp), %r9 // output_width + + mov %rdx, -80(%rsp) + mov %rcx, -72(%rsp) + mov %r9, -56(%rsp) + mov %r8, -64(%rsp) + movq 56(%rsp), %rbp // input_stride + movq %rbp, 8(%rsp) + movq 64(%rsp), %rbp // relu + movq %rbp, 16(%rsp) + movq 72(%rsp), %rbp // relu6 + movq %rbp, 24(%rsp) +#endif + + movq $6, %rax + vcvtsi2ss %rax, %xmm15, %xmm15 + vshufps $0, %xmm15, %xmm15, %xmm15 + vinsertf128 $1, %xmm15, %ymm15, %ymm15 + vxorps %ymm14, %ymm14, %ymm14 + + LoopPixel: + movq -80(%rsp), %rdx + movq -72(%rsp), %rcx + movq -64(%rsp), %r8 + movq (%rsi), %r9 + movq 8(%rsi), %r10 + movq 16(%rsi), %r11 + movq 24(%rsi), %r12 + movq 32(%rsi), %r13 + movq 40(%rsi), %r14 + movq 48(%rsi), %r15 + movq 56(%rsi), %rbp + movq 64(%rsi), %rbx + + vmovups (%r9), %ymm0 + addq $32, %r9 + vmovups (%r10), %ymm1 + addq $32, %r10 + vmovups (%r11), %ymm2 + addq $32, %r11 + + vmovups (%rdx), %ymm11 + addq $32, %rdx + vmovups (%rdx), %ymm12 + addq $32, %rdx + vmovups (%rdx), %ymm13 + addq $32, %rdx + + vmovups (%rcx), %ymm10 + addq $32, %rcx + + cmpq $8, %r8 + jbe LeftLoop + LoopC8: + vfmadd231ps %ymm11, %ymm0, %ymm10 + vmovups (%r12), %ymm3 + addq $32, %r12 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm1, %ymm10 + vmovups (%r13), %ymm4 + addq $32, %r13 + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm2, %ymm10 + vmovups (%r14), %ymm5 + addq $32, %r14 + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm3, %ymm10 + vmovups (%r15), %ymm6 + addq $32, %r15 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm4, %ymm10 + vmovups (%rbp), %ymm7 + addq $32, %rbp + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm5, %ymm10 + vmovups (%rbx), %ymm8 + addq $32, %rbx + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm6, %ymm10 + vmovups (%r9), %ymm0 + addq $32, %r9 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm7, %ymm10 + vmovups (%r10), %ymm1 + addq $32, %r10 + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm8, %ymm10 + vmovups (%r11), %ymm2 + addq $32, %r11 + vmovups (%rdx), %ymm13 + addq $32, %rdx + + movq 24(%rsp), %rax + cmpq $0, %rax + jne Relu6 + movq 16(%rsp), %rax + cmpq $0, %rax + jne Relu + jmp Write + Relu6: + vminps %ymm15, %ymm10, %ymm10 + Relu: + vmaxps %ymm14, %ymm10, %ymm10 + Write: + vmovups %ymm10, (%rdi) + addq $32, %rdi + + vmovups (%rcx), %ymm10 + addq $32, %rcx + subq $8, %r8 + cmpq $8, %r8 + ja LoopC8 + + LeftLoop: + vfmadd231ps %ymm11, %ymm0, %ymm10 + vmovups (%r12), %ymm3 + addq $32, %r12 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm1, %ymm10 + vmovups (%r13), %ymm4 + addq $32, %r13 + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm2, %ymm10 + vmovups (%r14), %ymm5 + addq $32, %r14 + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm3, %ymm10 + vmovups (%r15), %ymm6 + addq $32, %r15 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm4, %ymm10 + vmovups (%rbp), %ymm7 + addq $32, %rbp + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm5, %ymm10 + vmovups (%rbx), %ymm8 + addq $32, %rbx + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm6, %ymm10 + vfmadd231ps %ymm12, %ymm7, %ymm10 + vfmadd231ps %ymm13, %ymm8, %ymm10 + + movq 24(%rsp), %rax + cmpq $0, %rax + jne LeftRelu6 + movq 16(%rsp), %rax + cmpq $0, %rax + jne LeftRelu + jmp LeftWrite + LeftRelu6: + vminps %ymm15, %ymm10, %ymm10 + LeftRelu: + vmaxps %ymm14, %ymm10, %ymm10 + LeftWrite: + cmpq $1, %r8 + je Write1 + cmpq $2, %r8 + je Write2 + cmpq $3, %r8 + je Write3 + cmpq $4, %r8 + je Write4 + cmpq $5, %r8 + je Write5 + cmpq $6, %r8 + je Write6 + cmpq $7, %r8 + je Write7 + jmp Write8 + Write1: + vmovss %xmm10, (%rdi) + addq $4, %rdi + jmp NextPixel + Write2: + vmovsd %xmm10, (%rdi) + addq $8, %rdi + jmp NextPixel + Write3: + vmovsd %xmm10, (%rdi) + movhlps %xmm10, %xmm10 + vmovss %xmm10, 8(%rdi) + addq $12, %rdi + jmp NextPixel + Write4: + vmovups %xmm10, (%rdi) + addq $16, %rdi + jmp NextPixel + Write5: + vmovups %xmm10, (%rdi) + vextractf128 $1, %ymm10, %xmm9 + vmovss %xmm9, 16(%rdi) + addq $20, %rdi + jmp NextPixel + Write6: + vmovups %xmm10, (%rdi) + vextractf128 $1, %ymm10, %xmm9 + vmovsd %xmm9, 16(%rdi) + addq $24, %rdi + jmp NextPixel + Write7: + vmovups %xmm10, (%rdi) + vextractf128 $1, %ymm10, %xmm9 + vmovsd %xmm9, 16(%rdi) + movhlps %xmm9, %xmm9 + vmovss %xmm9, 24(%rdi) + addq $28, %rdi + jmp NextPixel + Write8: + vmovups %ymm10, (%rdi) + add $32, %rdi + + NextPixel: + movq 8(%rsp), %rbp + addq %rbp, %rsi + movq -56(%rsp), %rax + subq $1, %rax + movq %rax, -56(%rsp) + cmpq $0, %rax + ja LoopPixel +End: + subq $96, %rsp + popq %rdi + popq %rsi + popq %rdx + popq %rcx + popq %r8 + popq %r9 + popq %rbp + popq %rbx + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32BorderAvx.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32BorderAvx.S new file mode 100644 index 0000000000000000000000000000000000000000..8e6c938d945d8b9d3cc0b9eec584c5c70f46cdab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32BorderAvx.S @@ -0,0 +1,188 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/assembly_global.h" + +.text +.align 4 + +// void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, +// size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, +// size_t relu6); + +asm_function ConvDwFp32Border + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rbx + pushq %rbp + pushq %r9 + pushq %r8 // -64 + pushq %rcx // -72 + pushq %rdx // -80 + pushq %rsi + pushq %rdi + addq $96, %rsp + + movq %rdi, %rdx +#ifdef _WIN32 + movq %rcx, %rdx +#endif + movq 8(%rdx), %r12 // src + movq 16(%rdx), %r13 // weight + movq 24(%rdx), %rbp // bias + movq 32(%rdx), %r11 // height + movq 40(%rdx), %r10 + movq %r10, -72(%rsp) // width + movq 48(%rdx), %r10 + movq %r10, -80(%rsp) // in_kh_step + movq 56(%rdx), %r10 // in_kw_step + movq 64(%rdx), %rax // kernel_w + movq 72(%rdx), %rcx // relu + movq 80(%rdx), %rbx // reul6 + movq $6, -64(%rsp) + movq (%rdx), %rdx + cmpq $0, %r11 + je End + + xorps %xmm8, %xmm8 + LoopHeight: + movq %r12, %rsi // src_kh, src_kw + movq %r13, %rdi // weight_kh, weight_kw + movq -72(%rsp), %r8 // width + + cmpq $6, %r8 + jae LoopWidth6 + cmpq $4, %r8 + jae LoopWidth4 + cmpq $1, %r8 + jae LoopWidth1 + jmp LoopWidthEnd + + LoopWidth6: + xorps %xmm6, %xmm6 + xorps %xmm7, %xmm7 + imul $3, %r10, %r9 + addq %rsi, %r9 + vmovups (%rsi), %xmm0 // src_kw + vmovups (%rsi, %r10), %xmm1 + vmovups (%rsi, %r10, 2), %xmm2 + vmovups (%r9), %xmm3 + vmovups (%rsi, %r10, 4), %xmm4 + vmovups (%r9, %r10, 2), %xmm5 + + vfmadd231ps (%rdi), %xmm0, %xmm6 + vfmadd231ps 16(%rdi), %xmm1, %xmm7 + vfmadd231ps 32(%rdi), %xmm2, %xmm8 + vfmadd231ps 48(%rdi), %xmm3, %xmm6 + vfmadd231ps 64(%rdi), %xmm4, %xmm7 + vfmadd231ps 80(%rdi), %xmm5, %xmm8 + + addps %xmm6, %xmm7 + imul $6, %r10, %r15 + addq $96, %rdi + addps %xmm7, %xmm8 + addq %r15, %rsi + + subq $6, %r8 + cmpq $6, %r8 + jae LoopWidth6 + cmpq $4, %r8 + jae LoopWidth4 + cmpq $0, %r8 + je LoopWidthEnd + jmp LoopWidth1 + + LoopWidth4: + xorps %xmm6, %xmm6 + xorps %xmm7, %xmm7 + imul $3, %r10, %r9 + addq %rsi, %r9 + vmovups (%rsi), %xmm0 // src_kw + vmovups (%rsi, %r10, 1), %xmm1 + vmovups (%rsi, %r10, 2), %xmm2 + vmovups (%r9), %xmm3 + + vfmadd231ps (%rdi), %xmm0, %xmm6 + vfmadd231ps 16(%rdi), %xmm1, %xmm7 + vfmadd231ps 32(%rdi), %xmm2, %xmm8 + vfmadd231ps 48(%rdi), %xmm3, %xmm6 + + addps %xmm6, %xmm7 + imul $4, %r10, %r15 + addq $64, %rdi + addps %xmm7, %xmm8 + addq %r15, %rsi + + subq $4, %r8 + cmpq $4, %r8 + jae LoopWidth4 + cmpq $0, %r8 + je LoopWidthEnd + jmp LoopWidth1 + + LoopWidth1: + vmovups (%rsi), %xmm0 // input_tmp + addq %r10, %rsi + vfmadd231ps (%rdi), %xmm0, %xmm8 + addq $16, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopWidth1 + jmp LoopWidthEnd + + LoopWidthEnd: + subq $1, %r11 + cmpq $0, %r11 + je LoopHeightEnd + addq -80(%rsp), %r12 // in_kh_step + addq %rax, %r13 // kernel_w_step + jmp LoopHeight + + LoopHeightEnd: + xorps %xmm10, %xmm10 + vbroadcastss -64(%rsp), %xmm9 + + addps (%rbp), %xmm8 + cmpq $1, %rbx + je Relu6 + cmpq $1, %rcx + je Relu + jmp Write + Relu6: + minps %xmm9, %xmm8 + Relu: + maxps %xmm10, %xmm8 + Write: + movups %xmm8, (%rdx) +End: + subq $96, %rsp + popq %rdi + popq %rsi + popq %rdx + popq %rcx + popq %r8 + popq %r9 + popq %rbp + popq %rbx + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32RowAvx.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32RowAvx.S new file mode 100644 index 0000000000000000000000000000000000000000..2b936afb971553c77054bf095117ad8c3e771a01 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32RowAvx.S @@ -0,0 +1,189 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/assembly_global.h" + +.text +.align 4 + +// void ConvDwFp32Row(float *output_ptr, const float *input_tmp, const float *weight_ptr, size_t num_pixels, +// size_t output_channel, size_t input_step); +// in linux x64 platform: +// rdi: output_ptr +// rsi: input_ptr +// rdx: weight_ptr +// rcx: num_pixels +// r8: output_channel +// r9: input_step + +// in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: output_ptr +// rdx: input_ptr +// r8: weight_ptr +// r9: num_pixels +// 40: output_channel +// 48: input_step + +asm_function ConvDwFp32Row + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rsi + pushq %rdi + addq $48, %rsp + +#ifdef _WIN32 + movq %rcx, %rdi // output_ptr + movq %rdx, %rsi // input_ptr + movq %r8, %rdx // weight_ptr + movq %r9, %rcx // num_pixels + movq 40(%rsp), %r8 // output_channel + movq 48(%rsp), %r9 // input_step +#endif + + movq $4, %r13 + imul %r13, %r9 + movq %rsi, %r13 // input_ptr + movq %rdx, %r14 // weight_ptr + movq %r8, %r15 // output_channel + cmpq $0, %rcx + je End + + LoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $32, %r8 + jae LoopC32 + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC32: + vmovups (%rsi), %ymm0 // input_tmp + vmovups 32(%rsi), %ymm1 + vmovups 64(%rsi), %ymm2 + vmovups 96(%rsi), %ymm3 + + vmovups (%rdi), %ymm8 // output_tmp + vmovups 32(%rdi), %ymm9 + vmovups 64(%rdi), %ymm10 + vmovups 96(%rdi), %ymm11 + + addq $128, %rsi + vfmadd231ps (%rdx), %ymm0, %ymm8 + vfmadd231ps 32(%rdx), %ymm1, %ymm9 + vfmadd231ps 64(%rdx), %ymm2, %ymm10 + vfmadd231ps 96(%rdx), %ymm3, %ymm11 + + vmovups %ymm8, (%rdi) // output_ptr + vmovups %ymm9, 32(%rdi) + vmovups %ymm10, 64(%rdi) + vmovups %ymm11, 96(%rdi) + addq $128, %rdi + addq $128, %rdx + + subq $32, %r8 + cmpq $32, %r8 + jae LoopC32 + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC16: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rdi), %ymm8 // output_tmp + vmovups 32(%rsi), %ymm1 + vmovups 32(%rdi), %ymm9 + addq $64, %rsi + + vfmadd231ps (%rdx), %ymm0, %ymm8 + vfmadd231ps 32(%rdx), %ymm1, %ymm9 + + vmovups %ymm8, (%rdi) // output_ptr + addq $64, %rdx + vmovups %ymm9, 32(%rdi) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rdi), %ymm8 // output_tmp + addq $32, %rsi + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rdi), %xmm8 // output_ptr + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp LoopPixel +End: + subq $48, %rsp + popq %rdi + popq %rsi + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32RowOptAVX.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32RowOptAVX.S new file mode 100644 index 0000000000000000000000000000000000000000..9492bd6a0bcff2a49f8be421789a263254b733e4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32RowOptAVX.S @@ -0,0 +1,382 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/assembly_global.h" + +.text +.align 4 + +// void ConvDwAVXFp32Row(float *output_ptr, const float *input_tmp, const float *weight_ptr, size_t num_pixels, +// size_t output_channel, size_t input_step); +// in linux x64 platform: +// rdi: output_ptr +// rsi: input_ptr +// rdx: weight_ptr +// rcx: num_pixels +// r8: output_channel +// r9: input_step + +// in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: output_ptr +// rdx: input_ptr +// r8: weight_ptr +// r9: num_pixels +// 40: output_channel +// 48: input_step + +asm_function ConvDwAVXFp32Row + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rsi + pushq %rdi + addq $48, %rsp + +#ifdef _WIN32 + movq %rcx, %rdi // output_ptr + movq %rdx, %rsi // input_ptr + movq %r8, %rdx // weight_ptr + movq %r9, %rcx // num_pixels + movq 40(%rsp), %r8 // output_channel + movq 48(%rsp), %r9 // input_step + movq 56(%rsp), %r11 // first_calc_flag + movq 64(%rsp), %r10 // bias +#else + movq 8(%rsp), %r11 // first_calc_flag + movq 16(%rsp), %r10 // bias +#endif + + + movq $4, %r13 + imul %r13, %r9 + movq %r8, %r12 + imul %r13, %r12 + movq %rsi, %r13 // input_ptr + movq %rdx, %r14 // weight_ptr + movq %r8, %r15 // output_channel + + cmpq $1, %r11 + je OutputInitByBias + jmp OutputInitBySelf + +OutputInitByBias: + cmpq $3, %rcx + jae BiasLoopPixelNum4 + cmpq $0, %rcx + ja BiasLoopPixel + je End + + BiasLoopPixelNum4: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + movq 16(%rsp), %r10 // bias_tmp + + cmpq $8, %r8 + jae BiasLoopC8Num4 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopC8Num4: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rsi, %r9), %ymm1 + vmovups (%rsi, %r9, 2), %ymm2 + // vmovups (%rsi, %r9, 3), %ymm3 + + vmovups (%r10), %ymm8 // output_tmp + vmovups (%r10), %ymm9 // output_tmp + vmovups (%r10), %ymm10 // output_tmp + // vmovups (%r10), %ymm11 // output_tmp + addq $32, %rsi + addq $32, %r10 + + vmovups (%rdx), %ymm15 // weight_tmp + vfmadd231ps %ymm15, %ymm0, %ymm8 + vfmadd231ps %ymm15, %ymm1, %ymm9 + vfmadd231ps %ymm15, %ymm2, %ymm10 + // vfmadd231ps %ymm15, %ymm3, %ymm11 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + vmovups %ymm9, (%rdi, %r12) + vmovups %ymm10, (%rdi, %r12, 2) + // vmovups %ymm11, (%rdi, %r12, 3) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae BiasLoopC8Num4 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopCNum4: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rsi, %r9), %xmm1 + vmovss (%rsi, %r9, 2), %xmm2 + // vmovss (%rsi, %r9, 3), %xmm3 + + vmovss (%r10), %xmm8 // output_ptr + vmovss (%r10), %xmm9 // output_tmp + vmovss (%r10), %xmm10 // output_tmp + // vmovss (%r10), %xmm11 // output_tmp + addq $4, %r10 + + vmovss (%rdx), %xmm15 // weight_tmp + vfmadd231ss %xmm15, %xmm0, %xmm8 + vfmadd231ss %xmm15, %xmm1, %xmm9 + vfmadd231ss %xmm15, %xmm2, %xmm10 + // vfmadd231ss %xmm15, %xmm3, %xmm11 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + vmovss %xmm9, (%rdi, %r12) + vmovss %xmm10, (%rdi, %r12, 2) + // vmovss %xmm11, (%rdi, %r12, 3) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopCEndNum4: + subq $3, %rcx // num_pixel -= 3 + addq %r12, %rdi + addq %r12, %rdi + + addq %r9, %r13 + addq %r9, %r13 + addq %r9, %r13 + cmpq $3, %rcx + jae BiasLoopPixelNum4 + cmpq $0, %rcx + ja BiasLoopPixel + je End + + BiasLoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + movq 16(%rsp), %r10 // bias_tmp + + cmpq $8, %r8 + jae BiasLoopC8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%r10), %ymm8 // output_tmp + addq $32, %rsi + addq $32, %r10 + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae BiasLoopC8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%r10), %xmm8 // output_ptr + addq $4, %r10 + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp BiasLoopPixel + +OutputInitBySelf: + cmpq $3, %rcx + jae LoopPixelNum4 + cmpq $0, %rcx + ja LoopPixel + je End + + LoopPixelNum4: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $8, %r8 + jae LoopC8Num4 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopC8Num4: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rsi, %r9), %ymm1 + vmovups (%rsi, %r9, 2), %ymm2 + // vmovups (%rsi, %r9, 3), %ymm3 + + vmovups (%rdi), %ymm8 // output_tmp + vmovups (%rdi, %r12), %ymm9 // output_tmp + vmovups (%rdi, %r12, 2), %ymm10 // output_tmp + // vmovups (%rdi, %r12, 3), %ymm11 // output_tmp + addq $32, %rsi + + vmovups (%rdx), %ymm15 // weight_tmp + vfmadd231ps %ymm15, %ymm0, %ymm8 + vfmadd231ps %ymm15, %ymm1, %ymm9 + vfmadd231ps %ymm15, %ymm2, %ymm10 + // vfmadd231ps %ymm15, %ymm3, %ymm11 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + vmovups %ymm9, (%rdi, %r12) + vmovups %ymm10, (%rdi, %r12, 2) + // vmovups %ymm11, (%rdi, %r12, 3) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8Num4 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopCNum4: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rsi, %r9), %xmm1 + vmovss (%rsi, %r9, 2), %xmm2 + // vmovss (%rsi, %r9, 3), %xmm3 + + vmovss (%rdi), %xmm8 // output_ptr + vmovss (%rdi, %r12), %xmm9 // output_tmp + vmovss (%rdi, %r12, 2), %xmm10 // output_tmp + // vmovss (%rdi, %r12, 3), %xmm11 // output_tmp + + vmovss (%rdx), %xmm15 // weight_tmp + vfmadd231ss %xmm15, %xmm0, %xmm8 + vfmadd231ss %xmm15, %xmm1, %xmm9 + vfmadd231ss %xmm15, %xmm2, %xmm10 + // vfmadd231ss %xmm15, %xmm3, %xmm11 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + vmovss %xmm9, (%rdi, %r12) + vmovss %xmm10, (%rdi, %r12, 2) + // vmovss %xmm11, (%rdi, %r12, 3) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopCEndNum4: + subq $3, %rcx // num_pixel -= 3 + addq %r12, %rdi + addq %r12, %rdi + + addq %r9, %r13 + addq %r9, %r13 + addq %r9, %r13 + cmpq $3, %rcx + jae LoopPixelNum4 + cmpq $0, %rcx + ja LoopPixel + je End + + LoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rdi), %ymm8 // output_tmp + addq $32, %rsi + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rdi), %xmm8 // output_ptr + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp LoopPixel +End: + subq $48, %rsp + popq %rdi + popq %rsi + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/MatmulAvx.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/MatmulAvx.S new file mode 100644 index 0000000000000000000000000000000000000000..85a64041dc69b56cc82f279b2ae04c987ec8af87 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/MatmulAvx.S @@ -0,0 +1,993 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/assembly_global.h" + +.text +.align 4 + +// void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// parameters pass in Linux x86 platform: +// rdi: a +// rsi: b +// rdx: c +// rcx: bias +// r8: act_type +// r9: depth +// 8: row +// 16: col +// 24: stride +// 32: writeNhwc/writeWino + +// parameters pass in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: a +// rdx: b +// r8: c +// r9: bias +// 40: act_type +// 48: depth +// 56: row +// 64: col +// 72: stride +// 80: writeMode + +asm_function MatmulFloatAvxOpt + // rbx, rsp, rbp, r12-r15 must be saved according to x86 calling convention + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rbx + pushq %rbp + pushq %r9 // -56 + pushq %r8 // -64 + pushq %rcx // -72 + pushq %rdx // -80 + pushq %rsi // -88 + pushq %rdi // -96 + pushq %rsi // -104 rsi + pushq %rdi // -112 rdi + addq $112, %rsp +#ifdef _WIN32 + movq %rcx, %rdi + movq %rdx, %rsi + movq %r8, %rdx + movq %r9, %rcx + movq 40(%rsp), %r8 // act_type + movq 48(%rsp), %r9 // depth + movq %r9, -56(%rsp) // r9 + movq %rcx, -72(%rsp) // rcx + movq %rdx, -80(%rsp) // rdx + movq %rsi, -88(%rsp) // rsi + movq %rdi, -96(%rsp) // rdi + + movq 56(%rsp), %rbp // row + movq %rbp, 8(%rsp) + movq 64(%rsp), %rbp // col + movq %rbp, 16(%rsp) + movq 72(%rsp), %rbp // stride + movq %rbp, 24(%rsp) + movq 80(%rsp), %rbp // weiteMode + movq %rbp, 32(%rsp) +#endif + movq 8(%rsp), %rbp + movq 16(%rsp), %rbx + movq 24(%rsp), %r10 + movq 32(%rsp), %r14 + + movq $24, %r11 + imul %r9, %r11 + cmpq $0, %r14 + jne NoC8Steps + movq $48, %r13 + imul %rbp, %r13 +NoC8Steps: + cmpq $2, %r14 + jne NoWinoSteps + movq $4, %r12 + imul %r10, %r12 + imul %rbx, %r12 + movq $32, %r13 + imul %r10, %r13 +NoWinoSteps: + movq $4, %rax + imul %rax, %r10 + +LoopRow: + movq -88(%rsp), %rsi + movq 16(%rsp), %rbx + movq -72(%rsp), %rcx + + LoopCol: + cmpq $0, %r14 + je NoReloadDst + movq -80(%rsp), %rdx + NoReloadDst: + movq -96(%rsp), %rdi + movq -56(%rsp), %r9 + + vmovups (%rsi), %ymm0 + vmovups 32(%rsi), %ymm1 + vbroadcastss (%rdi), %ymm10 + vbroadcastss 4(%rdi), %ymm11 + vbroadcastss 8(%rdi), %ymm12 + vbroadcastss 12(%rdi), %ymm13 + vbroadcastss 16(%rdi), %ymm2 + vbroadcastss 20(%rdi), %ymm3 + addq $64, %rsi + vmulps %ymm0, %ymm10, %ymm4 + vmulps %ymm1, %ymm10, %ymm5 + vmulps %ymm0, %ymm11, %ymm6 + vmulps %ymm1, %ymm11, %ymm7 + vmulps %ymm0, %ymm12, %ymm8 + vmulps %ymm1, %ymm12, %ymm9 + vmulps %ymm0, %ymm13, %ymm10 + vmulps %ymm1, %ymm13, %ymm11 + add $24, %rdi + vmulps %ymm0, %ymm2, %ymm12 + vmulps %ymm1, %ymm2, %ymm13 + vmulps %ymm0, %ymm3, %ymm14 + vmulps %ymm1, %ymm3, %ymm15 + + subq $1, %r9 + cmpq $0, %r9 + je Bias + + LoopDepth: + vmovups (%rsi), %ymm0 + vmovups 32(%rsi), %ymm1 + vbroadcastss (%rdi), %ymm2 + vbroadcastss 4(%rdi), %ymm3 + vfmadd231ps %ymm0, %ymm2, %ymm4 + addq $64, %rsi + vfmadd231ps %ymm1, %ymm2, %ymm5 + vbroadcastss 8(%rdi), %ymm2 + vfmadd231ps %ymm0, %ymm3, %ymm6 + vfmadd231ps %ymm1, %ymm3, %ymm7 + vbroadcastss 12(%rdi), %ymm3 + vfmadd231ps %ymm0, %ymm2, %ymm8 + prefetcht0 384(%rsi) + vfmadd231ps %ymm1, %ymm2, %ymm9 + vbroadcastss 16(%rdi), %ymm2 + vfmadd231ps %ymm0, %ymm3, %ymm10 + vfmadd231ps %ymm1, %ymm3, %ymm11 + vbroadcastss 20(%rdi), %ymm3 + vfmadd231ps %ymm0, %ymm2, %ymm12 + vfmadd231ps %ymm1, %ymm2, %ymm13 + addq $24, %rdi + vfmadd231ps %ymm0, %ymm3, %ymm14 + vfmadd231ps %ymm1, %ymm3, %ymm15 + + subq $1, %r9 + cmpq $0, %r9 + ja LoopDepth + + Bias: + cmpq $0, %rcx + je Activation + vmovups (%rcx), %ymm0 + vmovups 32(%rcx), %ymm1 + add $64, %rcx + vaddps %ymm0, %ymm4, %ymm4 + vaddps %ymm1, %ymm5, %ymm5 + vaddps %ymm0, %ymm6, %ymm6 + vaddps %ymm1, %ymm7, %ymm7 + vaddps %ymm0, %ymm8, %ymm8 + vaddps %ymm1, %ymm9, %ymm9 + vaddps %ymm0, %ymm10, %ymm10 + vaddps %ymm1, %ymm11, %ymm11 + vaddps %ymm0, %ymm12, %ymm12 + vaddps %ymm1, %ymm13, %ymm13 + vaddps %ymm0, %ymm14, %ymm14 + vaddps %ymm1, %ymm15, %ymm15 + + Activation: + cmpq $3, %r8 + je Relu6 + cmpq $1, %r8 + je Relu + jmp Write + + Relu6: + movq $6, %rax + vcvtsi2ss %rax, %xmm0, %xmm0 + vshufps $0, %xmm0, %xmm0, %xmm0 + vinsertf128 $1, %xmm0, %ymm0, %ymm0 + vminps %ymm0, %ymm4, %ymm4 + vminps %ymm0, %ymm5, %ymm5 + vminps %ymm0, %ymm6, %ymm6 + vminps %ymm0, %ymm7, %ymm7 + vminps %ymm0, %ymm8, %ymm8 + vminps %ymm0, %ymm9, %ymm9 + vminps %ymm0, %ymm10, %ymm10 + vminps %ymm0, %ymm11, %ymm11 + vminps %ymm0, %ymm12, %ymm12 + vminps %ymm0, %ymm13, %ymm13 + vminps %ymm0, %ymm14, %ymm14 + vminps %ymm0, %ymm15, %ymm15 + + Relu: + vxorps %ymm1, %ymm1, %ymm1 + vmaxps %ymm1, %ymm4, %ymm4 + vmaxps %ymm1, %ymm5, %ymm5 + vmaxps %ymm1, %ymm6, %ymm6 + vmaxps %ymm1, %ymm7, %ymm7 + vmaxps %ymm1, %ymm8, %ymm8 + vmaxps %ymm1, %ymm9, %ymm9 + vmaxps %ymm1, %ymm10, %ymm10 + vmaxps %ymm1, %ymm11, %ymm11 + vmaxps %ymm1, %ymm12, %ymm12 + vmaxps %ymm1, %ymm13, %ymm13 + vmaxps %ymm1, %ymm14, %ymm14 + vmaxps %ymm1, %ymm15, %ymm15 + + Write: + cmpq $2, %r14 + je WriteWino + cmpq $0, %r14 + je WriteC8 + cmpq $1, %rbx + je Write1 + cmpq $2, %rbx + je Write2 + cmpq $3, %rbx + je Write3 + cmpq $4, %rbx + je Write4 + cmpq $5, %rbx + je Write5 + cmpq $6, %rbx + je Write6 + cmpq $7, %rbx + je Write7 + cmpq $8, %rbx + je Write8 + cmpq $9, %rbx + je Write9 + cmpq $10, %rbx + je Write10 + cmpq $11, %rbx + je Write11 + cmpq $12, %rbx + je Write12 + cmpq $13, %rbx + je Write13 + cmpq $14, %rbx + je Write14 + cmpq $15, %rbx + je Write15 + jmp Write16 + + Write1: + movq %rdx, %rax + addq $4, %rax + movq %rax, -80(%rsp) + vmovss %xmm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm14, (%rdx) + addq %r10, %rdx + addq $4, %rdx + jmp WriteEnd + Write2: + movq %rdx, %rax + addq $8, %rax + movq %rax, -80(%rsp) + vmovsd %xmm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm14, (%rdx) + addq %r10, %rdx + addq $8, %rdx + jmp WriteEnd + Write3: + movq %rdx, %rax + addq $12, %rax + movq %rax, -80(%rsp) + vmovsd %xmm4, (%rdx) + movhlps %xmm4, %xmm4 + vmovss %xmm4, 8(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm6, (%rdx) + movhlps %xmm6, %xmm6 + vmovss %xmm6, 8(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm8, (%rdx) + movhlps %xmm8, %xmm8 + vmovss %xmm8, 8(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm10, (%rdx) + movhlps %xmm10, %xmm10 + vmovss %xmm10, 8(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm12, (%rdx) + movhlps %xmm12, %xmm12 + vmovss %xmm12, 8(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm14, (%rdx) + movhlps %xmm14, %xmm14 + vmovss %xmm14, 8(%rdx) + addq %r10, %rdx + addq $12, %rdx + jmp WriteEnd + Write4: + movq %rdx, %rax + addq $16, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + addq %r10, %rdx + addq $16, %rdx + jmp WriteEnd + Write5: + movq %rdx, %rax + addq $20, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + vextractf128 $1, %ymm4, %xmm4 + vmovss %xmm4, 16(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + vextractf128 $1, %ymm6, %xmm6 + vmovss %xmm6, 16(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + vextractf128 $1, %ymm8, %xmm8 + vmovss %xmm8, 16(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + vextractf128 $1, %ymm10, %xmm10 + vmovss %xmm10, 16(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + vextractf128 $1, %ymm12, %xmm12 + vmovss %xmm12, 16(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + vextractf128 $1, %ymm14, %xmm14 + vmovss %xmm14, 16(%rdx) + addq %r10, %rdx + addq $20, %rdx + jmp WriteEnd + Write6: + movq %rdx, %rax + addq $24, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + vextractf128 $1, %ymm4, %xmm4 + vmovsd %xmm4, 16(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + vextractf128 $1, %ymm6, %xmm6 + vmovsd %xmm6, 16(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + vextractf128 $1, %ymm8, %xmm8 + vmovsd %xmm8, 16(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + vextractf128 $1, %ymm10, %xmm10 + vmovsd %xmm10, 16(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + vextractf128 $1, %ymm12, %xmm12 + vmovsd %xmm12, 16(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + vextractf128 $1, %ymm14, %xmm14 + vmovsd %xmm14, 16(%rdx) + addq %r10, %rdx + addq $24, %rdx + jmp WriteEnd + Write7: + movq %rdx, %rax + addq $28, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + vextractf128 $1, %ymm4, %xmm4 + vmovsd %xmm4, 16(%rdx) + movhlps %xmm4, %xmm4 + vmovss %xmm4, 24(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + vextractf128 $1, %ymm6, %xmm6 + vmovsd %xmm6, 16(%rdx) + movhlps %xmm6, %xmm6 + vmovss %xmm6, 24(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + vextractf128 $1, %ymm8, %xmm8 + vmovsd %xmm8, 16(%rdx) + movhlps %xmm8, %xmm8 + vmovss %xmm8, 24(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + vextractf128 $1, %ymm10, %xmm10 + vmovsd %xmm10, 16(%rdx) + movhlps %xmm10, %xmm10 + vmovss %xmm10, 24(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + vextractf128 $1, %ymm12, %xmm12 + vmovsd %xmm12, 16(%rdx) + movhlps %xmm12, %xmm12 + vmovss %xmm12, 24(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + vextractf128 $1, %ymm14, %xmm14 + vmovsd %xmm14, 16(%rdx) + movhlps %xmm14, %xmm14 + vmovss %xmm14, 24(%rdx) + addq %r10, %rdx + addq $28, %rdx + jmp WriteEnd + Write8: + movq %rdx, %rax + addq $32, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + addq %r10, %rdx + addq $32, %rdx + jmp WriteEnd + Write9: + movq %rdx, %rax + addq $36, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovss %xmm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovss %xmm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovss %xmm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovss %xmm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovss %xmm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovss %xmm15, 32(%rdx) + addq %r10, %rdx + addq $36, %rdx + jmp WriteEnd + Write10: + movq %rdx, %rax + addq $40, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovsd %xmm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovsd %xmm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovsd %xmm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovsd %xmm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovsd %xmm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovsd %xmm15, 32(%rdx) + addq %r10, %rdx + addq $40, %rdx + jmp WriteEnd + Write11: + movq %rdx, %rax + addq $44, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovsd %xmm5, 32(%rdx) + movhlps %xmm5, %xmm5 + vmovss %xmm5, 40(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovsd %xmm7, 32(%rdx) + movhlps %xmm7, %xmm7 + vmovss %xmm7, 40(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovsd %xmm9, 32(%rdx) + movhlps %xmm9, %xmm9 + vmovss %xmm9, 40(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovsd %xmm11, 32(%rdx) + movhlps %xmm11, %xmm11 + vmovss %xmm11, 40(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovsd %xmm13, 32(%rdx) + movhlps %xmm13, %xmm13 + vmovss %xmm13, 40(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovsd %xmm15, 32(%rdx) + movhlps %xmm15, %xmm15 + vmovss %xmm15, 40(%rdx) + addq %r10, %rdx + addq $44, %rdx + jmp WriteEnd + Write12: + movq %rdx, %rax + addq $48, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + addq %r10, %rdx + addq $48, %rdx + jmp WriteEnd + Write13: + movq %rdx, %rax + addq $52, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + vextractf128 $1, %ymm5, %xmm5 + vmovss %xmm5, 48(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + vextractf128 $1, %ymm7, %xmm7 + vmovss %xmm7, 48(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + vextractf128 $1, %ymm9, %xmm9 + vmovss %xmm9, 48(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + vextractf128 $1, %ymm11, %xmm11 + vmovss %xmm11, 48(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + vextractf128 $1, %ymm13, %xmm13 + vmovss %xmm13, 48(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + vextractf128 $1, %ymm15, %xmm15 + vmovss %xmm15, 48(%rdx) + addq %r10, %rdx + addq $52, %rdx + jmp WriteEnd + Write14: + movq %rdx, %rax + addq $56, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + vextractf128 $1, %ymm5, %xmm5 + vmovsd %xmm5, 48(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + vextractf128 $1, %ymm7, %xmm7 + vmovsd %xmm7, 48(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + vextractf128 $1, %ymm9, %xmm9 + vmovsd %xmm9, 48(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + vextractf128 $1, %ymm11, %xmm11 + vmovsd %xmm11, 48(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + vextractf128 $1, %ymm13, %xmm13 + vmovsd %xmm13, 48(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + vextractf128 $1, %ymm15, %xmm15 + vmovsd %xmm15, 48(%rdx) + addq %r10, %rdx + addq $56, %rdx + jmp WriteEnd + Write15: + movq %rdx, %rax + addq $60, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + vextractf128 $1, %ymm5, %xmm5 + vmovsd %xmm5, 48(%rdx) + movhlps %xmm5, %xmm5 + vmovss %xmm5, 56(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + vextractf128 $1, %ymm7, %xmm7 + vmovsd %xmm7, 48(%rdx) + movhlps %xmm7, %xmm7 + vmovss %xmm7, 56(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + vextractf128 $1, %ymm9, %xmm9 + vmovsd %xmm9, 48(%rdx) + movhlps %xmm9, %xmm9 + vmovss %xmm9, 56(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + vextractf128 $1, %ymm11, %xmm11 + vmovsd %xmm11, 48(%rdx) + movhlps %xmm11, %xmm11 + vmovss %xmm11, 56(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + vextractf128 $1, %ymm13, %xmm13 + vmovsd %xmm13, 48(%rdx) + movhlps %xmm13, %xmm13 + vmovss %xmm13, 56(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + vextractf128 $1, %ymm15, %xmm15 + vmovsd %xmm15, 48(%rdx) + movhlps %xmm15, %xmm15 + vmovss %xmm15, 56(%rdx) + addq %r10, %rdx + addq $60, %rdx + jmp WriteEnd + WriteC8: + movq %rdx, %rax + addq %r11, %rdx + movq %rdx, %r15 + addq %r11, %rdx + movq %rdx, -80(%rsp) + vmovups %ymm4, (%rax) + vmovups %ymm6, 32(%rax) + vmovups %ymm8, 64(%rax) + vmovups %ymm10, 96(%rax) + vmovups %ymm12, 128(%rax) + vmovups %ymm14, 160(%rax) + vmovups %ymm5, (%r15) + vmovups %ymm7, 32(%r15) + vmovups %ymm9, 64(%r15) + vmovups %ymm11, 96(%r15) + vmovups %ymm13, 128(%r15) + vmovups %ymm15, 160(%r15) + jmp WriteEnd + WriteWino: + movq %rdx, %rax + addq %r13, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + addq %r12, %rdx + vmovups %ymm6, (%rdx) + addq %r12, %rdx + vmovups %ymm8, (%rdx) + addq %r12, %rdx + vmovups %ymm10, (%rdx) + addq %r12, %rdx + vmovups %ymm12, (%rdx) + addq %r12, %rdx + vmovups %ymm14, (%rdx) + cmpq $8, %rbx + je WriteEnd + movq %rax, %rdx + addq %r13, %rax + movq %rax, -80(%rsp) + vmovups %ymm5, (%rdx) + addq %r12, %rdx + vmovups %ymm7, (%rdx) + addq %r12, %rdx + vmovups %ymm9, (%rdx) + addq %r12, %rdx + vmovups %ymm11, (%rdx) + addq %r12, %rdx + vmovups %ymm13, (%rdx) + addq %r12, %rdx + vmovups %ymm15, (%rdx) + jmp WriteEnd + Write16: + movq %rdx, %rax + addq $64, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %ymm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %ymm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %ymm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %ymm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %ymm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %ymm15, 32(%rdx) + addq %r10, %rdx + addq $64, %rdx + + WriteEnd: + cmpq $16, %rbx + jbe LoopColEnd + subq $16, %rbx + jmp LoopCol + + LoopColEnd: + movq -96(%rsp), %rdi + addq %r11, %rdi + movq %rdi, -96(%rsp) + cmpq $0, %r14 + je C8DstStep + cmpq $2, %r14 + je WinoDstStep + movq $4, %rax + movq 16(%rsp), %rbx + imul %rbx, %rax + subq %rax, %rdx + movq %rdx, -80(%rsp) + jmp NoDstStep + C8DstStep: + movq -80(%rsp), %rax + addq $384, %rax + movq %rax, -80(%rsp) + jmp NoDstStep + WinoDstStep: + addq %r13, %rdx + movq %rdx, -80(%rsp) + NoDstStep: + cmpq $6, %rbp + jbe LoopRowEnd + subq $6, %rbp + jmp LoopRow + +LoopRowEnd: + subq $112, %rsp + popq %rdi + popq %rsi + popq %rdx + popq %rdx + popq %rdx + popq %rcx + popq %r8 + popq %r9 + popq %rbp + popq %rbx + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx512/ConvDwFp32RowAVX512.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx512/ConvDwFp32RowAVX512.S new file mode 100644 index 0000000000000000000000000000000000000000..7afdeb0f5fbe0ca587d4bb9b1c70d6e4cc21aab4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx512/ConvDwFp32RowAVX512.S @@ -0,0 +1,499 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX512 +#include "nnacl_c/assembly_global.h" + +.text +.align 4 + +// void ConvDwAVX512Fp32Row(float *output_ptr, const float *input_tmp, const float *weight_ptr, size_t num_pixels, +// size_t output_channel, size_t input_step); +// in linux x64 platform: +// rdi: output_ptr +// rsi: input_ptr +// rdx: weight_ptr +// rcx: num_pixels +// r8: output_channel +// r9: input_step + +// in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: output_ptr +// rdx: input_ptr +// r8: weight_ptr +// r9: num_pixels +// 40: output_channel +// 48: input_step + +asm_function ConvDwAVX512Fp32Row + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rsi + pushq %rdi + addq $48, %rsp + +#ifdef _WIN32 + movq %rcx, %rdi // output_ptr + movq %rdx, %rsi // input_ptr + movq %r8, %rdx // weight_ptr + movq %r9, %rcx // num_pixels + movq 40(%rsp), %r8 // output_channel + movq 48(%rsp), %r9 // input_step + movq 56(%rsp), %r11 // first_calc_flag + movq 64(%rsp), %r10 // bias +#else + movq 8(%rsp), %r11 // first_calc_flag + movq 16(%rsp), %r10 // bias +#endif + + movq $4, %r13 + imul %r13, %r9 + movq %r8, %r12 + imul %r13, %r12 + movq %rsi, %r13 // input_ptr + movq %rdx, %r14 // weight_ptr + movq %r8, %r15 // output_channel + + cmpq $1, %r11 + je OutputInitByBias + jmp OutputInitBySelf + +OutputInitByBias: + cmpq $3, %rcx + jae BiasLoopPixelNum4 + cmpq $0, %rcx + ja BiasLoopPixel + je End + + BiasLoopPixelNum4: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + movq 16(%rsp), %r10 // bias_tmp + + cmpq $16, %r8 + jae BiasLoopC16Num4 + cmpq $8, %r8 + jae BiasLoopC8Num4 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopC16Num4: + vmovups (%rsi), %zmm0 // input_tmp + vmovups (%rsi, %r9), %zmm1 + vmovups (%rsi, %r9, 2), %zmm2 + // vmovups (%rsi, %r9, 3), %zmm3 + + vmovups (%r10), %zmm8 // output_tmp + vmovups (%r10), %zmm9 // output_tmp + vmovups (%r10), %zmm10 // output_tmp + // vmovups (%r10), %zmm11 // output_tmp + addq $64, %rsi + addq $64, %r10 + + vmovups (%rdx), %zmm15 // weight_tmp + vfmadd231ps %zmm15, %zmm0, %zmm8 + vfmadd231ps %zmm15, %zmm1, %zmm9 + vfmadd231ps %zmm15, %zmm2, %zmm10 + // vfmadd231ps %zmm15, %zmm3, %zmm11 + + addq $64, %rdx + vmovups %zmm8, (%rdi) + vmovups %zmm9, (%rdi, %r12) + vmovups %zmm10, (%rdi, %r12, 2) + // vmovups %zmm11, (%rdi, %r12, 3) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae BiasLoopC16Num4 + cmpq $8, %r8 + jae BiasLoopC8Num4 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopC8Num4: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rsi, %r9), %ymm1 + vmovups (%rsi, %r9, 2), %ymm2 + // vmovups (%rsi, %r9, 3), %ymm3 + + vmovups (%r10), %ymm8 // output_tmp + vmovups (%r10), %ymm9 // output_tmp + vmovups (%r10), %ymm10 // output_tmp + // vmovups (%r10), %ymm11 // output_tmp + addq $32, %rsi + addq $32, %r10 + + vmovups (%rdx), %ymm15 // weight_tmp + vfmadd231ps %ymm15, %ymm0, %ymm8 + vfmadd231ps %ymm15, %ymm1, %ymm9 + vfmadd231ps %ymm15, %ymm2, %ymm10 + // vfmadd231ps %ymm15, %ymm3, %ymm11 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + vmovups %ymm9, (%rdi, %r12) + vmovups %ymm10, (%rdi, %r12, 2) + // vmovups %ymm11, (%rdi, %r12, 3) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae BiasLoopC8Num4 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopCNum4: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rsi, %r9), %xmm1 + vmovss (%rsi, %r9, 2), %xmm2 + // vmovss (%rsi, %r9, 3), %xmm3 + + vmovss (%r10), %xmm8 // output_ptr + vmovss (%r10), %xmm9 // output_tmp + vmovss (%r10), %xmm10 // output_tmp + // vmovss (%r10), %xmm11 // output_tmp + addq $4, %r10 + + vmovss (%rdx), %xmm15 // weight_tmp + vfmadd231ss %xmm15, %xmm0, %xmm8 + vfmadd231ss %xmm15, %xmm1, %xmm9 + vfmadd231ss %xmm15, %xmm2, %xmm10 + // vfmadd231ss %xmm15, %xmm3, %xmm11 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + vmovss %xmm9, (%rdi, %r12) + vmovss %xmm10, (%rdi, %r12, 2) + // vmovss %xmm11, (%rdi, %r12, 3) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopCEndNum4: + subq $3, %rcx // num_pixel -= 3 + addq %r12, %rdi + addq %r12, %rdi + + addq %r9, %r13 + addq %r9, %r13 + addq %r9, %r13 + cmpq $3, %rcx + jae BiasLoopPixelNum4 + cmpq $0, %rcx + ja BiasLoopPixel + je End + + BiasLoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + movq 16(%rsp), %r10 // bias_tmp + + cmpq $16, %r8 + jae BiasLoopC16 + cmpq $8, %r8 + jae BiasLoopC8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopC16: + vmovups (%rsi), %zmm0 // input_tmp + vmovups (%r10), %zmm8 // output_tmp + addq $64, %rsi + addq $64, %r10 + + vfmadd231ps (%rdx), %zmm0, %zmm8 + + addq $64, %rdx + vmovups %zmm8, (%rdi) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae BiasLoopC16 + cmpq $8, %r8 + jae BiasLoopC8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%r10), %ymm8 // output_tmp + addq $32, %rsi + addq $32, %r10 + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae BiasLoopC8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%r10), %xmm8 // output_ptr + addq $4, %r10 + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp BiasLoopPixel + +OutputInitBySelf: + cmpq $3, %rcx + jae LoopPixelNum4 + cmpq $0, %rcx + ja LoopPixel + je End + + LoopPixelNum4: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $16, %r8 + jae LoopC16Num4 + cmpq $8, %r8 + jae LoopC8Num4 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopC16Num4: + vmovups (%rsi), %zmm0 // input_tmp + vmovups (%rsi, %r9), %zmm1 + vmovups (%rsi, %r9, 2), %zmm2 + // vmovups (%rsi, %r9, 3), %zmm3 + + vmovups (%rdi), %zmm8 // output_tmp + vmovups (%rdi, %r12), %zmm9 // output_tmp + vmovups (%rdi, %r12, 2), %zmm10 // output_tmp + // vmovups (%rdi, %r12, 3), %zmm11 // output_tmp + addq $64, %rsi + + vmovups (%rdx), %zmm15 // weight_tmp + vfmadd231ps %zmm15, %zmm0, %zmm8 + vfmadd231ps %zmm15, %zmm1, %zmm9 + vfmadd231ps %zmm15, %zmm2, %zmm10 + // vfmadd231ps %zmm15, %zmm3, %zmm11 + + addq $64, %rdx + vmovups %zmm8, (%rdi) + vmovups %zmm9, (%rdi, %r12) + vmovups %zmm10, (%rdi, %r12, 2) + // vmovups %zmm11, (%rdi, %r12, 3) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae LoopC16Num4 + cmpq $8, %r8 + jae LoopC8Num4 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopC8Num4: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rsi, %r9), %ymm1 + vmovups (%rsi, %r9, 2), %ymm2 + // vmovups (%rsi, %r9, 3), %ymm3 + + vmovups (%rdi), %ymm8 // output_tmp + vmovups (%rdi, %r12), %ymm9 // output_tmp + vmovups (%rdi, %r12, 2), %ymm10 // output_tmp + // vmovups (%rdi, %r12, 3), %ymm11 // output_tmp + addq $32, %rsi + + vmovups (%rdx), %ymm15 // weight_tmp + vfmadd231ps %ymm15, %ymm0, %ymm8 + vfmadd231ps %ymm15, %ymm1, %ymm9 + vfmadd231ps %ymm15, %ymm2, %ymm10 + // vfmadd231ps %ymm15, %ymm3, %ymm11 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + vmovups %ymm9, (%rdi, %r12) + vmovups %ymm10, (%rdi, %r12, 2) + // vmovups %ymm11, (%rdi, %r12, 3) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8Num4 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopCNum4: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rsi, %r9), %xmm1 + vmovss (%rsi, %r9, 2), %xmm2 + // vmovss (%rsi, %r9, 3), %xmm3 + + vmovss (%rdi), %xmm8 // output_ptr + vmovss (%rdi, %r12), %xmm9 // output_tmp + vmovss (%rdi, %r12, 2), %xmm10 // output_tmp + // vmovss (%rdi, %r12, 3), %xmm11 // output_tmp + + vmovss (%rdx), %xmm15 // weight_tmp + vfmadd231ss %xmm15, %xmm0, %xmm8 + vfmadd231ss %xmm15, %xmm1, %xmm9 + vfmadd231ss %xmm15, %xmm2, %xmm10 + // vfmadd231ss %xmm15, %xmm3, %xmm11 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + vmovss %xmm9, (%rdi, %r12) + vmovss %xmm10, (%rdi, %r12, 2) + // vmovss %xmm11, (%rdi, %r12, 3) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopCEndNum4: + subq $3, %rcx // num_pixel -= 3 + addq %r12, %rdi + addq %r12, %rdi + + addq %r9, %r13 + addq %r9, %r13 + addq %r9, %r13 + cmpq $3, %rcx + jae LoopPixelNum4 + cmpq $0, %rcx + ja LoopPixel + je End + + LoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC16: + vmovups (%rsi), %zmm0 // input_tmp + vmovups (%rdi), %zmm8 // output_tmp + addq $64, %rsi + + vfmadd231ps (%rdx), %zmm0, %zmm8 + + addq $64, %rdx + vmovups %zmm8, (%rdi) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rdi), %ymm8 // output_tmp + addq $32, %rsi + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rdi), %xmm8 // output_ptr + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp LoopPixel +End: + subq $48, %rsp + popq %rdi + popq %rsi + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/CalculateMinMaxFp16Count8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/CalculateMinMaxFp16Count8.S new file mode 100644 index 0000000000000000000000000000000000000000..85b202ae6b3c399bfeb34cbf8a150e3e03863e97 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/CalculateMinMaxFp16Count8.S @@ -0,0 +1,56 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void CalculateMinMaxCount8Fp16(const float16_t *data, int count_8, float16_t *real_min, float16_t *real_max); +// x0: data +// w1: count_8 +// x2: real_min +// x3: real_max + +asm_function CalculateMinMaxCount8Fp16 + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + + mov x4, x0 // reload data + mov w5, w1 // reload count + ld1 {v31.8h}, [x4] + ld1 {v30.8h}, [x4], #16 + subs w5, w5, #8 + ble Write + + Loop: + ld1 {v0.8h}, [x4], #16 + fmin v31.8h, v31.8h, v0.8h + fmax v30.8h, v30.8h, v0.8h + subs w5, w5, #8 + bgt Loop + + Write: + fminv h6, v31.8h + fmaxv h7, v30.8h + + str h6, [x2] + str h7, [x3] + + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Border.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Border.S new file mode 100644 index 0000000000000000000000000000000000000000..4822c93b83f87abd9041a2df9c4bf6c714629506 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Border.S @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, +// size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, +// size_t relu6) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: in_kh_step, x7: in_kw_step, +// x8: kernel_w, x9: relu, x10: relu6 +asm_function ConvDwFp16Border + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + + ld1 {v0.8h}, [x3] // bias + movi v1.8h, #0x46, lsl #8 // relu 6 + dup v2.4s, wzr // relu + + mov x13, x1 + mov x14, x2 + LoopH: + mov x15, x13 + mov x16, x14 + mov x17, x5 + LoopW: + ld1 {v3.8h}, [x15], x7 + ld1 {v4.8h}, [x16], #16 + fmla v0.8h, v3.8h, v4.8h + subs x17, x17, #1 + bne LoopW + subs x4, x4, #1 + add x13, x13, x6 + add x14, x14, x8 + bne LoopH + cbnz x10, Relu6 + cbnz x9, Relu + b Write + Relu6: + fmin v0.8h, v0.8h, v1.8h + Relu: + fmax v0.8h, v0.8h, v2.8h + Write: + st1 {v0.8h}, [x0] + + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Center.S new file mode 100644 index 0000000000000000000000000000000000000000..1b2534379a6ae471733c67d213178380706beb93 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Center.S @@ -0,0 +1,312 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: weight, x6: kernel_h, x7: kernel_w, +// x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step +// x14: relu, x15: relu6 +asm_function ConvDwFp16Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #192 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + + ld1 {v24.8h}, [x3] + movi v26.8h, #0x46, lsl #8 + dup v27.4s, wzr + + LoopH: + mov x23, x1 + mov x24, x5 + mov x3, x0 + cmp x24, #8 + blt LoopW + cmp x24, #16 + blt LoopW8 + + LoopW16: + mov x19, #16 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + mov v8.16b, v24.16b + mov v9.16b, v24.16b + mov v10.16b, v24.16b + mov v11.16b, v24.16b + mov v12.16b, v24.16b + mov v13.16b, v24.16b + mov v14.16b, v24.16b + mov v15.16b, v24.16b + LoopKh16: + mov x25, x7 + mov x21, x16 + LoopKw16: + mov x22, x21 + ld1 {v25.8h}, [x17], #16 + ld1 {v16.8h}, [x22], x11 + ld1 {v17.8h}, [x22], x11 + fmla v0.8h, v16.8h, v25.8h + fmla v1.8h, v17.8h, v25.8h + ld1 {v18.8h}, [x22], x11 + ld1 {v19.8h}, [x22], x11 + fmla v2.8h, v18.8h, v25.8h + fmla v3.8h, v19.8h, v25.8h + ld1 {v20.8h}, [x22], x11 + ld1 {v21.8h}, [x22], x11 + fmla v4.8h, v20.8h, v25.8h + fmla v5.8h, v21.8h, v25.8h + ld1 {v22.8h}, [x22], x11 + ld1 {v23.8h}, [x22], x11 + fmla v6.8h, v22.8h, v25.8h + fmla v7.8h, v23.8h, v25.8h + ld1 {v16.8h}, [x22], x11 + ld1 {v17.8h}, [x22], x11 + fmla v8.8h, v16.8h, v25.8h + fmla v9.8h, v17.8h, v25.8h + ld1 {v18.8h}, [x22], x11 + ld1 {v19.8h}, [x22], x11 + fmla v10.8h, v18.8h, v25.8h + fmla v11.8h, v19.8h, v25.8h + ld1 {v20.8h}, [x22], x11 + ld1 {v21.8h}, [x22], x11 + fmla v12.8h, v20.8h, v25.8h + fmla v13.8h, v21.8h, v25.8h + ld1 {v22.8h}, [x22], x11 + ld1 {v23.8h}, [x22], x11 + fmla v14.8h, v22.8h, v25.8h + fmla v15.8h, v23.8h, v25.8h + subs x25, x25, #1 + add x21, x21, x13 + bne LoopKw16 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh16 + cbnz x15, Relu616 + cbnz x14, Relu16 + b Write16 + Relu616: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h + fmin v4.8h, v4.8h, v26.8h + fmin v5.8h, v5.8h, v26.8h + fmin v6.8h, v6.8h, v26.8h + fmin v7.8h, v7.8h, v26.8h + fmin v8.8h, v8.8h, v26.8h + fmin v9.8h, v9.8h, v26.8h + fmin v10.8h, v10.8h, v26.8h + fmin v11.8h, v11.8h, v26.8h + fmin v12.8h, v12.8h, v26.8h + fmin v13.8h, v13.8h, v26.8h + fmin v14.8h, v14.8h, v26.8h + fmin v15.8h, v15.8h, v26.8h + Relu16: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h + fmax v4.8h, v4.8h, v27.8h + fmax v5.8h, v5.8h, v27.8h + fmax v6.8h, v6.8h, v27.8h + fmax v7.8h, v7.8h, v27.8h + fmax v8.8h, v8.8h, v27.8h + fmax v9.8h, v9.8h, v27.8h + fmax v10.8h, v10.8h, v27.8h + fmax v11.8h, v11.8h, v27.8h + fmax v12.8h, v12.8h, v27.8h + fmax v13.8h, v13.8h, v27.8h + fmax v14.8h, v14.8h, v27.8h + fmax v15.8h, v15.8h, v27.8h + Write16: + st1 {v0.8h}, [x3], x9 + st1 {v1.8h}, [x3], x9 + st1 {v2.8h}, [x3], x9 + st1 {v3.8h}, [x3], x9 + st1 {v4.8h}, [x3], x9 + st1 {v5.8h}, [x3], x9 + st1 {v6.8h}, [x3], x9 + st1 {v7.8h}, [x3], x9 + st1 {v8.8h}, [x3], x9 + st1 {v9.8h}, [x3], x9 + st1 {v10.8h}, [x3], x9 + st1 {v11.8h}, [x3], x9 + st1 {v12.8h}, [x3], x9 + st1 {v13.8h}, [x3], x9 + st1 {v14.8h}, [x3], x9 + st1 {v15.8h}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #16 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + blt LoopW + cmp x24, #16 + bge LoopW16 + LoopW8: + mov x19, #8 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + LoopKh8: + mov x25, x7 + mov x21, x16 + LoopKw8: + mov x22, x21 + ld1 {v25.8h}, [x17], #16 + ld1 {v16.8h}, [x22], x11 + ld1 {v17.8h}, [x22], x11 + fmla v0.8h, v16.8h, v25.8h + fmla v1.8h, v17.8h, v25.8h + ld1 {v18.8h}, [x22], x11 + ld1 {v19.8h}, [x22], x11 + fmla v2.8h, v18.8h, v25.8h + fmla v3.8h, v19.8h, v25.8h + ld1 {v20.8h}, [x22], x11 + ld1 {v21.8h}, [x22], x11 + fmla v4.8h, v20.8h, v25.8h + fmla v5.8h, v21.8h, v25.8h + ld1 {v22.8h}, [x22], x11 + ld1 {v23.8h}, [x22], x11 + fmla v6.8h, v22.8h, v25.8h + fmla v7.8h, v23.8h, v25.8h + subs x25, x25, #1 + add x21, x21, x13 + bne LoopKw8 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh8 + cbnz x15, Relu68 + cbnz x14, Relu8 + b Write8 + Relu68: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h + fmin v4.8h, v4.8h, v26.8h + fmin v5.8h, v5.8h, v26.8h + fmin v6.8h, v6.8h, v26.8h + fmin v7.8h, v7.8h, v26.8h + Relu8: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h + fmax v4.8h, v4.8h, v27.8h + fmax v5.8h, v5.8h, v27.8h + fmax v6.8h, v6.8h, v27.8h + fmax v7.8h, v7.8h, v27.8h + Write8: + st1 {v0.8h}, [x3], x9 + st1 {v1.8h}, [x3], x9 + st1 {v2.8h}, [x3], x9 + st1 {v3.8h}, [x3], x9 + st1 {v4.8h}, [x3], x9 + st1 {v5.8h}, [x3], x9 + st1 {v6.8h}, [x3], x9 + st1 {v7.8h}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #8 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + bge LoopW8 + LoopW: + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + LoopKh: + mov x25, x7 + mov x22, x16 + LoopKw: + ld1 {v16.8h}, [x22], x13 + ld1 {v25.8h}, [x17], #16 + fmla v0.8h, v16.8h, v25.8h + subs x25, x25, #1 + bne LoopKw + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh + cbnz x15, Relu6 + cbnz x14, Relu + b Write + Relu6: + fmin v0.8h, v0.8h, v26.8h + Relu: + fmax v0.8h, v0.8h, v27.8h + Write: + st1 {v0.8h}, [x3], x9 + add x23, x23, x11 + subs x24, x24, #1 + bne LoopW + LoopWEnd: + add x0, x0, x8 + add x1, x1, x10 + subs x4, x4, #1 + bne LoopH + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Row.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Row.S new file mode 100644 index 0000000000000000000000000000000000000000..2238257d7c748aec6538cf4bfc6a1cdfca36a8af --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Row.S @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp16Row(float16_t* output_ptr, const float16_t* input_ptr,const float16_t* filter_ptr, +// size_t num_pixels, size_t input_channel, size_t input_step) +// x0: output_ptr, x1: input_ptr, x2: filter_ptr, x3: num_pixels, +// x4: input_channel, x5: input_step +// +asm_function ConvDwFp16Row + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters +cmp x3, #0 +beq End + +mov x9, x0 +mov x12, #2 // sizeof(float16_t) +mul x5, x5, x12 + +LoopOutPixel: +mov x6, x1 +mov x7, x2 +mov x8, x4 + +LoopInputDepth32In: + cmp x8, #32 + blt Loop8 + sub x8, x8, #32 + + ld1 {v0.8h, v1.8h}, [x6], #32 + ld1 {v2.8h, v3.8h}, [x7], #32 + ld1 {v16.8h, v17.8h}, [x0], #32 + + cmp x8, #32 + blt LoopInputDepth32Out + LoopInputDepth32: + fmla v16.8h, v0.8h, v2.8h + fmla v17.8h, v1.8h, v3.8h + + st1 {v16.8h, v17.8h}, [x9], #32 + + ld1 {v4.8h, v5.8h}, [x6], #32 + ld1 {v6.8h, v7.8h}, [x7], #32 + ld1 {v18.8h, v19.8h}, [x0], #32 + + fmla v18.8h, v4.8h, v6.8h + fmla v19.8h, v5.8h, v7.8h + + st1 {v18.8h, v19.8h}, [x9], #32 + + ld1 {v0.8h, v1.8h}, [x6], #32 + ld1 {v2.8h, v3.8h}, [x7], #32 + ld1 {v16.8h, v17.8h}, [x0], #32 + + sub x8, x8, #32 + cmp x8, #32 + bge LoopInputDepth32 + + LoopInputDepth32Out: + fmla v16.8h, v0.8h, v2.8h + fmla v17.8h, v1.8h, v3.8h + st1 {v16.8h, v17.8h}, [x9], #32 + + ld1 {v4.8h, v5.8h}, [x6], #32 + ld1 {v6.8h, v7.8h}, [x7], #32 + ld1 {v18.8h, v19.8h}, [x0], #32 + + fmla v18.8h, v4.8h, v6.8h + fmla v19.8h, v5.8h, v7.8h + + st1 {v18.8h, v19.8h}, [x9], #32 + + Loop8: + cmp x8, #8 + blt L0 + + LoopInputDepth8: + ld1 {v0.8h}, [x6], #16 + ld1 {v2.8h}, [x7], #16 + ld1 {v16.8h}, [x0], #16 + fmla v16.8h, v0.8h, v2.8h + st1 {v16.8h}, [x9], #16 + sub x8, x8, #8 + cmp x8, #8 + bge LoopInputDepth8 + + L0: + cmp x8, #0 + beq Loop8LineEnd + + LoopInputDepth0: + ldr h0, [x6], #2 + ldr h1, [x7], #2 + ldr h2, [x0], #2 + fmul h0, h0, h1 + fadd h2, h2, h0 + str h2, [x9], #2 + subs x8, x8, #1 + bne LoopInputDepth0 + + Loop8LineEnd: + +subs x3, x3, #1 +add x1, x1, x5 +bne LoopOutPixel + +End: +ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DeconvDwFp16Border.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DeconvDwFp16Border.S new file mode 100644 index 0000000000000000000000000000000000000000..103985c75b5acec688eee9956b97b89c9db92fe4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DeconvDwFp16Border.S @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwFp16Border(float *dst, const float *src, const float *weight, size_t height, size_t width, +// size_t in_kh_step, size_t in_kw_step, size_t kernel_w) + +// x0: dst, x1: src, x2: weight, x3: height, x4: width, x5: in_kh_step, x6: in_kw_step, x7: kernel_w +asm_function DeconvDwFp16Border + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ld1 {v1.8h}, [x1] + + mov x13, x0 + mov x14, x2 + LoopH: + mov x15, x13 + mov x16, x14 + mov x17, x4 + LoopW: + ld1 {v0.8h}, [x15] + ld1 {v2.8h}, [x16], #16 + fmla v0.8h, v1.8h, v2.8h + st1 {v0.8h}, [x15], x6 + subs x17, x17, #1 + bne LoopW + subs x3, x3, #1 + add x13, x13, x5 + add x14, x14, x7 + bne LoopH + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DeconvDwFp16Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DeconvDwFp16Center.S new file mode 100644 index 0000000000000000000000000000000000000000..44f0c1ce6ad34d9c26230cfb9f272308826a6243 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DeconvDwFp16Center.S @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step); +// x0: dst, x1: src, x2: weight, x3: height, x4: weight, x5: kernel_h, x6: kernel_w, x7: out_h_step +// x8: block_channel, x9: in_sh_step, x10: in_sw_step, x11: in_kh_step, x12: in_kw_step +asm_function DeconvDwFp16Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + ldr x8, [sp, #32] + ldr x9, [sp, #40] + ldr x10, [sp, #48] + ldr x11, [sp, #56] + ldr x12, [sp, #64] + + LoopH: + mov x15, x0 + mov x16, x1 + mov x17, x4 + LoopW: + mov x22, x15 + mov x19, x2 + mov x20, x5 + ld1 {v1.8h}, [x16], x8 + LoopKh: + mov x21, x22 + mov x13, x6 + LoopKw: + ld1 {v0.8h}, [x21] + ld1 {v2.8h}, [x19], #16 + fmla v0.8h, v1.8h, v2.8h + st1 {v0.8h}, [x21], x12 + subs x13, x13, #1 + bne LoopKw + add x22, x22, x11 + subs x20, x20, #1 + bne LoopKh + add x15, x15, x10 + subs x17, x17, #1 + bne LoopW + add x0, x0, x9 + add x1, x1, x7 + subs x3, x3, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DynamicGatherArm64ForFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DynamicGatherArm64ForFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..c27ade021349fffa53b5c8c7262e7aa7e26be6b3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DynamicGatherArm64ForFp16.S @@ -0,0 +1,54 @@ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" +.text +.align 5 + +// void DynamicGatherArm64ForFp16(const int8_t *src, float16_t *output, int count_16, int zp, float16_t scale); +// x0: src(left matrix ptr) +// x1: output(right matrix ptr) +// w2: count_16 +// w3: zp +// w4: scale + +asm_function DynamicGatherArm64ForFp16 + mov x5, x0 // reload src + mov x6, x1 // reload out + mov w7, w2 // reload count_16 + dup v1.4s, w3 // zp + dup v2.4s, v0.s[0] // scale + + LoopCount: + ld1 {v0.16b}, [x5], #16 + + sxtl v3.8h, v0.8b + sxtl2 v4.8h, v0.16b + + sxtl v16.4s, v3.4h + sxtl2 v17.4s, v3.8h + sxtl v18.4s, v4.4h + sxtl2 v19.4s, v4.8h + + sub v16.4s, v16.4s, v1.4s + scvtf v16.4s,v16.4s + fmul v16.4s, v16.4s, v2.4s + sub v17.4s, v17.4s, v1.4s + scvtf v17.4s,v17.4s + fmul v17.4s, v17.4s, v2.4s + sub v18.4s, v18.4s, v1.4s + scvtf v18.4s,v18.4s + fmul v18.4s, v18.4s, v2.4s + sub v19.4s, v19.4s, v1.4s + scvtf v19.4s,v19.4s + fmul v19.4s, v19.4s, v2.4s + + fcvtn v16.4h, v16.4s + fcvtn v17.4h, v17.4s + fcvtn v18.4h, v18.4s + fcvtn v19.4h, v19.4s + + st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x6], #32 + subs w7, w7, #16 + bgt LoopCount +ret + +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Float16ToFloat32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Float16ToFloat32.S new file mode 100644 index 0000000000000000000000000000000000000000..391736583b2075daefd08443db51cc53a258e06f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Float16ToFloat32.S @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void Float16ToFloat32(const float16_t *input, float *output, int number); +// x0: input, x1: output, x2: number +asm_function Float16ToFloat32 + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + cmp x2, #0 + beq LoopEnd + cmp x2, #64 + blt Loop + Loop64: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 + fcvtl v16.4s, v0.4h + fcvtl2 v17.4s, v0.8h + fcvtl v18.4s, v1.4h + fcvtl2 v19.4s, v1.8h + fcvtl v20.4s, v2.4h + fcvtl2 v21.4s, v2.8h + fcvtl v22.4s, v3.4h + fcvtl2 v23.4s, v3.8h + fcvtl v24.4s, v4.4h + fcvtl2 v25.4s, v4.8h + fcvtl v26.4s, v5.4h + fcvtl2 v27.4s, v5.8h + fcvtl v28.4s, v6.4h + fcvtl2 v29.4s, v6.8h + fcvtl v30.4s, v7.4h + fcvtl2 v31.4s, v7.8h + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x1], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x1], #64 + subs x2, x2, #64 + ble LoopEnd + cmp x2, #64 + bge Loop64 + Loop: + ldr h0, [x0], #2 + fcvt s0, h0 + str s0, [x1], #4 + subs x2, x2, #1 + bgt Loop + LoopEnd: + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Float32ToFloat16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Float32ToFloat16.S new file mode 100644 index 0000000000000000000000000000000000000000..b40a8aae16ebc45afb341651453410cd8472cff5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Float32ToFloat16.S @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void Float32ToFloat16(const float *input, float16_t output, int number); +// x0: input, x1: output, x2: number +asm_function Float32ToFloat16 + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + cmp x2, #0 + beq LoopEnd + cmp x2, #64 + blt Loop + Loop64: + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], #64 + fcvtn v0.4h, v16.4s + fcvtn2 v0.8h, v17.4s + fcvtn v1.4h, v18.4s + fcvtn2 v1.8h, v19.4s + fcvtn v2.4h, v20.4s + fcvtn2 v2.8h, v21.4s + fcvtn v3.4h, v22.4s + fcvtn2 v3.8h, v23.4s + fcvtn v4.4h, v24.4s + fcvtn2 v4.8h, v25.4s + fcvtn v5.4h, v26.4s + fcvtn2 v5.8h, v27.4s + fcvtn v6.4h, v28.4s + fcvtn2 v6.8h, v29.4s + fcvtn v7.4h, v30.4s + fcvtn2 v7.8h, v31.4s + st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 + st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x1], #64 + subs x2, x2, #64 + ble LoopEnd + cmp x2, #64 + bge Loop64 + Loop: + ldr s0, [x0], #4 + fcvt h0, s0 + str h0, [x1], #2 + subs x2, x2, #1 + bgt Loop + LoopEnd: + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatVecMulFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatVecMulFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..c5aa798ab5be30e1187e2e05951ebbf38bb5a8ff --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatVecMulFp16.S @@ -0,0 +1,191 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, int depth, int col) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: col + +asm_function MatVecMulFp16Neon64 + sub sp, sp, #128 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + + mov w14, #2 // sizeof(float16) + mul w8, w14, w5 // rhs depthx1 block stride + mov w14, #4 + mul w13, w8, w14 // rhs depthx4 block stride + +Loop: + mov x15, x0 // reload a ptr + mov x7, x1 // reload b ptr + mov w9, w5 // reload depth + cmp w6, #4 + blt Loop1x1 + +Loop1x4: + dup v5.8h, wzr + dup v6.8h, wzr + dup v7.8h, wzr + dup v8.8h, wzr + dup v9.8h, wzr + dup v10.8h, wzr + dup v11.8h, wzr + dup v12.8h, wzr + dup v13.8h, wzr + + add x10, x7, x8 + add x11, x10, x8 + add x12, x11, x8 + +Depth8_1x4: + cmp w9, #8 + blt Depth1_1x4 + + ld1 {v0.8h}, [x15], #16 + ld1 {v1.8h}, [x7], #16 + ld1 {v2.8h}, [x10], #16 + ld1 {v3.8h}, [x11], #16 + ld1 {v4.8h}, [x12], #16 + + fmla v5.8h, v1.8h, v0.8h + fmla v6.8h, v2.8h, v0.8h + fmla v7.8h, v3.8h, v0.8h + fmla v8.8h, v4.8h, v0.8h + sub w9, w9, #8 + cbz w9, End1x4 + b Depth8_1x4 + +Depth1_1x4: + ld1 {v0.h}[0], [x15], #2 + ld1 {v1.h}[0], [x7], #2 + ld1 {v1.h}[1], [x10], #2 + ld1 {v1.h}[2], [x11], #2 + ld1 {v1.h}[3], [x12], #2 + + fmla v9.8h, v1.8h, v0.h[0] + sub w9, w9, #1 + cbz w9, End1x4 + b Depth1_1x4 + +End1x4: + faddp v10.8h, v5.8h, v6.8h + faddp v11.8h, v7.8h, v8.8h + faddp v12.8h, v10.8h, v11.8h + faddp v13.8h, v12.8h, v12.8h + fadd v13.8h, v13.8h, v9.8h + + cbz x3, Act1x4 + ld1 {v14.4h}, [x3], #8 + fadd v13.8h, v13.8h, v14.8h + +Act1x4: + cmp w4, #3 + beq Relu6_1x4 + cmp w4, #1 + beq Relu1x4 + b Write1x4 + +Relu6_1x4: + movi v14.8h, #0x46, lsl #8 + fmin v13.8h, v13.8h, v14.8h + +Relu1x4: + dup v14.8h, wzr + fmax v13.8h, v13.8h, v14.8h + +Write1x4: + st1 {v13.4h}, [x2], #8 + sub w6, w6, #4 + cbz w6, End + add x1, x1, x13 + b Loop + +Loop1x1: + dup v2.8h, wzr + dup v3.8h, wzr + dup v4.8h, wzr + dup v5.8h, wzr + dup v6.8h, wzr + +Depth8_1x1: + cmp w9, #8 + blt Depth1_1x1 + + ld1 {v0.8h}, [x15], #16 + ld1 {v1.8h}, [x7], #16 + + fmla v2.8h, v1.8h, v0.8h + sub w9, w9, #8 + cbz w9, End1x1 + b Depth8_1x1 + +Depth1_1x1: + ld1 {v0.h}[0], [x15], #2 + ld1 {v1.h}[0], [x7], #2 + + fmla v3.8h, v1.8h, v0.h[0] + sub w9, w9, #1 + cbz w9, End1x1 + b Depth1_1x1 + +End1x1: + faddp v4.8h, v2.8h, v2.8h + faddp v5.8h, v4.8h, v4.8h + faddp v6.8h, v5.8h, v5.8h + fadd v6.8h, v6.8h, v3.8h + + cbz x3, Act1x1 + ld1 {v7.h}[0], [x3], #2 + fadd v6.8h, v6.8h, v7.8h + +Act1x1: + cmp w4, #3 + beq Relu6_1x1 + cmp w4, #1 + beq Relu1x1 + b Write1x1 + +Relu6_1x1: + movi v7.8h, #0x46, lsl #8 + fmin v6.8h, v6.8h, v7.8h + +Relu1x1: + dup v7.8h, wzr + fmax v6.8h, v6.8h, v7.8h + +Write1x1: + st1 {v6.h}[0], [x2], #2 + sub w6, w6, #1 + cbz w6, End + add x1, x1, x8 + b Loop + +End: + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ret +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Matmul12X16Fp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Matmul12X16Fp16.S new file mode 100644 index 0000000000000000000000000000000000000000..0af3589e482e9084e326746ce421aa0d5ec66ac1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Matmul12X16Fp16.S @@ -0,0 +1,1703 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatMul12x16Fp16Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type : ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu +// x5: depth : Ic +// x6: row : remain_row +// x7: col +// x8: stride : output_stride x8 = x8 * 2 +// x9: writeMode : OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 + +// x17 : input_stride + +asm_function MatMul12x16Fp16Opt + sub sp, sp, #160 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] + +.macro CLEAR_OUTPUT_V8_V9 + dup v8.4s, wzr + dup v9.4s, wzr +.endm + +.macro CLEAR_OUTPUT_V8_V11 + dup v8.4s, wzr + dup v9.4s, wzr + dup v10.4s, wzr + dup v11.4s, wzr +.endm + +.macro CLEAR_OUTPUT_V8_V15 + CLEAR_OUTPUT_V8_V11 + dup v12.4s, wzr + dup v13.4s, wzr + dup v14.4s, wzr + dup v15.4s, wzr +.endm + +.macro CLEAR_OUTPUT_V8_V23 + CLEAR_OUTPUT_V8_V15 + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr +.endm + +.macro CLEAR_OUTPUT_V8_V31 + CLEAR_OUTPUT_V8_V23 + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +.endm + + mov x21, #24 + mul x17, x5, x21 // input_stride : 12 * Ic * sizeof(float16_t) + mov x21, #2 + mul x8, x8, x21 // output_stride + +LoopRowStart: + cmp x6, #1 + ble LoopRow1 + cmp x6, #2 + ble LoopRow2 + cmp x6, #4 + ble LoopRow4 + cmp x6, #8 + ble LoopRow8 + +LoopRow12: + mov x14, x1 // cur_weight + mov x13, x7 // reload_col + mov x12, x3 // reload_bias + + LoopCol12: + mov x11, x2 // cur_output + mov x10, x0 // cur_input + mov x19, x5 // reload_depth + CLEAR_OUTPUT_V8_V31 + cmp x19, #2 + blt LoopDepth12One + + LoopDepth12: + ld1 {v0.8h}, [x10], #16 // cur_input + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x14], #32 // cur_weight + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v5.8h, v0.h[0] + fmla v10.8h, v4.8h, v0.h[1] + fmla v11.8h, v5.8h, v0.h[1] + fmla v12.8h, v4.8h, v0.h[2] + fmla v13.8h, v5.8h, v0.h[2] + fmla v14.8h, v4.8h, v0.h[3] + fmla v15.8h, v5.8h, v0.h[3] + fmla v16.8h, v4.8h, v0.h[4] + fmla v17.8h, v5.8h, v0.h[4] + fmla v18.8h, v4.8h, v0.h[5] + fmla v19.8h, v5.8h, v0.h[5] + ld1 {v2.8h}, [x10], #16 // cur_input + ld1 {v3.4h}, [x10], #8 + ld1 {v6.8h, v7.8h}, [x14], #32 // cur_weight + fmla v20.8h, v4.8h, v0.h[6] + fmla v21.8h, v5.8h, v0.h[6] + fmla v22.8h, v4.8h, v0.h[7] + fmla v23.8h, v5.8h, v0.h[7] + fmla v24.8h, v4.8h, v1.h[0] + fmla v25.8h, v5.8h, v1.h[0] + fmla v26.8h, v4.8h, v1.h[1] + fmla v27.8h, v5.8h, v1.h[1] + fmla v28.8h, v4.8h, v1.h[2] + fmla v29.8h, v5.8h, v1.h[2] + fmla v30.8h, v4.8h, v1.h[3] + fmla v31.8h, v5.8h, v1.h[3] + + fmla v8.8h, v6.8h, v2.h[0] + fmla v9.8h, v7.8h, v2.h[0] + fmla v10.8h, v6.8h, v2.h[1] + fmla v11.8h, v7.8h, v2.h[1] + fmla v12.8h, v6.8h, v2.h[2] + fmla v13.8h, v7.8h, v2.h[2] + fmla v14.8h, v6.8h, v2.h[3] + fmla v15.8h, v7.8h, v2.h[3] + fmla v16.8h, v6.8h, v2.h[4] + fmla v17.8h, v7.8h, v2.h[4] + fmla v18.8h, v6.8h, v2.h[5] + fmla v19.8h, v7.8h, v2.h[5] + fmla v20.8h, v6.8h, v2.h[6] + fmla v21.8h, v7.8h, v2.h[6] + fmla v22.8h, v6.8h, v2.h[7] + fmla v23.8h, v7.8h, v2.h[7] + fmla v24.8h, v6.8h, v3.h[0] + fmla v25.8h, v7.8h, v3.h[0] + fmla v26.8h, v6.8h, v3.h[1] + fmla v27.8h, v7.8h, v3.h[1] + fmla v28.8h, v6.8h, v3.h[2] + fmla v29.8h, v7.8h, v3.h[2] + fmla v30.8h, v6.8h, v3.h[3] + fmla v31.8h, v7.8h, v3.h[3] + subs x19, x19, #2 + beq Bias12 + cmp x19, #2 + bge LoopDepth12 + + LoopDepth12One: + ld1 {v0.4h, v1.4h, v2.4h}, [x10], #24 // cur_input + ld1 {v3.8h, v4.8h}, [x14], #32 // cur_weight + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v11.8h, v4.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v13.8h, v4.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v15.8h, v4.8h, v0.h[3] + fmla v16.8h, v3.8h, v1.h[0] + fmla v17.8h, v4.8h, v1.h[0] + fmla v18.8h, v3.8h, v1.h[1] + fmla v19.8h, v4.8h, v1.h[1] + fmla v20.8h, v3.8h, v1.h[2] + fmla v21.8h, v4.8h, v1.h[2] + fmla v22.8h, v3.8h, v1.h[3] + fmla v23.8h, v4.8h, v1.h[3] + fmla v24.8h, v3.8h, v2.h[0] + fmla v25.8h, v4.8h, v2.h[0] + fmla v26.8h, v3.8h, v2.h[1] + fmla v27.8h, v4.8h, v2.h[1] + fmla v28.8h, v3.8h, v2.h[2] + fmla v29.8h, v4.8h, v2.h[2] + fmla v30.8h, v3.8h, v2.h[3] + fmla v31.8h, v4.8h, v2.h[3] + subs x19, x19, #1 + bgt LoopDepth12One + + Bias12: + cbz x3, Activation12 + ld1 {v0.8h, v1.8h}, [x12], #32 + fadd v8.8h, v8.8h, v0.8h + fadd v9.8h, v9.8h, v1.8h + fadd v10.8h, v10.8h, v0.8h + fadd v11.8h, v11.8h, v1.8h + fadd v12.8h, v12.8h, v0.8h + fadd v13.8h, v13.8h, v1.8h + fadd v14.8h, v14.8h, v0.8h + fadd v15.8h, v15.8h, v1.8h + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v1.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v1.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v1.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v1.8h + fadd v24.8h, v24.8h, v0.8h + fadd v25.8h, v25.8h, v1.8h + fadd v26.8h, v26.8h, v0.8h + fadd v27.8h, v27.8h, v1.8h + fadd v28.8h, v28.8h, v0.8h + fadd v29.8h, v29.8h, v1.8h + fadd v30.8h, v30.8h, v0.8h + fadd v31.8h, v31.8h, v1.8h + + Activation12: + cmp x4, #3 + beq Relu612 + cmp x4, #1 + beq Relu12 + b Write + + Relu612: + movi v2.8h, #0x46, lsl #8 + fmin v8.8h, v8.8h, v2.8h + fmin v9.8h, v9.8h, v2.8h + fmin v10.8h, v10.8h, v2.8h + fmin v11.8h, v11.8h, v2.8h + fmin v12.8h, v12.8h, v2.8h + fmin v13.8h, v13.8h, v2.8h + fmin v14.8h, v14.8h, v2.8h + fmin v15.8h, v15.8h, v2.8h + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + fmin v18.8h, v18.8h, v2.8h + fmin v19.8h, v19.8h, v2.8h + fmin v20.8h, v20.8h, v2.8h + fmin v21.8h, v21.8h, v2.8h + fmin v22.8h, v22.8h, v2.8h + fmin v23.8h, v23.8h, v2.8h + fmin v24.8h, v24.8h, v2.8h + fmin v25.8h, v25.8h, v2.8h + fmin v26.8h, v26.8h, v2.8h + fmin v27.8h, v27.8h, v2.8h + fmin v28.8h, v28.8h, v2.8h + fmin v29.8h, v29.8h, v2.8h + fmin v30.8h, v30.8h, v2.8h + fmin v31.8h, v31.8h, v2.8h + + Relu12: + dup v2.8h, wzr + fmax v8.8h, v8.8h, v2.8h + fmax v9.8h, v9.8h, v2.8h + fmax v10.8h, v10.8h, v2.8h + fmax v11.8h, v11.8h, v2.8h + fmax v12.8h, v12.8h, v2.8h + fmax v13.8h, v13.8h, v2.8h + fmax v14.8h, v14.8h, v2.8h + fmax v15.8h, v15.8h, v2.8h + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + fmax v18.8h, v18.8h, v2.8h + fmax v19.8h, v19.8h, v2.8h + fmax v20.8h, v20.8h, v2.8h + fmax v21.8h, v21.8h, v2.8h + fmax v22.8h, v22.8h, v2.8h + fmax v23.8h, v23.8h, v2.8h + fmax v24.8h, v24.8h, v2.8h + fmax v25.8h, v25.8h, v2.8h + fmax v26.8h, v26.8h, v2.8h + fmax v27.8h, v27.8h, v2.8h + fmax v28.8h, v28.8h, v2.8h + fmax v29.8h, v29.8h, v2.8h + fmax v30.8h, v30.8h, v2.8h + fmax v31.8h, v31.8h, v2.8h + b Write + +LoopRow8: + mov x14, x1 // cur_weight + mov x13, x7 // reload_col + mov x12, x3 // reload_bias + + LoopCol8: + mov x11, x2 // cur_output + mov x10, x0 // cur_input + mov x19, x5 // reload_depth + CLEAR_OUTPUT_V8_V23 + cmp x19, #2 + blt LoopDepth8One + + LoopDepth8: + ld1 {v0.8h}, [x10], #16 // cur_input + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x14], #32 // cur_weight + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v5.8h, v0.h[0] + fmla v10.8h, v4.8h, v0.h[1] + fmla v11.8h, v5.8h, v0.h[1] + fmla v12.8h, v4.8h, v0.h[2] + fmla v13.8h, v5.8h, v0.h[2] + fmla v14.8h, v4.8h, v0.h[3] + fmla v15.8h, v5.8h, v0.h[3] + fmla v16.8h, v4.8h, v0.h[4] + fmla v17.8h, v5.8h, v0.h[4] + fmla v18.8h, v4.8h, v0.h[5] + fmla v19.8h, v5.8h, v0.h[5] + ld1 {v2.8h}, [x10], #16 // cur_input + ld1 {v3.4h}, [x10], #8 + ld1 {v6.8h, v7.8h}, [x14], #32 // cur_weight + fmla v20.8h, v4.8h, v0.h[6] + fmla v21.8h, v5.8h, v0.h[6] + fmla v22.8h, v4.8h, v0.h[7] + fmla v23.8h, v5.8h, v0.h[7] + + fmla v8.8h, v6.8h, v2.h[0] + fmla v9.8h, v7.8h, v2.h[0] + fmla v10.8h, v6.8h, v2.h[1] + fmla v11.8h, v7.8h, v2.h[1] + fmla v12.8h, v6.8h, v2.h[2] + fmla v13.8h, v7.8h, v2.h[2] + fmla v14.8h, v6.8h, v2.h[3] + fmla v15.8h, v7.8h, v2.h[3] + fmla v16.8h, v6.8h, v2.h[4] + fmla v17.8h, v7.8h, v2.h[4] + fmla v18.8h, v6.8h, v2.h[5] + fmla v19.8h, v7.8h, v2.h[5] + fmla v20.8h, v6.8h, v2.h[6] + fmla v21.8h, v7.8h, v2.h[6] + fmla v22.8h, v6.8h, v2.h[7] + fmla v23.8h, v7.8h, v2.h[7] + subs x19, x19, #2 + beq Bias8 + cmp x19, #2 + bge LoopDepth8 + + LoopDepth8One: + ld1 {v0.4h, v1.4h, v2.4h}, [x10], #24 // cur_input + ld1 {v3.8h, v4.8h}, [x14], #32 // cur_weight + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v11.8h, v4.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v13.8h, v4.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v15.8h, v4.8h, v0.h[3] + fmla v16.8h, v3.8h, v1.h[0] + fmla v17.8h, v4.8h, v1.h[0] + fmla v18.8h, v3.8h, v1.h[1] + fmla v19.8h, v4.8h, v1.h[1] + fmla v20.8h, v3.8h, v1.h[2] + fmla v21.8h, v4.8h, v1.h[2] + fmla v22.8h, v3.8h, v1.h[3] + fmla v23.8h, v4.8h, v1.h[3] + subs x19, x19, #1 + bgt LoopDepth8One + + Bias8: + cbz x3, Activation8 + ld1 {v0.8h, v1.8h}, [x12], #32 + fadd v8.8h, v8.8h, v0.8h + fadd v9.8h, v9.8h, v1.8h + fadd v10.8h, v10.8h, v0.8h + fadd v11.8h, v11.8h, v1.8h + fadd v12.8h, v12.8h, v0.8h + fadd v13.8h, v13.8h, v1.8h + fadd v14.8h, v14.8h, v0.8h + fadd v15.8h, v15.8h, v1.8h + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v1.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v1.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v1.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v1.8h + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + movi v2.8h, #0x46, lsl #8 + fmin v8.8h, v8.8h, v2.8h + fmin v9.8h, v9.8h, v2.8h + fmin v10.8h, v10.8h, v2.8h + fmin v11.8h, v11.8h, v2.8h + fmin v12.8h, v12.8h, v2.8h + fmin v13.8h, v13.8h, v2.8h + fmin v14.8h, v14.8h, v2.8h + fmin v15.8h, v15.8h, v2.8h + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + fmin v18.8h, v18.8h, v2.8h + fmin v19.8h, v19.8h, v2.8h + fmin v20.8h, v20.8h, v2.8h + fmin v21.8h, v21.8h, v2.8h + fmin v22.8h, v22.8h, v2.8h + fmin v23.8h, v23.8h, v2.8h + + Relu8: + dup v2.8h, wzr + fmax v8.8h, v8.8h, v2.8h + fmax v9.8h, v9.8h, v2.8h + fmax v10.8h, v10.8h, v2.8h + fmax v11.8h, v11.8h, v2.8h + fmax v12.8h, v12.8h, v2.8h + fmax v13.8h, v13.8h, v2.8h + fmax v14.8h, v14.8h, v2.8h + fmax v15.8h, v15.8h, v2.8h + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + fmax v18.8h, v18.8h, v2.8h + fmax v19.8h, v19.8h, v2.8h + fmax v20.8h, v20.8h, v2.8h + fmax v21.8h, v21.8h, v2.8h + fmax v22.8h, v22.8h, v2.8h + fmax v23.8h, v23.8h, v2.8h + b Write + +LoopRow4: + mov x14, x1 // cur_weight + mov x13, x7 // reload_col + mov x12, x3 // reload_bias + + LoopCol4: + mov x11, x2 // cur_output + mov x10, x0 // cur_input + mov x19, x5 // reload_depth + CLEAR_OUTPUT_V8_V15 + cmp x19, #2 + blt LoopDepth4One + + LoopDepth4: + ld1 {v0.8h}, [x10], #16 // cur_input + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x14], #32 // cur_weight + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v5.8h, v0.h[0] + fmla v10.8h, v4.8h, v0.h[1] + fmla v11.8h, v5.8h, v0.h[1] + fmla v12.8h, v4.8h, v0.h[2] + fmla v13.8h, v5.8h, v0.h[2] + fmla v14.8h, v4.8h, v0.h[3] + fmla v15.8h, v5.8h, v0.h[3] + ld1 {v2.8h}, [x10], #16 // cur_input + ld1 {v3.4h}, [x10], #8 + ld1 {v6.8h, v7.8h}, [x14], #32 // cur_weight + + fmla v8.8h, v6.8h, v2.h[0] + fmla v9.8h, v7.8h, v2.h[0] + fmla v10.8h, v6.8h, v2.h[1] + fmla v11.8h, v7.8h, v2.h[1] + fmla v12.8h, v6.8h, v2.h[2] + fmla v13.8h, v7.8h, v2.h[2] + fmla v14.8h, v6.8h, v2.h[3] + fmla v15.8h, v7.8h, v2.h[3] + subs x19, x19, #2 + beq Bias4 + cmp x19, #2 + bge LoopDepth4 + + LoopDepth4One: + ld1 {v0.4h, v1.4h, v2.4h}, [x10], #24 // cur_input + ld1 {v3.8h, v4.8h}, [x14], #32 // cur_weight + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v11.8h, v4.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v13.8h, v4.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v15.8h, v4.8h, v0.h[3] + subs x19, x19, #1 + bgt LoopDepth4One + + Bias4: + cbz x3, Activation4 + ld1 {v0.8h, v1.8h}, [x12], #32 + fadd v8.8h, v8.8h, v0.8h + fadd v9.8h, v9.8h, v1.8h + fadd v10.8h, v10.8h, v0.8h + fadd v11.8h, v11.8h, v1.8h + fadd v12.8h, v12.8h, v0.8h + fadd v13.8h, v13.8h, v1.8h + fadd v14.8h, v14.8h, v0.8h + fadd v15.8h, v15.8h, v1.8h + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + movi v2.8h, #0x46, lsl #8 + fmin v8.8h, v8.8h, v2.8h + fmin v9.8h, v9.8h, v2.8h + fmin v10.8h, v10.8h, v2.8h + fmin v11.8h, v11.8h, v2.8h + fmin v12.8h, v12.8h, v2.8h + fmin v13.8h, v13.8h, v2.8h + fmin v14.8h, v14.8h, v2.8h + fmin v15.8h, v15.8h, v2.8h + + Relu4: + dup v2.8h, wzr + fmax v8.8h, v8.8h, v2.8h + fmax v9.8h, v9.8h, v2.8h + fmax v10.8h, v10.8h, v2.8h + fmax v11.8h, v11.8h, v2.8h + fmax v12.8h, v12.8h, v2.8h + fmax v13.8h, v13.8h, v2.8h + fmax v14.8h, v14.8h, v2.8h + fmax v15.8h, v15.8h, v2.8h + b Write + +LoopRow2: + mov x14, x1 // cur_weight + mov x13, x7 // reload_col + mov x12, x3 // reload_bias + + LoopCol2: + mov x11, x2 // cur_output + mov x10, x0 // cur_input + mov x19, x5 // reload_depth + CLEAR_OUTPUT_V8_V11 + cmp x19, #2 + blt LoopDepth2One + + LoopDepth2: + ld1 {v0.8h}, [x10], #16 // cur_input + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x14], #32 // cur_weight + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v5.8h, v0.h[0] + fmla v10.8h, v4.8h, v0.h[1] + fmla v11.8h, v5.8h, v0.h[1] + ld1 {v2.8h}, [x10], #16 // cur_input + ld1 {v3.4h}, [x10], #8 + ld1 {v6.8h, v7.8h}, [x14], #32 // cur_weight + + fmla v8.8h, v6.8h, v2.h[0] + fmla v9.8h, v7.8h, v2.h[0] + fmla v10.8h, v6.8h, v2.h[1] + fmla v11.8h, v7.8h, v2.h[1] + subs x19, x19, #2 + beq Bias2 + cmp x19, #2 + bge LoopDepth2 + + LoopDepth2One: + ld1 {v0.4h, v1.4h, v2.4h}, [x10], #24 // cur_input + ld1 {v3.8h, v4.8h}, [x14], #32 // cur_weight + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v11.8h, v4.8h, v0.h[1] + subs x19, x19, #1 + bgt LoopDepth2One + + Bias2: + cbz x3, Activation2 + ld1 {v0.8h, v1.8h}, [x12], #32 + fadd v8.8h, v8.8h, v0.8h + fadd v9.8h, v9.8h, v1.8h + fadd v10.8h, v10.8h, v0.8h + fadd v11.8h, v11.8h, v1.8h + + Activation2: + cmp x4, #3 + beq Relu62 + cmp x4, #1 + beq Relu2 + b Write + + Relu62: + movi v2.8h, #0x46, lsl #8 + fmin v8.8h, v8.8h, v2.8h + fmin v9.8h, v9.8h, v2.8h + fmin v10.8h, v10.8h, v2.8h + fmin v11.8h, v11.8h, v2.8h + + Relu2: + dup v2.8h, wzr + fmax v8.8h, v8.8h, v2.8h + fmax v9.8h, v9.8h, v2.8h + fmax v10.8h, v10.8h, v2.8h + fmax v11.8h, v11.8h, v2.8h + b Write + +LoopRow1: + mov x14, x1 // cur_weight + mov x13, x7 // reload_col + mov x12, x3 // reload_bias + + LoopCol1: + mov x11, x2 // cur_output + mov x10, x0 // cur_input + mov x19, x5 // reload_depth + CLEAR_OUTPUT_V8_V9 + cmp x19, #2 + blt LoopDepth1One + + LoopDepth1: + ld1 {v0.8h}, [x10], #16 // cur_input + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x14], #32 // cur_weight + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v5.8h, v0.h[0] + ld1 {v2.8h}, [x10], #16 // cur_input + ld1 {v3.4h}, [x10], #8 + ld1 {v6.8h, v7.8h}, [x14], #32 // cur_weight + + fmla v8.8h, v6.8h, v2.h[0] + fmla v9.8h, v7.8h, v2.h[0] + subs x19, x19, #2 + beq Bias1 + cmp x19, #2 + bge LoopDepth1 + + LoopDepth1One: + ld1 {v0.4h, v1.4h, v2.4h}, [x10], #24 // cur_input + ld1 {v3.8h, v4.8h}, [x14], #32 // cur_weight + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + subs x19, x19, #1 + bgt LoopDepth1One + + Bias1: + cbz x3, Activation1 + ld1 {v0.8h, v1.8h}, [x12], #32 + fadd v8.8h, v8.8h, v0.8h + fadd v9.8h, v9.8h, v1.8h + + Activation1: + cmp x4, #3 + beq Relu61 + cmp x4, #1 + beq Relu1 + b Write + + Relu61: + movi v2.8h, #0x46, lsl #8 + fmin v8.8h, v8.8h, v2.8h + fmin v9.8h, v9.8h, v2.8h + + Relu1: + dup v2.8h, wzr + fmax v8.8h, v8.8h, v2.8h + fmax v9.8h, v9.8h, v2.8h + b Write + + Write: + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + cmp x13, #8 + beq Write8 + cmp x13, #9 + beq Write9 + cmp x13, #10 + beq Write10 + cmp x13, #11 + beq Write11 + cmp x13, #12 + beq Write12 + cmp x13, #13 + beq Write13 + cmp x13, #14 + beq Write14 + cmp x13, #15 + beq Write15 + b Write16 + + Write1: + add x2, x2, #2 + str h8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str h10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str h12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str h14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str h16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str h18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str h20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str h22, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str h24, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str h26, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str h28, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str h30, [x11] + add x11, x11, x8 + add x11, x11, #2 + b WriteEnd + + Write2: + add x2, x2, #4 + add x19, x11, #2 + st1 {v8.h}[0], [x11], x8 + st1 {v8.h}[1], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.h}[0], [x11], x8 + st1 {v10.h}[1], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.h}[0], [x11], x8 + st1 {v12.h}[1], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.h}[0], [x11], x8 + st1 {v14.h}[1], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.h}[0], [x11], x8 + st1 {v16.h}[1], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.h}[0], [x11], x8 + st1 {v18.h}[1], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.h}[0], [x11], x8 + st1 {v20.h}[1], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.h}[0], [x11], x8 + st1 {v22.h}[1], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.h}[0], [x11], x8 + st1 {v24.h}[1], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.h}[0], [x11], x8 + st1 {v26.h}[1], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.h}[0], [x11], x8 + st1 {v28.h}[1], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.h}[0], [x11], x8 + st1 {v30.h}[1], [x19], x8 + add x11, x11, #4 + b WriteEnd + + Write3: + add x2, x2, #6 + add x19, x11, #4 + add x20, x11, #2 + st1 {v8.h}[0], [x11], x8 + st1 {v8.h}[1], [x20], x8 + st1 {v8.h}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.h}[0], [x11], x8 + st1 {v10.h}[1], [x20], x8 + st1 {v10.h}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.h}[0], [x11], x8 + st1 {v12.h}[1], [x20], x8 + st1 {v12.h}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.h}[0], [x11], x8 + st1 {v14.h}[1], [x20], x8 + st1 {v14.h}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.h}[0], [x11], x8 + st1 {v16.h}[1], [x20], x8 + st1 {v16.h}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.h}[0], [x11], x8 + st1 {v18.h}[1], [x20], x8 + st1 {v18.h}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.h}[0], [x11], x8 + st1 {v20.h}[1], [x20], x8 + st1 {v20.h}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.h}[0], [x11], x8 + st1 {v22.h}[1], [x20], x8 + st1 {v22.h}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.h}[0], [x11], x8 + st1 {v24.h}[1], [x20], x8 + st1 {v24.h}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.h}[0], [x11], x8 + st1 {v26.h}[1], [x20], x8 + st1 {v26.h}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.h}[0], [x11], x8 + st1 {v28.h}[1], [x20], x8 + st1 {v28.h}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.h}[0], [x11], x8 + st1 {v30.h}[1], [x20], x8 + st1 {v30.h}[2], [x19], x8 + add x11, x11, #6 + b WriteEnd + + Write4: + add x2, x2, #8 + st1 {v8.4h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write5: + add x2, x2, #10 + add x19, x11, #8 + st1 {v8.4h}, [x11], x8 + st1 {v8.h}[4], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4h}, [x11], x8 + st1 {v10.h}[4], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4h}, [x11], x8 + st1 {v12.h}[4], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4h}, [x11], x8 + st1 {v14.h}[4], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4h}, [x11], x8 + st1 {v16.h}[4], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.h}[4], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.h}[4], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.h}[4], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.h}[4], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.h}[4], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.h}[4], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.h}[4], [x19], x8 + add x11, x11, #10 + b WriteEnd + Write6: + add x2, x2, #12 + add x19, x11, #8 + add x20, x11, #10 + st1 {v8.4h}, [x11], x8 + st1 {v8.h}[4], [x19], x8 + st1 {v8.h}[5], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4h}, [x11], x8 + st1 {v10.h}[4], [x19], x8 + st1 {v10.h}[5], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4h}, [x11], x8 + st1 {v12.h}[4], [x19], x8 + st1 {v12.h}[5], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4h}, [x11], x8 + st1 {v14.h}[4], [x19], x8 + st1 {v14.h}[5], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4h}, [x11], x8 + st1 {v16.h}[4], [x19], x8 + st1 {v16.h}[5], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.h}[4], [x19], x8 + st1 {v18.h}[5], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.h}[4], [x19], x8 + st1 {v20.h}[5], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.h}[4], [x19], x8 + st1 {v22.h}[5], [x20], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.h}[4], [x19], x8 + st1 {v24.h}[5], [x20], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.h}[4], [x19], x8 + st1 {v26.h}[5], [x20], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.h}[4], [x19], x8 + st1 {v28.h}[5], [x20], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.h}[4], [x19], x8 + st1 {v30.h}[5], [x20], x8 + add x11, x11, #12 + b WriteEnd + Write7: + add x2, x2, #14 + add x19, x11, #8 + add x20, x11, #10 + add x10, x11, #12 + st1 {v8.4h}, [x11], x8 + st1 {v8.h}[4], [x19], x8 + st1 {v8.h}[5], [x20], x8 + st1 {v8.h}[6], [x10], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4h}, [x11], x8 + st1 {v10.h}[4], [x19], x8 + st1 {v10.h}[5], [x20], x8 + st1 {v10.h}[6], [x10], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4h}, [x11], x8 + st1 {v12.h}[4], [x19], x8 + st1 {v12.h}[5], [x20], x8 + st1 {v12.h}[6], [x10], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4h}, [x11], x8 + st1 {v14.h}[4], [x19], x8 + st1 {v14.h}[5], [x20], x8 + st1 {v14.h}[6], [x10], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4h}, [x11], x8 + st1 {v16.h}[4], [x19], x8 + st1 {v16.h}[5], [x20], x8 + st1 {v16.h}[6], [x10], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.h}[4], [x19], x8 + st1 {v18.h}[5], [x20], x8 + st1 {v18.h}[6], [x10], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.h}[4], [x19], x8 + st1 {v20.h}[5], [x20], x8 + st1 {v20.h}[6], [x10], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.h}[4], [x19], x8 + st1 {v22.h}[5], [x20], x8 + st1 {v22.h}[6], [x10], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.h}[4], [x19], x8 + st1 {v24.h}[5], [x20], x8 + st1 {v24.h}[6], [x10], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.h}[4], [x19], x8 + st1 {v26.h}[5], [x20], x8 + st1 {v26.h}[6], [x10], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.h}[4], [x19], x8 + st1 {v28.h}[5], [x20], x8 + st1 {v28.h}[6], [x10], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.h}[4], [x19], x8 + st1 {v30.h}[5], [x20], x8 + st1 {v30.h}[6], [x10], x8 + add x11, x11, #14 + b WriteEnd + Write8: + add x2, x2, #16 + st1 {v8.8h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write9: + add x2, x2, #18 + add x19, x11, #16 + st1 {v8.8h}, [x11], x8 + st1 {v9.h}[0], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.h}[0], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.h}[0], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.h}[0], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.h}[0], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.h}[0], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.h}[0], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.h}[0], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.h}[0], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.h}[0], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.h}[0], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.h}[0], [x19], x8 + add x11, x11, #18 + b WriteEnd + Write10: + add x2, x2, #20 + add x19, x11, #16 + add x20, x11, #18 + st1 {v8.8h}, [x11], x8 + st1 {v9.h}[0], [x19], x8 + st1 {v9.h}[1], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.h}[0], [x19], x8 + st1 {v11.h}[1], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.h}[0], [x19], x8 + st1 {v13.h}[1], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.h}[0], [x19], x8 + st1 {v15.h}[1], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.h}[0], [x19], x8 + st1 {v17.h}[1], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.h}[0], [x19], x8 + st1 {v19.h}[1], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.h}[0], [x19], x8 + st1 {v21.h}[1], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.h}[0], [x19], x8 + st1 {v23.h}[1], [x20], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.h}[0], [x19], x8 + st1 {v25.h}[1], [x20], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.h}[0], [x19], x8 + st1 {v27.h}[1], [x20], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.h}[0], [x19], x8 + st1 {v29.h}[1], [x20], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.h}[0], [x19], x8 + st1 {v31.h}[1], [x20], x8 + add x11, x11, #20 + b WriteEnd + Write11: + add x2, x2, #22 + add x19, x11, #16 + add x20, x11, #18 + add x10, x11, #20 + st1 {v8.8h}, [x11], x8 + st1 {v9.h}[0], [x19], x8 + st1 {v9.h}[1], [x20], x8 + st1 {v9.h}[2], [x10], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.h}[0], [x19], x8 + st1 {v11.h}[1], [x20], x8 + st1 {v11.h}[2], [x10], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.h}[0], [x19], x8 + st1 {v13.h}[1], [x20], x8 + st1 {v13.h}[2], [x10], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.h}[0], [x19], x8 + st1 {v15.h}[1], [x20], x8 + st1 {v15.h}[2], [x10], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.h}[0], [x19], x8 + st1 {v17.h}[1], [x20], x8 + st1 {v17.h}[2], [x10], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.h}[0], [x19], x8 + st1 {v19.h}[1], [x20], x8 + st1 {v19.h}[2], [x10], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.h}[0], [x19], x8 + st1 {v21.h}[1], [x20], x8 + st1 {v21.h}[2], [x10], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.h}[0], [x19], x8 + st1 {v23.h}[1], [x20], x8 + st1 {v23.h}[2], [x10], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.h}[0], [x19], x8 + st1 {v25.h}[1], [x20], x8 + st1 {v25.h}[2], [x10], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.h}[0], [x19], x8 + st1 {v27.h}[1], [x20], x8 + st1 {v27.h}[2], [x10], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.h}[0], [x19], x8 + st1 {v29.h}[1], [x20], x8 + st1 {v29.h}[2], [x10], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.h}[0], [x19], x8 + st1 {v31.h}[1], [x20], x8 + st1 {v31.h}[2], [x10], x8 + add x11, x11, #22 + b WriteEnd + Write12: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.8h}, [x11], x8 + st1 {v9.4h}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.4h}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.4h}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.4h}, [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.4h}, [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.4h}, [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.4h}, [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.4h}, [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.4h}, [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.4h}, [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.4h}, [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.4h}, [x19], x8 + add x11, x11, #24 + b WriteEnd + Write13: + add x2, x2, #26 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.8h}, [x11], x8 + st1 {v9.4h}, [x19], x8 + st1 {v9.h}[4], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.4h}, [x19], x8 + st1 {v11.h}[4], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.4h}, [x19], x8 + st1 {v13.h}[4], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.4h}, [x19], x8 + st1 {v15.h}[4], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.4h}, [x19], x8 + st1 {v17.h}[4], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.4h}, [x19], x8 + st1 {v19.h}[4], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.4h}, [x19], x8 + st1 {v21.h}[4], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.4h}, [x19], x8 + st1 {v23.h}[4], [x20], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.4h}, [x19], x8 + st1 {v25.h}[4], [x20], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.4h}, [x19], x8 + st1 {v27.h}[4], [x20], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.4h}, [x19], x8 + st1 {v29.h}[4], [x20], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.4h}, [x19], x8 + st1 {v31.h}[4], [x20], x8 + add x11, x11, #26 + b WriteEnd + Write14: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + add x10, x11, #26 + st1 {v8.8h}, [x11], x8 + st1 {v9.4h}, [x19], x8 + st1 {v9.h}[4], [x20], x8 + st1 {v9.h}[5], [x10], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.4h}, [x19], x8 + st1 {v11.h}[4], [x20], x8 + st1 {v11.h}[5], [x10], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.4h}, [x19], x8 + st1 {v13.h}[4], [x20], x8 + st1 {v13.h}[5], [x10], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.4h}, [x19], x8 + st1 {v15.h}[4], [x20], x8 + st1 {v15.h}[5], [x10], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.4h}, [x19], x8 + st1 {v17.h}[4], [x20], x8 + st1 {v17.h}[5], [x10], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.4h}, [x19], x8 + st1 {v19.h}[4], [x20], x8 + st1 {v19.h}[5], [x10], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.4h}, [x19], x8 + st1 {v21.h}[4], [x20], x8 + st1 {v21.h}[5], [x10], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.4h}, [x19], x8 + st1 {v23.h}[4], [x20], x8 + st1 {v23.h}[5], [x10], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.4h}, [x19], x8 + st1 {v25.h}[4], [x20], x8 + st1 {v25.h}[5], [x10], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.4h}, [x19], x8 + st1 {v27.h}[4], [x20], x8 + st1 {v27.h}[5], [x10], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.4h}, [x19], x8 + st1 {v29.h}[4], [x20], x8 + st1 {v29.h}[5], [x10], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.4h}, [x19], x8 + st1 {v31.h}[4], [x20], x8 + st1 {v31.h}[5], [x10], x8 + add x11, x11, #28 + b WriteEnd + Write15: + add x2, x2, #30 + add x19, x11, #16 + add x20, x11, #24 + add x10, x11, #26 + add x15, x11, #28 + st1 {v8.8h}, [x11], x8 + st1 {v9.4h}, [x19], x8 + st1 {v9.h}[4], [x20], x8 + st1 {v9.h}[5], [x10], x8 + st1 {v9.h}[6], [x15], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.4h}, [x19], x8 + st1 {v11.h}[4], [x20], x8 + st1 {v11.h}[5], [x10], x8 + st1 {v11.h}[6], [x15], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.4h}, [x19], x8 + st1 {v13.h}[4], [x20], x8 + st1 {v13.h}[5], [x10], x8 + st1 {v13.h}[6], [x15], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.4h}, [x19], x8 + st1 {v15.h}[4], [x20], x8 + st1 {v15.h}[5], [x10], x8 + st1 {v15.h}[6], [x15], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.4h}, [x19], x8 + st1 {v17.h}[4], [x20], x8 + st1 {v17.h}[5], [x10], x8 + st1 {v17.h}[6], [x15], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.4h}, [x19], x8 + st1 {v19.h}[4], [x20], x8 + st1 {v19.h}[5], [x10], x8 + st1 {v19.h}[6], [x15], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.4h}, [x19], x8 + st1 {v21.h}[4], [x20], x8 + st1 {v21.h}[5], [x10], x8 + st1 {v21.h}[6], [x15], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.4h}, [x19], x8 + st1 {v23.h}[4], [x20], x8 + st1 {v23.h}[5], [x10], x8 + st1 {v23.h}[6], [x15], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.4h}, [x19], x8 + st1 {v25.h}[4], [x20], x8 + st1 {v25.h}[5], [x10], x8 + st1 {v25.h}[6], [x15], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.4h}, [x19], x8 + st1 {v27.h}[4], [x20], x8 + st1 {v27.h}[5], [x10], x8 + st1 {v27.h}[6], [x15], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.4h}, [x19], x8 + st1 {v29.h}[4], [x20], x8 + st1 {v29.h}[5], [x10], x8 + st1 {v29.h}[6], [x15], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.4h}, [x19], x8 + st1 {v31.h}[4], [x20], x8 + st1 {v31.h}[5], [x10], x8 + st1 {v31.h}[6], [x15], x8 + add x11, x11, #30 + b WriteEnd + Write16: + add x2, x2, #32 + add x19, x11, #16 + st1 {v8.8h}, [x11], x8 + st1 {v9.8h}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.8h}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.8h}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.8h}, [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.8h}, [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.8h}, [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.8h}, [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.8h}, [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.8h}, [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.8h}, [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.8h}, [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.8h}, [x19], x8 + add x11, x11, #32 + b WriteEnd + + WriteEnd: + subs x13, x13, #16 // col - 16 + ble LoopColEnd + cmp x6, #1 + ble LoopCol1 + cmp x6, #2 + ble LoopCol2 + cmp x6, #4 + ble LoopCol4 + cmp x6, #8 + ble LoopCol8 + b LoopCol12 + +LoopColEnd: + add x0, x0, x17 + mov x21, #2 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + subs x6, x6, #12 + bgt LoopRowStart + + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulBaseFp16Neon.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulBaseFp16Neon.S new file mode 100644 index 0000000000000000000000000000000000000000..228e73ec79506d65c82dc44707408857e779db86 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulBaseFp16Neon.S @@ -0,0 +1,960 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulBaseFp16Neon(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulBaseFp16Neon + sub sp, sp, #160 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] // act + add x8, x8, x8 // stride * sizeof(float16_t) + + add x16, x7, x7 // col * sizeof(float16_t) + add x17, x5, x5 // depth * zieof(float16_t) + mov x11, x2 + dup v12.8h, wzr + movi v13.8h, #0x46, lsl #8 +LoopRowStart: + cmp x6, #16 + bge LoopRow16 + cmp x6, #8 + bge LoopRow8 + b LoopRow4 + +LoopRow16: + mov x15, #16 + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol16: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + ld1 {v16.8h}, [x12], #16 + mov v17.16b, v16.16b + mov v18.16b, v16.16b + mov v19.16b, v16.16b + mov v20.16b, v16.16b + mov v21.16b, v16.16b + mov v22.16b, v16.16b + mov v23.16b, v16.16b + mov v24.16b, v16.16b + mov v25.16b, v16.16b + mov v26.16b, v16.16b + mov v27.16b, v16.16b + mov v28.16b, v16.16b + mov v29.16b, v16.16b + mov v30.16b, v16.16b + mov v31.16b, v16.16b + + cmp x19, #4 + blt LoopDepth16One + + LoopDepth16: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + fmla v24.8h, v8.8h, v1.h[0] + fmla v25.8h, v8.8h, v1.h[1] + fmla v26.8h, v8.8h, v1.h[2] + fmla v27.8h, v8.8h, v1.h[3] + fmla v28.8h, v8.8h, v1.h[4] + fmla v29.8h, v8.8h, v1.h[5] + fmla v30.8h, v8.8h, v1.h[6] + fmla v31.8h, v8.8h, v1.h[7] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v20.8h, v9.8h, v2.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v22.8h, v9.8h, v2.h[6] + fmla v23.8h, v9.8h, v2.h[7] + fmla v24.8h, v9.8h, v3.h[0] + fmla v25.8h, v9.8h, v3.h[1] + fmla v26.8h, v9.8h, v3.h[2] + fmla v27.8h, v9.8h, v3.h[3] + fmla v28.8h, v9.8h, v3.h[4] + fmla v29.8h, v9.8h, v3.h[5] + fmla v30.8h, v9.8h, v3.h[6] + fmla v31.8h, v9.8h, v3.h[7] + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v20.8h, v10.8h, v4.h[4] + fmla v21.8h, v10.8h, v4.h[5] + fmla v22.8h, v10.8h, v4.h[6] + fmla v23.8h, v10.8h, v4.h[7] + fmla v24.8h, v10.8h, v5.h[0] + fmla v25.8h, v10.8h, v5.h[1] + fmla v26.8h, v10.8h, v5.h[2] + fmla v27.8h, v10.8h, v5.h[3] + fmla v28.8h, v10.8h, v5.h[4] + fmla v29.8h, v10.8h, v5.h[5] + fmla v30.8h, v10.8h, v5.h[6] + fmla v31.8h, v10.8h, v5.h[7] + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + fmla v20.8h, v11.8h, v6.h[4] + fmla v21.8h, v11.8h, v6.h[5] + fmla v22.8h, v11.8h, v6.h[6] + fmla v23.8h, v11.8h, v6.h[7] + fmla v24.8h, v11.8h, v7.h[0] + fmla v25.8h, v11.8h, v7.h[1] + fmla v26.8h, v11.8h, v7.h[2] + fmla v27.8h, v11.8h, v7.h[3] + fmla v28.8h, v11.8h, v7.h[4] + fmla v29.8h, v11.8h, v7.h[5] + fmla v30.8h, v11.8h, v7.h[6] + fmla v31.8h, v11.8h, v7.h[7] + + subs x19, x19, #4 + beq Activation16 + cmp x19, #4 + bge LoopDepth16 + + LoopDepth16One: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + fmla v20.8h, v2.8h, v0.h[4] + fmla v21.8h, v2.8h, v0.h[5] + fmla v22.8h, v2.8h, v0.h[6] + fmla v23.8h, v2.8h, v0.h[7] + fmla v24.8h, v2.8h, v1.h[0] + fmla v25.8h, v2.8h, v1.h[1] + fmla v26.8h, v2.8h, v1.h[2] + fmla v27.8h, v2.8h, v1.h[3] + fmla v28.8h, v2.8h, v1.h[4] + fmla v29.8h, v2.8h, v1.h[5] + fmla v30.8h, v2.8h, v1.h[6] + fmla v31.8h, v2.8h, v1.h[7] + subs x19, x19, #1 + bgt LoopDepth16One + + Activation16: + cmp x4, #3 + beq Relu616 + cmp x4, #1 + beq Relu16 + b Write16 + Relu616: + fmin v16.8h, v16.8h, v13.8h + fmin v17.8h, v17.8h, v13.8h + fmin v18.8h, v18.8h, v13.8h + fmin v19.8h, v19.8h, v13.8h + fmin v20.8h, v20.8h, v13.8h + fmin v21.8h, v21.8h, v13.8h + fmin v22.8h, v22.8h, v13.8h + fmin v23.8h, v23.8h, v13.8h + fmin v24.8h, v24.8h, v13.8h + fmin v25.8h, v25.8h, v13.8h + fmin v26.8h, v26.8h, v13.8h + fmin v27.8h, v27.8h, v13.8h + fmin v28.8h, v28.8h, v13.8h + fmin v29.8h, v29.8h, v13.8h + fmin v30.8h, v30.8h, v13.8h + fmin v31.8h, v31.8h, v13.8h + Relu16: + fmax v16.8h, v16.8h, v12.8h + fmax v17.8h, v17.8h, v12.8h + fmax v18.8h, v18.8h, v12.8h + fmax v19.8h, v19.8h, v12.8h + fmax v20.8h, v20.8h, v12.8h + fmax v21.8h, v21.8h, v12.8h + fmax v22.8h, v22.8h, v12.8h + fmax v23.8h, v23.8h, v12.8h + fmax v24.8h, v24.8h, v12.8h + fmax v25.8h, v25.8h, v12.8h + fmax v26.8h, v26.8h, v12.8h + fmax v27.8h, v27.8h, v12.8h + fmax v28.8h, v28.8h, v12.8h + fmax v29.8h, v29.8h, v12.8h + fmax v30.8h, v30.8h, v12.8h + fmax v31.8h, v31.8h, v12.8h + Write16: + cmp x13, #8 + bge Write16x8 + b Write + Write16x8: + add x2, x2, #16 + st1 {v16.8h}, [x11], x8 + st1 {v17.8h}, [x11], x8 + st1 {v18.8h}, [x11], x8 + st1 {v19.8h}, [x11], x8 + st1 {v20.8h}, [x11], x8 + st1 {v21.8h}, [x11], x8 + st1 {v22.8h}, [x11], x8 + st1 {v23.8h}, [x11], x8 + st1 {v24.8h}, [x11], x8 + st1 {v25.8h}, [x11], x8 + st1 {v26.8h}, [x11], x8 + st1 {v27.8h}, [x11], x8 + st1 {v28.8h}, [x11], x8 + st1 {v29.8h}, [x11], x8 + st1 {v30.8h}, [x11], x8 + st1 {v31.8h}, [x11], x8 + b WriteEnd + +LoopRow8: + mov x15, #8 + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + ld1 {v16.8h}, [x12], #16 + mov v17.16b, v16.16b + mov v18.16b, v16.16b + mov v19.16b, v16.16b + mov v20.16b, v16.16b + mov v21.16b, v16.16b + mov v22.16b, v16.16b + mov v23.16b, v16.16b + + cmp x19, #4 + blt LoopDepth8One + + LoopDepth8: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + fmla v16.8h, v9.8h, v1.h[0] + fmla v17.8h, v9.8h, v1.h[1] + fmla v18.8h, v9.8h, v1.h[2] + fmla v19.8h, v9.8h, v1.h[3] + fmla v20.8h, v9.8h, v1.h[4] + fmla v21.8h, v9.8h, v1.h[5] + fmla v22.8h, v9.8h, v1.h[6] + fmla v23.8h, v9.8h, v1.h[7] + fmla v16.8h, v10.8h, v2.h[0] + fmla v17.8h, v10.8h, v2.h[1] + fmla v18.8h, v10.8h, v2.h[2] + fmla v19.8h, v10.8h, v2.h[3] + fmla v20.8h, v10.8h, v2.h[4] + fmla v21.8h, v10.8h, v2.h[5] + fmla v22.8h, v10.8h, v2.h[6] + fmla v23.8h, v10.8h, v2.h[7] + fmla v16.8h, v11.8h, v3.h[0] + fmla v17.8h, v11.8h, v3.h[1] + fmla v18.8h, v11.8h, v3.h[2] + fmla v19.8h, v11.8h, v3.h[3] + fmla v20.8h, v11.8h, v3.h[4] + fmla v21.8h, v11.8h, v3.h[5] + fmla v22.8h, v11.8h, v3.h[6] + fmla v23.8h, v11.8h, v3.h[7] + subs x19, x19, #4 + beq Activation8 + cmp x19, #4 + bge LoopDepth8 + LoopDepth8One: + ld1 {v0.8h}, [x10], #16 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + fmla v20.8h, v2.8h, v0.h[4] + fmla v21.8h, v2.8h, v0.h[5] + fmla v22.8h, v2.8h, v0.h[6] + fmla v23.8h, v2.8h, v0.h[7] + subs x19, x19, #1 + bgt LoopDepth8One + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write8_Row + Relu68: + fmin v16.8h, v16.8h, v13.8h + fmin v17.8h, v17.8h, v13.8h + fmin v18.8h, v18.8h, v13.8h + fmin v19.8h, v19.8h, v13.8h + fmin v20.8h, v20.8h, v13.8h + fmin v21.8h, v21.8h, v13.8h + fmin v22.8h, v22.8h, v13.8h + fmin v23.8h, v23.8h, v13.8h + Relu8: + fmax v16.8h, v16.8h, v12.8h + fmax v17.8h, v17.8h, v12.8h + fmax v18.8h, v18.8h, v12.8h + fmax v19.8h, v19.8h, v12.8h + fmax v20.8h, v20.8h, v12.8h + fmax v21.8h, v21.8h, v12.8h + fmax v22.8h, v22.8h, v12.8h + fmax v23.8h, v23.8h, v12.8h + Write8_Row: + cmp x13, #8 // row + bge Write8x8 + b Write + Write8x8: + add x2, x2, #16 + st1 {v16.8h}, [x11], x8 + st1 {v17.8h}, [x11], x8 + st1 {v18.8h}, [x11], x8 + st1 {v19.8h}, [x11], x8 + st1 {v20.8h}, [x11], x8 + st1 {v21.8h}, [x11], x8 + st1 {v22.8h}, [x11], x8 + st1 {v23.8h}, [x11], x8 + b WriteEnd + +LoopRow4: + mov x15, #4 + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + ld1 {v16.8h}, [x12], #16 + mov v17.16b, v16.16b + mov v18.16b, v16.16b + mov v19.16b, v16.16b + cmp x19, #4 + blt LoopDepth4One + LoopDepth4: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v16.8h, v9.8h, v0.h[4] + fmla v17.8h, v9.8h, v0.h[5] + fmla v18.8h, v9.8h, v0.h[6] + fmla v19.8h, v9.8h, v0.h[7] + fmla v16.8h, v10.8h, v1.h[0] + fmla v17.8h, v10.8h, v1.h[1] + fmla v18.8h, v10.8h, v1.h[2] + fmla v19.8h, v10.8h, v1.h[3] + fmla v16.8h, v11.8h, v1.h[4] + fmla v17.8h, v11.8h, v1.h[5] + fmla v18.8h, v11.8h, v1.h[6] + fmla v19.8h, v11.8h, v1.h[7] + subs x19, x19, #4 + beq Activation4 + cmp x19, #4 + bge LoopDepth4 + LoopDepth4One: + ld1 {v0.4h}, [x10], #8 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + subs x19, x19, #1 + bgt LoopDepth4One + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write4_Row + Relu64: + fmin v16.8h, v16.8h, v13.8h + fmin v17.8h, v17.8h, v13.8h + fmin v18.8h, v18.8h, v13.8h + fmin v19.8h, v19.8h, v13.8h + Relu4: + fmax v16.8h, v16.8h, v12.8h + fmax v17.8h, v17.8h, v12.8h + fmax v18.8h, v18.8h, v12.8h + fmax v19.8h, v19.8h, v12.8h + Write4_Row: + cmp x6, #4 + bge Write4x8 + b Write + Write4x8: + cmp x13, #8 + blt Write + add x2, x2, #16 + st1 {v16.8h}, [x11], x8 + st1 {v17.8h}, [x11], x8 + st1 {v18.8h}, [x11], x8 + st1 {v19.8h}, [x11], x8 + b WriteEnd + + Write: + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #2 + st1 {v16.h}[0], [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.h}[0], [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.h}[0], [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.h}[0], [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.h}[0], [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.h}[0], [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.h}[0], [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.h}[0], [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.h}[0], [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.h}[0], [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.h}[0], [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.h}[0], [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.h}[0], [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.h}[0], [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.h}[0], [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.h}[0], [x11], x8 + b WriteEnd + Write2: + add x2, x2, #4 + st1 {v16.s}[0], [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.s}[0], [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.s}[0], [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.s}[0], [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.s}[0], [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.s}[0], [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.s}[0], [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.s}[0], [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.s}[0], [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.s}[0], [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.s}[0], [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.s}[0], [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.s}[0], [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.s}[0], [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.s}[0], [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.s}[0], [x11], x8 + b WriteEnd + Write3: + add x2, x2, #6 + add x19, x11, #4 + st1 {v16.s}[0], [x11], x8 + st1 {v16.h}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.s}[0], [x11], x8 + st1 {v17.h}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.s}[0], [x11], x8 + st1 {v18.h}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.s}[0], [x11], x8 + st1 {v19.h}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.s}[0], [x11], x8 + st1 {v20.h}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.s}[0], [x11], x8 + st1 {v21.h}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.s}[0], [x11], x8 + st1 {v22.h}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.s}[0], [x11], x8 + st1 {v23.h}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.s}[0], [x11], x8 + st1 {v24.h}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.s}[0], [x11], x8 + st1 {v25.h}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.s}[0], [x11], x8 + st1 {v26.h}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.s}[0], [x11], x8 + st1 {v27.h}[2], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.s}[0], [x11], x8 + st1 {v28.h}[2], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.s}[0], [x11], x8 + st1 {v29.h}[2], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.s}[0], [x11], x8 + st1 {v30.h}[2], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.s}[0], [x11], x8 + st1 {v31.h}[2], [x19] + b WriteEnd + Write4: + add x2, x2, #8 + st1 {v16.4h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + b WriteEnd + Write5: + add x2, x2, #10 + add x19, x11, #8 + st1 {v16.4h}, [x11], x8 + st1 {v16.h}[4], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.h}[4], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.h}[4], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.h}[4], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.h}[4], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.h}[4], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.h}[4], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.h}[4], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.h}[4], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.h}[4], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.h}[4], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.h}[4], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.h}[4], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.h}[4], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.h}[4], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.h}[4], [x19] + b WriteEnd + Write6: + add x2, x2, #12 + add x19, x11, #8 + st1 {v16.4h}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.s}[2], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.s}[2], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.s}[2], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.s}[2], [x19] + b WriteEnd + Write7: + add x2, x2, #14 + add x19, x11, #8 + add x10, x11, #12 + st1 {v16.4h}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + st1 {v16.h}[6], [x10], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.s}[2], [x19], x8 + st1 {v17.h}[6], [x10], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + st1 {v18.h}[6], [x10], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.s}[2], [x19], x8 + st1 {v19.h}[6], [x10], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + st1 {v20.h}[6], [x10], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.s}[2], [x19], x8 + st1 {v21.h}[6], [x10], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + st1 {v22.h}[6], [x10], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.s}[2], [x19], x8 + st1 {v23.h}[6], [x10], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + st1 {v24.h}[6], [x10], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.s}[2], [x19], x8 + st1 {v25.h}[6], [x10], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + st1 {v26.h}[6], [x10], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.s}[2], [x19], x8 + st1 {v27.h}[6], [x10], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + st1 {v28.h}[6], [x10], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.s}[2], [x19], x8 + st1 {v29.h}[6], [x10], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.s}[2], [x19], x8 + st1 {v30.h}[6], [x10], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.s}[2], [x19] + st1 {v31.h}[6], [x10] + b WriteEnd + Write8: + add x2, x2, #16 + st1 {v16.8h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.8h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.8h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.8h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.8h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.8h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.8h}, [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.8h}, [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.8h}, [x11], x8 + + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + ble LoopColEnd + cmp x6, #16 + bge LoopCol16 + cmp x6, #8 + bge LoopCol8 + b LoopCol4 + +LoopColEnd: + sub x2, x2, x16 // dst - col * 2 + mul x21, x8, x15 // row_block * col * 2 + add x2, x2, x21 + subs x6, x6, x15 + mul x15, x15, x17 + add x0, x0, x15 + bgt LoopRowStart + + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..0f01e6f1f73f902f618219bd1756b3dcaed358d4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16.S @@ -0,0 +1,892 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int row, int col, int stride, bool write_nhwc) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: row +// w7: col +// w17: stride +// w13: writeC8 + +asm_function MatmulFp16Neon64 + sub sp, sp, #144 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + stp x19, x20, [sp, #128] + + mov w18, #16 // sizeof(float16) * 8 + mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float16) * 8 * depth + mov x11, x3 // bias flag + mov x19, #2 + ldr x17, [sp, #144] + mul x17, x17, x19 + +L1: + mov w10, w6 // reload lhs row + mov x12, x0 // reload lhs ptr + mov x19, x2 // reload dst ptr + +L2: + mov x16, x1 // reload rhs ptr + mov w13, w5 // reload depth + mov x14, x3 // reload bias ptr + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + cmp w13, #8 + blt CommLoopMul + +OptLoopMul8: + ld1 {v0.8h, v1.8h}, [x12], #32 + ld1 {v8.8h, v9.8h}, [x16], #32 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + ld1 {v2.8h, v3.8h}, [x12], #32 + fmla v24.8h, v8.8h, v1.h[0] + fmla v25.8h, v8.8h, v1.h[1] + fmla v26.8h, v8.8h, v1.h[2] + fmla v27.8h, v8.8h, v1.h[3] + fmla v28.8h, v8.8h, v1.h[4] + fmla v29.8h, v8.8h, v1.h[5] + fmla v30.8h, v8.8h, v1.h[6] + fmla v31.8h, v8.8h, v1.h[7] + ld1 {v10.8h, v11.8h}, [x16], #32 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v20.8h, v9.8h, v2.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v22.8h, v9.8h, v2.h[6] + fmla v23.8h, v9.8h, v2.h[7] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64 + fmla v24.8h, v9.8h, v3.h[0] + fmla v25.8h, v9.8h, v3.h[1] + fmla v26.8h, v9.8h, v3.h[2] + fmla v27.8h, v9.8h, v3.h[3] + fmla v28.8h, v9.8h, v3.h[4] + fmla v29.8h, v9.8h, v3.h[5] + fmla v30.8h, v9.8h, v3.h[6] + fmla v31.8h, v9.8h, v3.h[7] + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x16], #64 + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v20.8h, v10.8h, v4.h[4] + fmla v21.8h, v10.8h, v4.h[5] + fmla v22.8h, v10.8h, v4.h[6] + fmla v23.8h, v10.8h, v4.h[7] + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64 + fmla v24.8h, v10.8h, v5.h[0] + fmla v25.8h, v10.8h, v5.h[1] + fmla v26.8h, v10.8h, v5.h[2] + fmla v27.8h, v10.8h, v5.h[3] + fmla v28.8h, v10.8h, v5.h[4] + fmla v29.8h, v10.8h, v5.h[5] + fmla v30.8h, v10.8h, v5.h[6] + fmla v31.8h, v10.8h, v5.h[7] + ld1 {v4.8h, v5.8h}, [x12], #32 + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + fmla v20.8h, v11.8h, v6.h[4] + fmla v21.8h, v11.8h, v6.h[5] + fmla v22.8h, v11.8h, v6.h[6] + fmla v23.8h, v11.8h, v6.h[7] + fmla v24.8h, v11.8h, v7.h[0] + fmla v25.8h, v11.8h, v7.h[1] + fmla v26.8h, v11.8h, v7.h[2] + fmla v27.8h, v11.8h, v7.h[3] + fmla v28.8h, v11.8h, v7.h[4] + fmla v29.8h, v11.8h, v7.h[5] + fmla v30.8h, v11.8h, v7.h[6] + fmla v31.8h, v11.8h, v7.h[7] + ld1 {v6.8h, v7.8h}, [x12], #32 + fmla v16.8h, v12.8h, v0.h[0] + fmla v17.8h, v12.8h, v0.h[1] + fmla v18.8h, v12.8h, v0.h[2] + fmla v19.8h, v12.8h, v0.h[3] + fmla v20.8h, v12.8h, v0.h[4] + fmla v21.8h, v12.8h, v0.h[5] + fmla v22.8h, v12.8h, v0.h[6] + fmla v23.8h, v12.8h, v0.h[7] + fmla v24.8h, v12.8h, v1.h[0] + fmla v25.8h, v12.8h, v1.h[1] + fmla v26.8h, v12.8h, v1.h[2] + fmla v27.8h, v12.8h, v1.h[3] + fmla v28.8h, v12.8h, v1.h[4] + fmla v29.8h, v12.8h, v1.h[5] + fmla v30.8h, v12.8h, v1.h[6] + fmla v31.8h, v12.8h, v1.h[7] + fmla v16.8h, v13.8h, v2.h[0] + fmla v17.8h, v13.8h, v2.h[1] + fmla v18.8h, v13.8h, v2.h[2] + fmla v19.8h, v13.8h, v2.h[3] + fmla v20.8h, v13.8h, v2.h[4] + fmla v21.8h, v13.8h, v2.h[5] + fmla v22.8h, v13.8h, v2.h[6] + fmla v23.8h, v13.8h, v2.h[7] + fmla v24.8h, v13.8h, v3.h[0] + fmla v25.8h, v13.8h, v3.h[1] + fmla v26.8h, v13.8h, v3.h[2] + fmla v27.8h, v13.8h, v3.h[3] + fmla v28.8h, v13.8h, v3.h[4] + fmla v29.8h, v13.8h, v3.h[5] + fmla v30.8h, v13.8h, v3.h[6] + fmla v31.8h, v13.8h, v3.h[7] + fmla v16.8h, v14.8h, v4.h[0] + fmla v17.8h, v14.8h, v4.h[1] + fmla v18.8h, v14.8h, v4.h[2] + fmla v19.8h, v14.8h, v4.h[3] + fmla v20.8h, v14.8h, v4.h[4] + fmla v21.8h, v14.8h, v4.h[5] + fmla v22.8h, v14.8h, v4.h[6] + fmla v23.8h, v14.8h, v4.h[7] + fmla v24.8h, v14.8h, v5.h[0] + fmla v25.8h, v14.8h, v5.h[1] + fmla v26.8h, v14.8h, v5.h[2] + fmla v27.8h, v14.8h, v5.h[3] + fmla v28.8h, v14.8h, v5.h[4] + fmla v29.8h, v14.8h, v5.h[5] + fmla v30.8h, v14.8h, v5.h[6] + fmla v31.8h, v14.8h, v5.h[7] + fmla v16.8h, v15.8h, v6.h[0] + fmla v17.8h, v15.8h, v6.h[1] + fmla v18.8h, v15.8h, v6.h[2] + fmla v19.8h, v15.8h, v6.h[3] + fmla v20.8h, v15.8h, v6.h[4] + fmla v21.8h, v15.8h, v6.h[5] + fmla v22.8h, v15.8h, v6.h[6] + fmla v23.8h, v15.8h, v6.h[7] + fmla v24.8h, v15.8h, v7.h[0] + fmla v25.8h, v15.8h, v7.h[1] + fmla v26.8h, v15.8h, v7.h[2] + fmla v27.8h, v15.8h, v7.h[3] + fmla v28.8h, v15.8h, v7.h[4] + fmla v29.8h, v15.8h, v7.h[5] + fmla v30.8h, v15.8h, v7.h[6] + fmla v31.8h, v15.8h, v7.h[7] + + sub w13, w13, #8 + cmp w13, #0 + ble Bias + cmp w13, #8 + bge OptLoopMul8 + +CommLoopMul: + ld1 {v0.8h, v1.8h}, [x12], #32 + ld1 {v8.8h}, [x16], #16 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + fmla v24.8h, v8.8h, v1.h[0] + fmla v25.8h, v8.8h, v1.h[1] + fmla v26.8h, v8.8h, v1.h[2] + fmla v27.8h, v8.8h, v1.h[3] + fmla v28.8h, v8.8h, v1.h[4] + fmla v29.8h, v8.8h, v1.h[5] + fmla v30.8h, v8.8h, v1.h[6] + fmla v31.8h, v8.8h, v1.h[7] + + subs w13, w13, #1 + bgt CommLoopMul + +Bias: + cbz x11, Activation + ld1 {v0.8h}, [x14], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v0.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v0.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v0.8h + fadd v24.8h, v24.8h, v0.8h + fadd v25.8h, v25.8h, v0.8h + fadd v26.8h, v26.8h, v0.8h + fadd v27.8h, v27.8h, v0.8h + fadd v28.8h, v28.8h, v0.8h + fadd v29.8h, v29.8h, v0.8h + fadd v30.8h, v30.8h, v0.8h + fadd v31.8h, v31.8h, v0.8h + +Activation: + cmp w4, #3 + beq Relu6 + cmp w4, #1 + beq Relu + b Write + +Relu6: + movi v15.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v15.8h + fmin v17.8h, v17.8h, v15.8h + fmin v18.8h, v18.8h, v15.8h + fmin v19.8h, v19.8h, v15.8h + fmin v20.8h, v20.8h, v15.8h + fmin v21.8h, v21.8h, v15.8h + fmin v22.8h, v22.8h, v15.8h + fmin v23.8h, v23.8h, v15.8h + fmin v24.8h, v24.8h, v15.8h + fmin v25.8h, v25.8h, v15.8h + fmin v26.8h, v26.8h, v15.8h + fmin v27.8h, v27.8h, v15.8h + fmin v28.8h, v28.8h, v15.8h + fmin v29.8h, v29.8h, v15.8h + fmin v30.8h, v30.8h, v15.8h + fmin v31.8h, v31.8h, v15.8h + +Relu: + dup v14.4s, wzr + fmax v16.8h, v16.8h, v14.8h + fmax v17.8h, v17.8h, v14.8h + fmax v18.8h, v18.8h, v14.8h + fmax v19.8h, v19.8h, v14.8h + fmax v20.8h, v20.8h, v14.8h + fmax v21.8h, v21.8h, v14.8h + fmax v22.8h, v22.8h, v14.8h + fmax v23.8h, v23.8h, v14.8h + fmax v24.8h, v24.8h, v14.8h + fmax v25.8h, v25.8h, v14.8h + fmax v26.8h, v26.8h, v14.8h + fmax v27.8h, v27.8h, v14.8h + fmax v28.8h, v28.8h, v14.8h + fmax v29.8h, v29.8h, v14.8h + fmax v30.8h, v30.8h, v14.8h + fmax v31.8h, v31.8h, v14.8h + +Write: + ldrb w13, [sp, #152] + cbz w13, WriteC8 + cmp w7, #1 + beq Write1 + cmp w7, #2 + beq Write2 + cmp w7, #3 + beq Write3 + cmp w7, #4 + beq Write4 + cmp w7, #5 + beq Write5 + cmp w7, #6 + beq Write6 + cmp w7, #7 + beq Write7 + b Write8 + +Write1: + st1 {v16.h}[0], [x19], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.h}[0], [x19], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.h}[0], [x19], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.h}[0], [x19], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.h}[0], [x19], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.h}[0], [x19], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.h}[0], [x19], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.h}[0], [x19], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.h}[0], [x19], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.h}[0], [x19], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.h}[0], [x19], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.h}[0], [x19], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.h}[0], [x19], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.h}[0], [x19], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.h}[0], [x19], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.h}[0], [x19], x17 + b WriteEnd +Write2: + add x13, x19, #2 + st1 {v16.h}[0], [x19], x17 + st1 {v16.h}[1], [x13], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.h}[0], [x19], x17 + st1 {v17.h}[1], [x13], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.h}[0], [x19], x17 + st1 {v18.h}[1], [x13], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.h}[0], [x19], x17 + st1 {v19.h}[1], [x13], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.h}[0], [x19], x17 + st1 {v20.h}[1], [x13], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.h}[0], [x19], x17 + st1 {v21.h}[1], [x13], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.h}[0], [x19], x17 + st1 {v22.h}[1], [x13], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.h}[0], [x19], x17 + st1 {v23.h}[1], [x13], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.h}[0], [x19], x17 + st1 {v24.h}[1], [x13], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.h}[0], [x19], x17 + st1 {v25.h}[1], [x13], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.h}[0], [x19], x17 + st1 {v26.h}[1], [x13], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.h}[0], [x19], x17 + st1 {v27.h}[1], [x13], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.h}[0], [x19], x17 + st1 {v28.h}[1], [x13], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.h}[0], [x19], x17 + st1 {v29.h}[1], [x13], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.h}[0], [x19], x17 + st1 {v30.h}[1], [x13], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.h}[0], [x19], x17 + st1 {v31.h}[1], [x13], x17 + b WriteEnd +Write3: + add x13, x19, #2 + add x14, x19, #4 + st1 {v16.h}[0], [x19], x17 + st1 {v16.h}[1], [x13], x17 + st1 {v16.h}[2], [x14], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.h}[0], [x19], x17 + st1 {v17.h}[1], [x13], x17 + st1 {v17.h}[2], [x14], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.h}[0], [x19], x17 + st1 {v18.h}[1], [x13], x17 + st1 {v18.h}[2], [x14], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.h}[0], [x19], x17 + st1 {v19.h}[1], [x13], x17 + st1 {v19.h}[2], [x14], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.h}[0], [x19], x17 + st1 {v20.h}[1], [x13], x17 + st1 {v20.h}[2], [x14], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.h}[0], [x19], x17 + st1 {v21.h}[1], [x13], x17 + st1 {v21.h}[2], [x14], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.h}[0], [x19], x17 + st1 {v22.h}[1], [x13], x17 + st1 {v22.h}[2], [x14], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.h}[0], [x19], x17 + st1 {v23.h}[1], [x13], x17 + st1 {v23.h}[2], [x14], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.h}[0], [x19], x17 + st1 {v24.h}[1], [x13], x17 + st1 {v24.h}[2], [x14], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.h}[0], [x19], x17 + st1 {v25.h}[1], [x13], x17 + st1 {v25.h}[2], [x14], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.h}[0], [x19], x17 + st1 {v26.h}[1], [x13], x17 + st1 {v26.h}[2], [x14], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.h}[0], [x19], x17 + st1 {v27.h}[1], [x13], x17 + st1 {v27.h}[2], [x14], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.h}[0], [x19], x17 + st1 {v28.h}[1], [x13], x17 + st1 {v28.h}[2], [x14], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.h}[0], [x19], x17 + st1 {v29.h}[1], [x13], x17 + st1 {v29.h}[2], [x14], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.h}[0], [x19], x17 + st1 {v30.h}[1], [x13], x17 + st1 {v30.h}[2], [x14], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.h}[0], [x19], x17 + st1 {v31.h}[1], [x13], x17 + st1 {v31.h}[2], [x14], x17 + b WriteEnd +Write4: + st1 {v16.4h}, [x19], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.4h}, [x19], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x19], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x19], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x19], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x19], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x19], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x19], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x19], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x19], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x19], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x19], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x19], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x19], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x19], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x19], x17 + b WriteEnd +Write5: + add x13, x19, #8 + st1 {v16.4h}, [x19], x17 + st1 {v16.h}[4], [x13], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.4h}, [x19], x17 + st1 {v17.h}[4], [x13], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x19], x17 + st1 {v18.h}[4], [x13], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x19], x17 + st1 {v19.h}[4], [x13], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x19], x17 + st1 {v20.h}[4], [x13], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x19], x17 + st1 {v21.h}[4], [x13], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x19], x17 + st1 {v22.h}[4], [x13], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x19], x17 + st1 {v23.h}[4], [x13], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x19], x17 + st1 {v24.h}[4], [x13], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x19], x17 + st1 {v25.h}[4], [x13], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x19], x17 + st1 {v26.h}[4], [x13], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x19], x17 + st1 {v27.h}[4], [x13], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x19], x17 + st1 {v28.h}[4], [x13], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x19], x17 + st1 {v29.h}[4], [x13], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x19], x17 + st1 {v30.h}[4], [x13], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x19], x17 + st1 {v31.h}[4], [x13], x17 + b WriteEnd +Write6: + add x13, x19, #8 + add x14, x19, #10 + st1 {v16.4h}, [x19], x17 + st1 {v16.h}[4], [x13], x17 + st1 {v16.h}[5], [x14], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.4h}, [x19], x17 + st1 {v17.h}[4], [x13], x17 + st1 {v17.h}[5], [x14], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x19], x17 + st1 {v18.h}[4], [x13], x17 + st1 {v18.h}[5], [x14], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x19], x17 + st1 {v19.h}[4], [x13], x17 + st1 {v19.h}[5], [x14], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x19], x17 + st1 {v20.h}[4], [x13], x17 + st1 {v20.h}[5], [x14], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x19], x17 + st1 {v21.h}[4], [x13], x17 + st1 {v21.h}[5], [x14], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x19], x17 + st1 {v22.h}[4], [x13], x17 + st1 {v22.h}[5], [x14], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x19], x17 + st1 {v23.h}[4], [x13], x17 + st1 {v23.h}[5], [x14], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x19], x17 + st1 {v24.h}[4], [x13], x17 + st1 {v24.h}[5], [x14], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x19], x17 + st1 {v25.h}[4], [x13], x17 + st1 {v25.h}[5], [x14], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x19], x17 + st1 {v26.h}[4], [x13], x17 + st1 {v26.h}[5], [x14], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x19], x17 + st1 {v27.h}[4], [x13], x17 + st1 {v27.h}[5], [x14], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x19], x17 + st1 {v28.h}[4], [x13], x17 + st1 {v28.h}[5], [x14], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x19], x17 + st1 {v29.h}[4], [x13], x17 + st1 {v29.h}[5], [x14], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x19], x17 + st1 {v30.h}[4], [x13], x17 + st1 {v30.h}[5], [x14], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x19], x17 + st1 {v31.h}[4], [x13], x17 + st1 {v31.h}[5], [x14], x17 + b WriteEnd +Write7: + add x13, x19, #8 + add x14, x19, #10 + add x16, x19, #12 + st1 {v16.4h}, [x19], x17 + st1 {v16.h}[4], [x13], x17 + st1 {v16.h}[5], [x14], x17 + st1 {v16.h}[6], [x16], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.4h}, [x19], x17 + st1 {v17.h}[4], [x13], x17 + st1 {v17.h}[5], [x14], x17 + st1 {v17.h}[6], [x16], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x19], x17 + st1 {v18.h}[4], [x13], x17 + st1 {v18.h}[5], [x14], x17 + st1 {v18.h}[6], [x16], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x19], x17 + st1 {v19.h}[4], [x13], x17 + st1 {v19.h}[5], [x14], x17 + st1 {v19.h}[6], [x16], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x19], x17 + st1 {v20.h}[4], [x13], x17 + st1 {v20.h}[5], [x14], x17 + st1 {v20.h}[6], [x16], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x19], x17 + st1 {v21.h}[4], [x13], x17 + st1 {v21.h}[5], [x14], x17 + st1 {v21.h}[6], [x16], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x19], x17 + st1 {v22.h}[4], [x13], x17 + st1 {v22.h}[5], [x14], x17 + st1 {v22.h}[6], [x16], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x19], x17 + st1 {v23.h}[4], [x13], x17 + st1 {v23.h}[5], [x14], x17 + st1 {v23.h}[6], [x16], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x19], x17 + st1 {v24.h}[4], [x13], x17 + st1 {v24.h}[5], [x14], x17 + st1 {v24.h}[6], [x16], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x19], x17 + st1 {v25.h}[4], [x13], x17 + st1 {v25.h}[5], [x14], x17 + st1 {v25.h}[6], [x16], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x19], x17 + st1 {v26.h}[4], [x13], x17 + st1 {v26.h}[5], [x14], x17 + st1 {v26.h}[6], [x16], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x19], x17 + st1 {v27.h}[4], [x13], x17 + st1 {v27.h}[5], [x14], x17 + st1 {v27.h}[6], [x16], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x19], x17 + st1 {v28.h}[4], [x13], x17 + st1 {v28.h}[5], [x14], x17 + st1 {v28.h}[6], [x16], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x19], x17 + st1 {v29.h}[4], [x13], x17 + st1 {v29.h}[5], [x14], x17 + st1 {v29.h}[6], [x16], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x19], x17 + st1 {v30.h}[4], [x13], x17 + st1 {v30.h}[5], [x14], x17 + st1 {v30.h}[6], [x16], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x19], x17 + st1 {v31.h}[4], [x13], x17 + st1 {v31.h}[5], [x14], x17 + st1 {v31.h}[6], [x16], x17 + b WriteEnd +WriteC8: + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64 + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x2], #64 + st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x2], #64 + b WriteEnd +Write8: + st1 {v16.8h}, [x19], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.8h}, [x19], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.8h}, [x19], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.8h}, [x19], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.8h}, [x19], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.8h}, [x19], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.8h}, [x19], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.8h}, [x19], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.8h}, [x19], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.8h}, [x19], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.8h}, [x19], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.8h}, [x19], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.8h}, [x19], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.8h}, [x19], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.8h}, [x19], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.8h}, [x19], x17 + +WriteEnd: + subs w10, w10, #16 // lhs row - 8 + bgt L2 + +End2: + subs w7, w7, #8 // rhs col - 8 + add x1, x1, x15 // rhs ptr + stride + add x3, x3, #16 // bias ptr + stride + ldrb w13, [sp, #152] + cbz w13, NoDstStep + add x2, x2, #16 // dst ptr + stride +NoDstStep: + bgt L1 + +End1: + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16Opt.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16Opt.S new file mode 100644 index 0000000000000000000000000000000000000000..c55e83a5b1ae6f9aa5121d0ce229c656c7479580 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16Opt.S @@ -0,0 +1,1185 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFp16Neon64Opt + sub sp, sp, #96 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + stp x19, x20, [sp, #64] + stp x21, x22, [sp, #80] + + ldr x8, [sp, #96] + ldr x9, [sp, #104] + + mov x21, #32 // sizeof(float16_t) * 16 + mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float16_t) * 16 * depth + cbnz x9, NoC8Steps + mov x11, x2 + mov x21, #16 + mul x16, x6, x21 // row * 8 * sizeof(float16_t) +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x21, #2 + mul x15, x7, x8 + mul x15, x15, x21 // kernel_size * col *sizeof(float16_t) + mov x21, #16 + mul x16, x8, x21 // kernel_size * 8 * sizeof(float16_t) +NoWinoSteps: + mov x21, #2 + mul x8, x8, x21 + +LoopRowStart: + cmp x6, #1 + ble LoopRow + cmp x6, #2 + ble LoopRow2 + cmp x6, #4 + ble LoopRow4 + cmp x6, #8 + ble LoopRow8 + +LoopRow16: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol16: + cbz x9, NoReloadDst16 + mov x11, x2 + NoReloadDst16: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + cmp x19, #4 + blt LoopDepth16One + + LoopDepth16: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + fmla v24.8h, v8.8h, v1.h[0] + fmla v25.8h, v8.8h, v1.h[1] + fmla v26.8h, v8.8h, v1.h[2] + fmla v27.8h, v8.8h, v1.h[3] + fmla v28.8h, v8.8h, v1.h[4] + fmla v29.8h, v8.8h, v1.h[5] + fmla v30.8h, v8.8h, v1.h[6] + fmla v31.8h, v8.8h, v1.h[7] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v20.8h, v9.8h, v2.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v22.8h, v9.8h, v2.h[6] + fmla v23.8h, v9.8h, v2.h[7] + fmla v24.8h, v9.8h, v3.h[0] + fmla v25.8h, v9.8h, v3.h[1] + fmla v26.8h, v9.8h, v3.h[2] + fmla v27.8h, v9.8h, v3.h[3] + fmla v28.8h, v9.8h, v3.h[4] + fmla v29.8h, v9.8h, v3.h[5] + fmla v30.8h, v9.8h, v3.h[6] + fmla v31.8h, v9.8h, v3.h[7] + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v20.8h, v10.8h, v4.h[4] + fmla v21.8h, v10.8h, v4.h[5] + fmla v22.8h, v10.8h, v4.h[6] + fmla v23.8h, v10.8h, v4.h[7] + fmla v24.8h, v10.8h, v5.h[0] + fmla v25.8h, v10.8h, v5.h[1] + fmla v26.8h, v10.8h, v5.h[2] + fmla v27.8h, v10.8h, v5.h[3] + fmla v28.8h, v10.8h, v5.h[4] + fmla v29.8h, v10.8h, v5.h[5] + fmla v30.8h, v10.8h, v5.h[6] + fmla v31.8h, v10.8h, v5.h[7] + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + fmla v20.8h, v11.8h, v6.h[4] + fmla v21.8h, v11.8h, v6.h[5] + fmla v22.8h, v11.8h, v6.h[6] + fmla v23.8h, v11.8h, v6.h[7] + fmla v24.8h, v11.8h, v7.h[0] + fmla v25.8h, v11.8h, v7.h[1] + fmla v26.8h, v11.8h, v7.h[2] + fmla v27.8h, v11.8h, v7.h[3] + fmla v28.8h, v11.8h, v7.h[4] + fmla v29.8h, v11.8h, v7.h[5] + fmla v30.8h, v11.8h, v7.h[6] + fmla v31.8h, v11.8h, v7.h[7] + + subs x19, x19, #4 + beq Bias16 + cmp x19, #4 + bge LoopDepth16 + + LoopDepth16One: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + fmla v20.8h, v2.8h, v0.h[4] + fmla v21.8h, v2.8h, v0.h[5] + fmla v22.8h, v2.8h, v0.h[6] + fmla v23.8h, v2.8h, v0.h[7] + fmla v24.8h, v2.8h, v1.h[0] + fmla v25.8h, v2.8h, v1.h[1] + fmla v26.8h, v2.8h, v1.h[2] + fmla v27.8h, v2.8h, v1.h[3] + fmla v28.8h, v2.8h, v1.h[4] + fmla v29.8h, v2.8h, v1.h[5] + fmla v30.8h, v2.8h, v1.h[6] + fmla v31.8h, v2.8h, v1.h[7] + + subs x19, x19, #1 + bgt LoopDepth16One + + Bias16: + cbz x3, Activation16 + ld1 {v0.8h}, [x12], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v0.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v0.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v0.8h + fadd v24.8h, v24.8h, v0.8h + fadd v25.8h, v25.8h, v0.8h + fadd v26.8h, v26.8h, v0.8h + fadd v27.8h, v27.8h, v0.8h + fadd v28.8h, v28.8h, v0.8h + fadd v29.8h, v29.8h, v0.8h + fadd v30.8h, v30.8h, v0.8h + fadd v31.8h, v31.8h, v0.8h + + Activation16: + cmp x4, #3 + beq Relu616 + cmp x4, #1 + beq Relu16 + b Write + + Relu616: + movi v2.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + fmin v18.8h, v18.8h, v2.8h + fmin v19.8h, v19.8h, v2.8h + fmin v20.8h, v20.8h, v2.8h + fmin v21.8h, v21.8h, v2.8h + fmin v22.8h, v22.8h, v2.8h + fmin v23.8h, v23.8h, v2.8h + fmin v24.8h, v24.8h, v2.8h + fmin v25.8h, v25.8h, v2.8h + fmin v26.8h, v26.8h, v2.8h + fmin v27.8h, v27.8h, v2.8h + fmin v28.8h, v28.8h, v2.8h + fmin v29.8h, v29.8h, v2.8h + fmin v30.8h, v30.8h, v2.8h + fmin v31.8h, v31.8h, v2.8h + + Relu16: + dup v2.8h, wzr + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + fmax v18.8h, v18.8h, v2.8h + fmax v19.8h, v19.8h, v2.8h + fmax v20.8h, v20.8h, v2.8h + fmax v21.8h, v21.8h, v2.8h + fmax v22.8h, v22.8h, v2.8h + fmax v23.8h, v23.8h, v2.8h + fmax v24.8h, v24.8h, v2.8h + fmax v25.8h, v25.8h, v2.8h + fmax v26.8h, v26.8h, v2.8h + fmax v27.8h, v27.8h, v2.8h + fmax v28.8h, v28.8h, v2.8h + fmax v29.8h, v29.8h, v2.8h + fmax v30.8h, v30.8h, v2.8h + fmax v31.8h, v31.8h, v2.8h + b Write + +LoopRow8: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + cbz x9, NoReloadDst8 + mov x11, x2 + NoReloadDst8: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + + cmp x19, #4 + blt LoopDepth8One + + LoopDepth8: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v20.8h, v9.8h, v2.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v22.8h, v9.8h, v2.h[6] + fmla v23.8h, v9.8h, v2.h[7] + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v20.8h, v10.8h, v4.h[4] + fmla v21.8h, v10.8h, v4.h[5] + fmla v22.8h, v10.8h, v4.h[6] + fmla v23.8h, v10.8h, v4.h[7] + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + fmla v20.8h, v11.8h, v6.h[4] + fmla v21.8h, v11.8h, v6.h[5] + fmla v22.8h, v11.8h, v6.h[6] + fmla v23.8h, v11.8h, v6.h[7] + + subs x19, x19, #4 + beq Bias8 + cmp x19, #4 + bge LoopDepth8 + + LoopDepth8One: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + fmla v20.8h, v2.8h, v0.h[4] + fmla v21.8h, v2.8h, v0.h[5] + fmla v22.8h, v2.8h, v0.h[6] + fmla v23.8h, v2.8h, v0.h[7] + + subs x19, x19, #1 + bgt LoopDepth8One + + Bias8: + cbz x3, Activation8 + ld1 {v0.8h}, [x12], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v0.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v0.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v0.8h + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + movi v2.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + fmin v18.8h, v18.8h, v2.8h + fmin v19.8h, v19.8h, v2.8h + fmin v20.8h, v20.8h, v2.8h + fmin v21.8h, v21.8h, v2.8h + fmin v22.8h, v22.8h, v2.8h + fmin v23.8h, v23.8h, v2.8h + + Relu8: + dup v2.8h, wzr + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + fmax v18.8h, v18.8h, v2.8h + fmax v19.8h, v19.8h, v2.8h + fmax v20.8h, v20.8h, v2.8h + fmax v21.8h, v21.8h, v2.8h + fmax v22.8h, v22.8h, v2.8h + fmax v23.8h, v23.8h, v2.8h + b Write + +LoopRow4: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + cbz x9, NoReloadDst4 + mov x11, x2 + NoReloadDst4: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + + cmp x19, #4 + blt LoopDepth4One + + LoopDepth4: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + + subs x19, x19, #4 + beq Bias4 + cmp x19, #4 + bge LoopDepth4 + + LoopDepth4One: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + + subs x19, x19, #1 + bgt LoopDepth4One + + Bias4: + cbz x3, Activation4 + ld1 {v0.8h}, [x12], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v0.8h + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + movi v2.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + fmin v18.8h, v18.8h, v2.8h + fmin v19.8h, v19.8h, v2.8h + + Relu4: + dup v2.8h, wzr + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + fmax v18.8h, v18.8h, v2.8h + fmax v19.8h, v19.8h, v2.8h + b Write + +LoopRow2: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol2: + cbz x9, NoReloadDst2 + mov x11, x2 + NoReloadDst2: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + + cmp x19, #4 + blt LoopDepth2One + + LoopDepth2: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + + subs x19, x19, #4 + beq Bias2 + cmp x19, #4 + bge LoopDepth2 + + LoopDepth2One: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + + subs x19, x19, #1 + bgt LoopDepth2One + + Bias2: + cbz x3, Activation2 + ld1 {v0.8h}, [x12], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + + Activation2: + cmp x4, #3 + beq Relu62 + cmp x4, #1 + beq Relu2 + b Write + + Relu62: + movi v2.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + + Relu2: + dup v2.8h, wzr + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + b Write + +LoopRow: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol: + cbz x9, NoReloadDst + mov x11, x2 + NoReloadDst: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + dup v16.4s, wzr + + cmp x19, #4 + blt LoopDepthOne + + LoopDepth: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v16.8h, v10.8h, v4.h[0] + fmla v16.8h, v11.8h, v6.h[0] + + subs x19, x19, #4 + beq Bias + cmp x19, #4 + bge LoopDepth + + LoopDepthOne: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + + subs x19, x19, #1 + bgt LoopDepthOne + + Bias: + cbz x3, Activation + ld1 {v0.8h}, [x12], #16 + fadd v16.8h, v16.8h, v0.8h + + Activation: + cmp x4, #3 + beq Relu6 + cmp x4, #1 + beq Relu + b Write + + Relu6: + movi v2.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v2.8h + + Relu: + dup v2.8h, wzr + fmax v16.8h, v16.8h, v2.8h + + Write: + cmp x9, #2 + beq WriteWino + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #2 + str h16, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str h17, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str h18, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str h19, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str h20, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str h21, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str h22, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str h23, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str h24, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str h25, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str h26, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str h27, [x11] + cmp x6, #12 + beq WriteEnd + add x11, x11, x8 + str h28, [x11] + cmp x6, #13 + beq WriteEnd + add x11, x11, x8 + str h29, [x11] + cmp x6, #14 + beq WriteEnd + add x11, x11, x8 + str h30, [x11] + cmp x6, #15 + beq WriteEnd + add x11, x11, x8 + str h31, [x11] + add x11, x11, x8 + add x11, x11, #2 + b WriteEnd + Write2: + add x2, x2, #4 + st1 {v16.s}[0], [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.s}[0], [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.s}[0], [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.s}[0], [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.s}[0], [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.s}[0], [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.s}[0], [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.s}[0], [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.s}[0], [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.s}[0], [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.s}[0], [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.s}[0], [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.s}[0], [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.s}[0], [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.s}[0], [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.s}[0], [x11], x8 + add x11, x11, #4 + b WriteEnd + Write3: + add x2, x2, #6 + add x19, x11, #4 + st1 {v16.s}[0], [x11], x8 + st1 {v16.h}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.s}[0], [x11], x8 + st1 {v17.h}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.s}[0], [x11], x8 + st1 {v18.h}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.s}[0], [x11], x8 + st1 {v19.h}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.s}[0], [x11], x8 + st1 {v20.h}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.s}[0], [x11], x8 + st1 {v21.h}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.s}[0], [x11], x8 + st1 {v22.h}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.s}[0], [x11], x8 + st1 {v23.h}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.s}[0], [x11], x8 + st1 {v24.h}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.s}[0], [x11], x8 + st1 {v25.h}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.s}[0], [x11], x8 + st1 {v26.h}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.s}[0], [x11], x8 + st1 {v27.h}[2], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.s}[0], [x11], x8 + st1 {v28.h}[2], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.s}[0], [x11], x8 + st1 {v29.h}[2], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.s}[0], [x11], x8 + st1 {v30.h}[2], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.s}[0], [x11], x8 + st1 {v31.h}[2], [x19] + add x11, x11, #6 + b WriteEnd + Write4: + add x2, x2, #8 + st1 {v16.4h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write5: + add x2, x2, #10 + add x19, x11, #8 + st1 {v16.4h}, [x11], x8 + st1 {v16.h}[4], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.h}[4], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.h}[4], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.h}[4], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.h}[4], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.h}[4], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.h}[4], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.h}[4], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.h}[4], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.h}[4], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.h}[4], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.h}[4], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.h}[4], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.h}[4], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.h}[4], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.h}[4], [x19] + add x11, x11, #10 + b WriteEnd + Write6: + add x2, x2, #12 + add x19, x11, #8 + st1 {v16.4h}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.s}[2], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.s}[2], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.s}[2], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.s}[2], [x19] + add x11, x11, #12 + b WriteEnd + Write7: + add x2, x2, #14 + add x19, x11, #8 + add x10, x11, #12 + st1 {v16.4h}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + st1 {v16.h}[6], [x10], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.s}[2], [x19], x8 + st1 {v17.h}[6], [x10], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + st1 {v18.h}[6], [x10], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.s}[2], [x19], x8 + st1 {v19.h}[6], [x10], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + st1 {v20.h}[6], [x10], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.s}[2], [x19], x8 + st1 {v21.h}[6], [x10], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + st1 {v22.h}[6], [x10], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.s}[2], [x19], x8 + st1 {v23.h}[6], [x10], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + st1 {v24.h}[6], [x10], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.s}[2], [x19], x8 + st1 {v25.h}[6], [x10], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + st1 {v26.h}[6], [x10], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.s}[2], [x19], x8 + st1 {v27.h}[6], [x10], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + st1 {v28.h}[6], [x10], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.s}[2], [x19], x8 + st1 {v29.h}[6], [x10], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.s}[2], [x19], x8 + st1 {v30.h}[6], [x10], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.s}[2], [x19] + st1 {v31.h}[6], [x10] + add x11, x11, #14 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x19], #64 + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x19], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x19], #64 + st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x19], #64 + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v16.8h}, [x11], x15 + st1 {v17.8h}, [x11], x15 + st1 {v18.8h}, [x11], x15 + st1 {v19.8h}, [x11], x15 + st1 {v20.8h}, [x11], x15 + st1 {v21.8h}, [x11], x15 + st1 {v22.8h}, [x11], x15 + st1 {v23.8h}, [x11], x15 + st1 {v24.8h}, [x11], x15 + st1 {v25.8h}, [x11], x15 + st1 {v26.8h}, [x11], x15 + st1 {v27.8h}, [x11], x15 + st1 {v28.8h}, [x11], x15 + st1 {v29.8h}, [x11], x15 + st1 {v30.8h}, [x11], x15 + st1 {v31.8h}, [x11], x15 + b WriteEnd + Write8: + add x2, x2, #16 + st1 {v16.8h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.8h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.8h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.8h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.8h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.8h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.8h}, [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.8h}, [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.8h}, [x11], x8 + add x11, x11, #16 + + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + ble LoopColEnd + cmp x6, #1 + ble LoopCol + cmp x6, #2 + ble LoopCol2 + cmp x6, #4 + ble LoopCol4 + cmp x6, #8 + ble LoopCol8 + b LoopCol16 + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + mov x21, #2 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + b NoDstStep + C8DstStep: + add x2, x2, #256 + mov x11, x2 + NoDstStep: + subs x6, x6, #16 + bgt LoopRowStart + + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16OptV2.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16OptV2.S new file mode 100644 index 0000000000000000000000000000000000000000..545e075509b2c465bbd82467d822ae5272d114b0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16OptV2.S @@ -0,0 +1,2966 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFp16OptV2(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// size_t depth, size_t row, size_t col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFp16OptV2 + sub sp, sp, #192 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x29, x30, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] // writeMode + lsl x8, x8, #1 // stride * sizeof(float16_t) + + lsl x15, x7, #1 // col * sizeof(float16_t) + lsl x16, x5, #1 // depth * sizeof(float16_t) + mov x11, x2 + movi v7.8h, #0x46, lsl #8 + subs x6, x6, #12 + blt LoopRow8 +LoopRow12: + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol12x8 + LoopCol12x16: + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias12x16 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + dup v24.2d, xzr + dup v25.2d, xzr + dup v26.2d, xzr + dup v27.2d, xzr + dup v28.2d, xzr + dup v29.2d, xzr + dup v30.2d, xzr + dup v31.2d, xzr + b Compute12x16Enter + InitFromBias12x16: + ld1 {v8.8h, v9.8h}, [x12] + ld1 {v10.8h, v11.8h}, [x12] + ld1 {v12.8h, v13.8h}, [x12] + ld1 {v14.8h, v15.8h}, [x12] + ld1 {v16.8h, v17.8h}, [x12] + ld1 {v18.8h, v19.8h}, [x12] + ld1 {v20.8h, v21.8h}, [x12] + ld1 {v22.8h, v23.8h}, [x12] + ld1 {v24.8h, v25.8h}, [x12] + ld1 {v26.8h, v27.8h}, [x12] + ld1 {v28.8h, v29.8h}, [x12] + ld1 {v30.8h, v31.8h}, [x12] + add x12, x12, #32 + Compute12x16Enter: + bl Compute12x16Unit + Activation12x16: + cmp x4, #3 + beq Relu612x16 + cmp x4, #1 + beq Relu12x16 + b Write12x16 + + Relu612x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v15.8h, v15.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v21.8h, v21.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v23.8h, v23.8h, v7.8h + fmin v24.8h, v24.8h, v7.8h + fmin v25.8h, v25.8h, v7.8h + fmin v26.8h, v26.8h, v7.8h + fmin v27.8h, v27.8h, v7.8h + fmin v28.8h, v28.8h, v7.8h + fmin v29.8h, v29.8h, v7.8h + fmin v30.8h, v30.8h, v7.8h + fmin v31.8h, v31.8h, v7.8h + + Relu12x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v15.8h, v15.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v21.8h, v21.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v23.8h, v23.8h, v6.8h + fmax v24.8h, v24.8h, v6.8h + fmax v25.8h, v25.8h, v6.8h + fmax v26.8h, v26.8h, v6.8h + fmax v27.8h, v27.8h, v6.8h + fmax v28.8h, v28.8h, v6.8h + fmax v29.8h, v29.8h, v6.8h + fmax v30.8h, v30.8h, v6.8h + fmax v31.8h, v31.8h, v6.8h + Write12x16: + mov x22, x21 + add x23, x21, x8, lsl #2 + add x24, x21, x8, lsl #3 + st1 {v8.8h, v9.8h}, [x22], x8 + st1 {v10.8h, v11.8h}, [x22], x8 + st1 {v12.8h, v13.8h}, [x22], x8 + st1 {v14.8h, v15.8h}, [x22] + st1 {v16.8h, v17.8h}, [x23], x8 + st1 {v18.8h, v19.8h}, [x23], x8 + st1 {v20.8h, v21.8h}, [x23], x8 + st1 {v22.8h, v23.8h}, [x23] + st1 {v24.8h, v25.8h}, [x24], x8 + st1 {v26.8h, v27.8h}, [x24], x8 + st1 {v28.8h, v29.8h}, [x24], x8 + st1 {v30.8h, v31.8h}, [x24] + add x21, x21, #32 + subs x13, x13, #16 + bge LoopCol12x16 + + LoopCol12x8: + adds x13, x13, #16 + cbz x13, LoopRow12End + subs x13, x13, #8 + blt LoopCol12x4 + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias12x8 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + dup v24.2d, xzr + dup v26.2d, xzr + dup v28.2d, xzr + dup v30.2d, xzr + b Compute12x8Enter + InitFromBias12x8: + ld1 {v8.8h}, [x12] + ld1 {v10.8h}, [x12] + ld1 {v12.8h}, [x12] + ld1 {v14.8h}, [x12] + ld1 {v16.8h}, [x12] + ld1 {v18.8h}, [x12] + ld1 {v20.8h}, [x12] + ld1 {v22.8h}, [x12] + ld1 {v24.8h}, [x12] + ld1 {v26.8h}, [x12] + ld1 {v28.8h}, [x12] + ld1 {v30.8h}, [x12] + add x12, x12, #16 + Compute12x8Enter: + bl Compute12x8Unit + Activation12x8: + cmp x4, #3 + beq Relu612x8 + cmp x4, #1 + beq Relu12x8 + b Write12x8 + + Relu612x8: + fmin v8.8h, v8.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v24.8h, v24.8h, v7.8h + fmin v26.8h, v26.8h, v7.8h + fmin v28.8h, v28.8h, v7.8h + fmin v30.8h, v30.8h, v7.8h + + Relu12x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v24.8h, v24.8h, v6.8h + fmax v26.8h, v26.8h, v6.8h + fmax v28.8h, v28.8h, v6.8h + fmax v30.8h, v30.8h, v6.8h + Write12x8: + mov x22, x21 + add x23, x21, x8, lsl #2 + add x24, x21, x8, lsl #3 + st1 {v8.8h}, [x22], x8 + st1 {v10.8h}, [x22], x8 + st1 {v12.8h}, [x22], x8 + st1 {v14.8h}, [x22] + st1 {v16.8h}, [x23], x8 + st1 {v18.8h}, [x23], x8 + st1 {v20.8h}, [x23], x8 + st1 {v22.8h}, [x23] + st1 {v24.8h}, [x24], x8 + st1 {v26.8h}, [x24], x8 + st1 {v28.8h}, [x24], x8 + st1 {v30.8h}, [x24] + add x21, x21, #16 + subs x13, x13, #8 + + LoopCol12x4: + adds x13, x13, #8 + cbz x13, LoopRow12End + LoopCol12x4Core: + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + cbnz x12, InitFromBias12x4 + dup v8.2s, wzr + dup v10.2s, wzr + dup v12.2s, wzr + dup v14.2s, wzr + dup v16.2s, wzr + dup v18.2s, wzr + dup v20.2s, wzr + dup v22.2s, wzr + dup v24.2s, wzr + dup v26.2s, wzr + dup v28.2s, wzr + dup v30.2s, wzr + b Compute12x4Enter + InitFromBias12x4: + ld1 {v8.4h}, [x12] + ld1 {v10.4h}, [x12] + ld1 {v12.4h}, [x12] + ld1 {v14.4h}, [x12] + ld1 {v16.4h}, [x12] + ld1 {v18.4h}, [x12] + ld1 {v20.4h}, [x12] + ld1 {v22.4h}, [x12] + ld1 {v24.4h}, [x12] + ld1 {v26.4h}, [x12] + ld1 {v28.4h}, [x12] + ld1 {v30.4h}, [x12] + add x12, x12, #8 + Compute12x4Enter: + bl Compute12x4Unit + Activation12x4: + cmp x4, #3 + beq Relu612x4 + cmp x4, #1 + beq Relu12x4 + b Write12x4 + + Relu612x4: + fmin v8.4h, v8.4h, v7.4h + fmin v10.4h, v10.4h, v7.4h + fmin v12.4h, v12.4h, v7.4h + fmin v14.4h, v14.4h, v7.4h + fmin v16.4h, v16.4h, v7.4h + fmin v18.4h, v18.4h, v7.4h + fmin v20.4h, v20.4h, v7.4h + fmin v22.4h, v22.4h, v7.4h + fmin v24.4h, v24.4h, v7.4h + fmin v26.4h, v26.4h, v7.4h + fmin v28.4h, v28.4h, v7.4h + fmin v30.4h, v30.4h, v7.4h + + Relu12x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + fmax v10.4h, v10.4h, v6.4h + fmax v12.4h, v12.4h, v6.4h + fmax v14.4h, v14.4h, v6.4h + fmax v16.4h, v16.4h, v6.4h + fmax v18.4h, v18.4h, v6.4h + fmax v20.4h, v20.4h, v6.4h + fmax v22.4h, v22.4h, v6.4h + fmax v24.4h, v24.4h, v6.4h + fmax v26.4h, v26.4h, v6.4h + fmax v28.4h, v28.4h, v6.4h + fmax v30.4h, v30.4h, v6.4h + Write12x4: + mov x22, x21 + add x23, x21, x8, lsl #2 + add x24, x21, x8, lsl #3 + cmp x13, #1 + beq Write12x1 + cmp x13, #2 + beq Write12x2 + cmp x13, #3 + beq Write12x3 + st1 {v8.4h}, [x22], x8 + st1 {v10.4h}, [x22], x8 + st1 {v12.4h}, [x22], x8 + st1 {v14.4h}, [x22] + st1 {v16.4h}, [x23], x8 + st1 {v18.4h}, [x23], x8 + st1 {v20.4h}, [x23], x8 + st1 {v22.4h}, [x23] + st1 {v24.4h}, [x24], x8 + st1 {v26.4h}, [x24], x8 + st1 {v28.4h}, [x24], x8 + st1 {v30.4h}, [x24] + add x21, x21, #8 + subs x13, x13, #4 + bgt LoopCol12x4Core + b LoopRow12End + Write12x1: + st1 {v8.h}[0], [x22], x8 + st1 {v10.h}[0], [x22], x8 + st1 {v12.h}[0], [x22], x8 + st1 {v14.h}[0], [x22] + st1 {v16.h}[0], [x23], x8 + st1 {v18.h}[0], [x23], x8 + st1 {v20.h}[0], [x23], x8 + st1 {v22.h}[0], [x23] + st1 {v24.h}[0], [x24], x8 + st1 {v26.h}[0], [x24], x8 + st1 {v28.h}[0], [x24], x8 + st1 {v30.h}[0], [x24] + b LoopRow12End + Write12x2: + st1 {v8.s}[0], [x22], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v14.s}[0], [x22] + st1 {v16.s}[0], [x23], x8 + st1 {v18.s}[0], [x23], x8 + st1 {v20.s}[0], [x23], x8 + st1 {v22.s}[0], [x23] + st1 {v24.s}[0], [x24], x8 + st1 {v26.s}[0], [x24], x8 + st1 {v28.s}[0], [x24], x8 + st1 {v30.s}[0], [x24] + b LoopRow12End + Write12x3: + add x23, x22, #4 + st1 {v8.s}[0], [x22], x8 + st1 {v8.h}[2], [x23], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v10.h}[2], [x23], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v12.h}[2], [x23], x8 + st1 {v14.s}[0], [x22], x8 + st1 {v14.h}[2], [x23], x8 + st1 {v16.s}[0], [x22], x8 + st1 {v16.h}[2], [x23], x8 + st1 {v18.s}[0], [x22], x8 + st1 {v18.h}[2], [x23], x8 + st1 {v20.s}[0], [x22], x8 + st1 {v20.h}[2], [x23], x8 + st1 {v22.s}[0], [x22], x8 + st1 {v22.h}[2], [x23], x8 + st1 {v24.s}[0], [x22], x8 + st1 {v24.h}[2], [x23], x8 + st1 {v26.s}[0], [x22], x8 + st1 {v26.h}[2], [x23], x8 + st1 {v28.s}[0], [x22], x8 + st1 {v28.h}[2], [x23], x8 + st1 {v30.s}[0], [x22] + st1 {v30.h}[2], [x23] + LoopRow12End: + add x0, x0, x16, lsl #3 + add x0, x0, x16, lsl #2 + add x2, x2, x8, lsl #3 + add x2, x2, x8, lsl #2 + subs x6, x6, #12 + bge LoopRow12 + +LoopRow8: + adds x6, x6,#12 + cbz x6, End + subs x6, x6, #8 + blt LoopRow4 + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol8x8 + LoopCol8x16: + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias8x16 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + b Compute8x16Enter + InitFromBias8x16: + ld1 {v8.8h, v9.8h}, [x12] + ld1 {v10.8h, v11.8h}, [x12] + ld1 {v12.8h, v13.8h}, [x12] + ld1 {v14.8h, v15.8h}, [x12] + ld1 {v16.8h, v17.8h}, [x12] + ld1 {v18.8h, v19.8h}, [x12] + ld1 {v20.8h, v21.8h}, [x12] + ld1 {v22.8h, v23.8h}, [x12] + add x12, x12, #32 + Compute8x16Enter: + bl Compute8x16Unit + Activation8x16: + cmp x4, #3 + beq Relu68x16 + cmp x4, #1 + beq Relu8x16 + b Write8x16 + + Relu68x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v15.8h, v15.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v21.8h, v21.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v23.8h, v23.8h, v7.8h + + Relu8x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v15.8h, v15.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v21.8h, v21.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v23.8h, v23.8h, v6.8h + Write8x16: + mov x22, x21 + add x23, x21, x8, lsl #2 + st1 {v8.8h, v9.8h}, [x22], x8 + st1 {v10.8h, v11.8h}, [x22], x8 + st1 {v12.8h, v13.8h}, [x22], x8 + st1 {v14.8h, v15.8h}, [x22] + st1 {v16.8h, v17.8h}, [x23], x8 + st1 {v18.8h, v19.8h}, [x23], x8 + st1 {v20.8h, v21.8h}, [x23], x8 + st1 {v22.8h, v23.8h}, [x23] + add x21, x21, #32 + subs x13, x13, #16 + bge LoopCol8x16 + + LoopCol8x8: + adds x13, x13, #16 + cbz x13, LoopRow8End + subs x13, x13, #8 + blt LoopCol8x4 + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias8x8 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + b Compute8x8Enter + InitFromBias8x8: + ld1 {v8.8h}, [x12] + ld1 {v10.8h}, [x12] + ld1 {v12.8h}, [x12] + ld1 {v14.8h}, [x12] + ld1 {v16.8h}, [x12] + ld1 {v18.8h}, [x12] + ld1 {v20.8h}, [x12] + ld1 {v22.8h}, [x12] + add x12, x12, #16 + Compute8x8Enter: + bl Compute8x8Unit + Activation8x8: + cmp x4, #3 + beq Relu68x8 + cmp x4, #1 + beq Relu8x8 + b Write8x8 + + Relu68x8: + fmin v8.8h, v8.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + + Relu8x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + Write8x8: + mov x22, x21 + add x23, x21, x8, lsl #2 + st1 {v8.8h}, [x22], x8 + st1 {v10.8h}, [x22], x8 + st1 {v12.8h}, [x22], x8 + st1 {v14.8h}, [x22] + st1 {v16.8h}, [x23], x8 + st1 {v18.8h}, [x23], x8 + st1 {v20.8h}, [x23], x8 + st1 {v22.8h}, [x23] + add x21, x21, #16 + subs x13, x13, #8 + + LoopCol8x4: + adds x13, x13, #8 + cbz x13, LoopRow8End + LoopCol8x4Core: + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + cbnz x12, InitFromBias8x4 + dup v8.2s, wzr + dup v10.2s, wzr + dup v12.2s, wzr + dup v14.2s, wzr + dup v16.2s, wzr + dup v18.2s, wzr + dup v20.2s, wzr + dup v22.2s, wzr + b Compute8x4Enter + InitFromBias8x4: + ld1 {v8.4h}, [x12] + ld1 {v10.4h}, [x12] + ld1 {v12.4h}, [x12] + ld1 {v14.4h}, [x12] + ld1 {v16.4h}, [x12] + ld1 {v18.4h}, [x12] + ld1 {v20.4h}, [x12] + ld1 {v22.4h}, [x12] + add x12, x12, #8 + Compute8x4Enter: + bl Compute8x4Unit + Activation8x4: + cmp x4, #3 + beq Relu68x4 + cmp x4, #1 + beq Relu8x4 + b Write8x4 + + Relu68x4: + fmin v8.4h, v8.4h, v7.4h + fmin v10.4h, v10.4h, v7.4h + fmin v12.4h, v12.4h, v7.4h + fmin v14.4h, v14.4h, v7.4h + fmin v16.4h, v16.4h, v7.4h + fmin v18.4h, v18.4h, v7.4h + fmin v20.4h, v20.4h, v7.4h + fmin v22.4h, v22.4h, v7.4h + + Relu8x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + fmax v10.4h, v10.4h, v6.4h + fmax v12.4h, v12.4h, v6.4h + fmax v14.4h, v14.4h, v6.4h + fmax v16.4h, v16.4h, v6.4h + fmax v18.4h, v18.4h, v6.4h + fmax v20.4h, v20.4h, v6.4h + fmax v22.4h, v22.4h, v6.4h + Write8x4: + mov x22, x21 + add x23, x21, x8, lsl #2 + cmp x13, #1 + beq Write8x1 + cmp x13, #2 + beq Write8x2 + cmp x13, #3 + beq Write8x3 + st1 {v8.4h}, [x22], x8 + st1 {v10.4h}, [x22], x8 + st1 {v12.4h}, [x22], x8 + st1 {v14.4h}, [x22] + st1 {v16.4h}, [x23], x8 + st1 {v18.4h}, [x23], x8 + st1 {v20.4h}, [x23], x8 + st1 {v22.4h}, [x23] + add x21, x21, #8 + subs x13, x13, #4 + bgt LoopCol8x4Core + b LoopRow8End + Write8x1: + st1 {v8.h}[0], [x22], x8 + st1 {v10.h}[0], [x22], x8 + st1 {v12.h}[0], [x22], x8 + st1 {v14.h}[0], [x22] + st1 {v16.h}[0], [x23], x8 + st1 {v18.h}[0], [x23], x8 + st1 {v20.h}[0], [x23], x8 + st1 {v22.h}[0], [x23] + b LoopRow8End + Write8x2: + st1 {v8.s}[0], [x22], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v14.s}[0], [x22] + st1 {v16.s}[0], [x23], x8 + st1 {v18.s}[0], [x23], x8 + st1 {v20.s}[0], [x23], x8 + st1 {v22.s}[0], [x23] + b LoopRow8End + Write8x3: + add x23, x22, #4 + st1 {v8.s}[0], [x22], x8 + st1 {v8.h}[2], [x23], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v10.h}[2], [x23], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v12.h}[2], [x23], x8 + st1 {v14.s}[0], [x22], x8 + st1 {v14.h}[2], [x23], x8 + st1 {v16.s}[0], [x22], x8 + st1 {v16.h}[2], [x23], x8 + st1 {v18.s}[0], [x22], x8 + st1 {v18.h}[2], [x23], x8 + st1 {v20.s}[0], [x22], x8 + st1 {v20.h}[2], [x23], x8 + st1 {v22.s}[0], [x22], x8 + st1 {v22.h}[2], [x23], x8 + LoopRow8End: + add x0, x0, x16, lsl #3 + add x2, x2, x8, lsl #3 + subs x6, x6, #8 + +LoopRow4: + adds x6, x6, #8 + cbz x6, End + subs x6, x6, #4 + blt LoopRowTail + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol4x8 + LoopCol4x16: + mov x10, x0 // update matrixA + ld1 {v0.4h}, [x10], #8 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias4x16 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + b Compute4x16Enter + InitFromBias4x16: + ld1 {v8.8h, v9.8h}, [x12] + ld1 {v10.8h, v11.8h}, [x12] + ld1 {v12.8h, v13.8h}, [x12] + ld1 {v14.8h, v15.8h}, [x12] + add x12, x12, #32 + Compute4x16Enter: + bl Compute4x16Unit + Activation4x16: + cmp x4, #3 + beq Relu64x16 + cmp x4, #1 + beq Relu4x16 + b Write4x16 + + Relu64x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v15.8h, v15.8h, v7.8h + + Relu4x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v15.8h, v15.8h, v6.8h + Write4x16: + mov x22, x21 + st1 {v8.8h, v9.8h}, [x22], x8 + st1 {v10.8h, v11.8h}, [x22], x8 + st1 {v12.8h, v13.8h}, [x22], x8 + st1 {v14.8h, v15.8h}, [x22] + add x21, x21, #32 + subs x13, x13, #16 + bge LoopCol4x16 + + LoopCol4x8: + adds x13, x13, #16 + cbz x13, LoopRow4End + subs x13, x13, #8 + blt LoopCol4x4 + mov x10, x0 // update matrixA + ld1 {v0.4h}, [x10], #8 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias4x8 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + b Compute4x8Enter + InitFromBias4x8: + ld1 {v8.8h}, [x12] + ld1 {v10.8h}, [x12] + ld1 {v12.8h}, [x12] + ld1 {v14.8h}, [x12] + add x12, x12, #16 + Compute4x8Enter: + bl Compute4x8Unit + Activation4x8: + cmp x4, #3 + beq Relu64x8 + cmp x4, #1 + beq Relu4x8 + b Write4x8 + + Relu64x8: + fmin v8.8h, v8.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + + Relu4x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + Write4x8: + mov x22, x21 + st1 {v8.8h}, [x22], x8 + st1 {v10.8h}, [x22], x8 + st1 {v12.8h}, [x22], x8 + st1 {v14.8h}, [x22] + add x21, x21, #16 + subs x13, x13, #8 + + LoopCol4x4: + adds x13, x13, #8 + cbz x13, LoopRow4End + LoopCol4x4Core: + mov x10, x0 // update matrixA + ld1 {v0.4h}, [x10], #8 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + cbnz x12, InitFromBias4x4 + dup v8.2s, wzr + dup v10.2s, wzr + dup v12.2s, wzr + dup v14.2s, wzr + b Compute4x4Enter + InitFromBias4x4: + ld1 {v8.4h}, [x12] + ld1 {v10.4h}, [x12] + ld1 {v12.4h}, [x12] + ld1 {v14.4h}, [x12] + add x12, x12, #8 + Compute4x4Enter: + bl Compute4x4Unit + Activation4x4: + cmp x4, #3 + beq Relu64x4 + cmp x4, #1 + beq Relu4x4 + b Write4x4 + + Relu64x4: + fmin v8.4h, v8.4h, v7.4h + fmin v10.4h, v10.4h, v7.4h + fmin v12.4h, v12.4h, v7.4h + fmin v14.4h, v14.4h, v7.4h + + Relu4x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + fmax v10.4h, v10.4h, v6.4h + fmax v12.4h, v12.4h, v6.4h + fmax v14.4h, v14.4h, v6.4h + Write4x4: + mov x22, x21 + cmp x13, #1 + beq Write4x1 + cmp x13, #2 + beq Write4x2 + cmp x13, #3 + beq Write4x3 + st1 {v8.4h}, [x22], x8 + st1 {v10.4h}, [x22], x8 + st1 {v12.4h}, [x22], x8 + st1 {v14.4h}, [x22] + add x21, x21, #8 + subs x13, x13, #4 + bgt LoopCol4x4Core + b LoopRow4End + Write4x1: + st1 {v8.h}[0], [x22], x8 + st1 {v10.h}[0], [x22], x8 + st1 {v12.h}[0], [x22], x8 + st1 {v14.h}[0], [x22] + b LoopRow4End + Write4x2: + st1 {v8.s}[0], [x22], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v14.s}[0], [x22] + b LoopRow4End + Write4x3: + add x23, x22, #4 + st1 {v8.s}[0], [x22], x8 + st1 {v8.h}[2], [x23], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v10.h}[2], [x23], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v12.h}[2], [x23], x8 + st1 {v14.s}[0], [x22], x8 + st1 {v14.h}[2], [x23], x8 + LoopRow4End: + add x0, x0, x16, lsl #2 + add x2, x2, x8, lsl #2 + subs x6, x6, #4 + +LoopRowTail: + adds x6, x6, #4 + cbz x6, End + cmp x6, #1 + beq LoopRow1 + cmp x6, #2 + beq LoopRow2 + // LoopRow3 + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol3x8 + LoopCol3x16: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias3x16 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + b Compute3x16Enter + InitFromBias3x16: + ld1 {v8.8h, v9.8h}, [x12] + ld1 {v10.8h, v11.8h}, [x12] + ld1 {v12.8h, v13.8h}, [x12] + add x12, x12, #32 + Compute3x16Enter: + bl Compute3x16Unit + Activation3x16: + cmp x4, #3 + beq Relu63x16 + cmp x4, #1 + beq Relu3x16 + b Write3x16 + + Relu63x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + + Relu3x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + Write3x16: + mov x22, x21 + st1 {v8.8h, v9.8h}, [x22], x8 + st1 {v10.8h, v11.8h}, [x22], x8 + st1 {v12.8h, v13.8h}, [x22] + add x21, x21, #32 + subs x13, x13, #16 + bge LoopCol3x16 + + LoopCol3x8: + adds x13, x13, #16 + cbz x13, End + subs x13, x13, #8 + blt LoopCol3x4 + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias3x8 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + b Compute3x8Enter + InitFromBias3x8: + ld1 {v8.8h}, [x12] + ld1 {v10.8h}, [x12] + ld1 {v12.8h}, [x12] + add x12, x12, #16 + Compute3x8Enter: + bl Compute3x8Unit + Activation3x8: + cmp x4, #3 + beq Relu63x8 + cmp x4, #1 + beq Relu3x8 + b Write3x8 + + Relu63x8: + fmin v8.8h, v8.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + + Relu3x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + Write3x8: + mov x22, x21 + st1 {v8.8h}, [x22], x8 + st1 {v10.8h}, [x22], x8 + st1 {v12.8h}, [x22] + add x21, x21, #16 + subs x13, x13, #8 + + LoopCol3x4: + adds x13, x13, #8 + cbz x13, End + LoopCol3x4Core: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias3x4 + dup v8.2s, wzr + dup v10.2s, wzr + dup v12.2s, wzr + b Compute3x4Enter + InitFromBias3x4: + ld1 {v8.4h}, [x12] + ld1 {v10.4h}, [x12] + ld1 {v12.4h}, [x12] + add x12, x12, #8 + Compute3x4Enter: + bl Compute3x4Unit + Activation3x4: + cmp x4, #3 + beq Relu63x4 + cmp x4, #1 + beq Relu3x4 + b Write3x4 + + Relu63x4: + fmin v8.4h, v8.4h, v7.4h + fmin v10.4h, v10.4h, v7.4h + fmin v12.4h, v12.4h, v7.4h + + Relu3x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + fmax v10.4h, v10.4h, v6.4h + fmax v12.4h, v12.4h, v6.4h + Write3x4: + mov x22, x21 + cmp x13, #1 + beq Write3x1 + cmp x13, #2 + beq Write3x2 + cmp x13, #3 + beq Write3x3 + st1 {v8.4h}, [x22], x8 + st1 {v10.4h}, [x22], x8 + st1 {v12.4h}, [x22] + add x21, x21, #8 + subs x13, x13, #4 + bgt LoopCol3x4Core + b End + Write3x1: + st1 {v8.h}[0], [x22], x8 + st1 {v10.h}[0], [x22], x8 + st1 {v12.h}[0], [x22] + b End + Write3x2: + st1 {v8.s}[0], [x22], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v12.s}[0], [x22] + b End + Write3x3: + add x23, x22, #4 + st1 {v8.s}[0], [x22], x8 + st1 {v8.h}[2], [x23], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v10.h}[2], [x23], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v12.h}[2], [x23], x8 + b End + +LoopRow2: + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol2x8 + LoopCol2x16: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias2x16 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + b Compute2x16Enter + InitFromBias2x16: + ld1 {v8.8h, v9.8h}, [x12] + ld1 {v10.8h, v11.8h}, [x12] + add x12, x12, #32 + Compute2x16Enter: + bl Compute2x16Unit + Activation2x16: + cmp x4, #3 + beq Relu62x16 + cmp x4, #1 + beq Relu2x16 + b Write2x16 + + Relu62x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + + Relu2x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + Write2x16: + mov x22, x21 + st1 {v8.8h, v9.8h}, [x22], x8 + st1 {v10.8h, v11.8h}, [x22] + add x21, x21, #32 + subs x13, x13, #16 + bge LoopCol2x16 + + LoopCol2x8: + adds x13, x13, #16 + cbz x13, End + subs x13, x13, #8 + blt LoopCol2x4 + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias2x8 + dup v8.2d, xzr + dup v10.2d, xzr + b Compute2x8Enter + InitFromBias2x8: + ld1 {v8.8h}, [x12] + ld1 {v10.8h}, [x12] + add x12, x12, #16 + Compute2x8Enter: + bl Compute2x8Unit + Activation2x8: + cmp x4, #3 + beq Relu62x8 + cmp x4, #1 + beq Relu2x8 + b Write2x8 + + Relu62x8: + fmin v8.8h, v8.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + + Relu2x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + Write2x8: + mov x22, x21 + st1 {v8.8h}, [x22], x8 + st1 {v10.8h}, [x22] + add x21, x21, #16 + subs x13, x13, #8 + + LoopCol2x4: + adds x13, x13, #8 + cbz x13, End + LoopCol2x4Core: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias2x4 + dup v8.2s, wzr + dup v10.2s, wzr + b Compute2x4Enter + InitFromBias2x4: + ld1 {v8.4h}, [x12] + ld1 {v10.4h}, [x12] + add x12, x12, #8 + Compute2x4Enter: + bl Compute2x4Unit + Activation2x4: + cmp x4, #3 + beq Relu62x4 + cmp x4, #1 + beq Relu2x4 + b Write2x4 + + Relu62x4: + fmin v8.4h, v8.4h, v7.4h + fmin v10.4h, v10.4h, v7.4h + Relu2x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + fmax v10.4h, v10.4h, v6.4h + Write2x4: + mov x22, x21 + cmp x13, #1 + beq Write2x1 + cmp x13, #2 + beq Write2x2 + cmp x13, #3 + beq Write2x3 + st1 {v8.4h}, [x22], x8 + st1 {v10.4h}, [x22] + add x21, x21, #8 + subs x13, x13, #4 + bgt LoopCol2x4Core + b End + Write2x1: + st1 {v8.h}[0], [x22], x8 + st1 {v10.h}[0], [x22] + b End + Write2x2: + st1 {v8.s}[0], [x22], x8 + st1 {v10.s}[0], [x22] + b End + Write2x3: + add x23, x22, #4 + st1 {v8.s}[0], [x22], x8 + st1 {v8.h}[2], [x23], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v10.h}[2], [x23], x8 + b End + +LoopRow1: + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol1x8 + LoopCol1x16: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias1x16 + dup v8.2d, xzr + dup v9.2d, xzr + b Compute1x16Enter + InitFromBias1x16: + ld1 {v8.8h, v9.8h}, [x12], #32 + Compute1x16Enter: + bl Compute1x16Unit + Activation1x16: + cmp x4, #3 + beq Relu61x16 + cmp x4, #1 + beq Relu1x16 + b Write1x16 + + Relu61x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + + Relu1x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + Write1x16: + st1 {v8.8h, v9.8h}, [x21], #32 + subs x13, x13, #16 + bge LoopCol1x16 + + LoopCol1x8: + adds x13, x13, #16 + cbz x13, End + subs x13, x13, #8 + blt LoopCol1x4 + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias1x8 + dup v8.2d, xzr + b Compute1x8Enter + InitFromBias1x8: + ld1 {v8.8h}, [x12], #16 + Compute1x8Enter: + bl Compute1x8Unit + Activation1x8: + cmp x4, #3 + beq Relu61x8 + cmp x4, #1 + beq Relu1x8 + b Write1x8 + + Relu61x8: + fmin v8.8h, v8.8h, v7.8h + + Relu1x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + Write1x8: + st1 {v8.8h}, [x21], #16 + subs x13, x13, #8 + + LoopCol1x4: + adds x13, x13, #8 + cbz x13, End + LoopCol1x4Core: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias1x4 + dup v8.2s, wzr + b Compute1x4Enter + InitFromBias1x4: + ld1 {v8.4h}, [x12], #8 + Compute1x4Enter: + bl Compute1x4Unit + Activation1x4: + cmp x4, #3 + beq Relu61x4 + cmp x4, #1 + beq Relu1x4 + b Write1x4 + + Relu61x4: + fmin v8.4h, v8.4h, v7.4h + Relu1x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + Write1x4: + cmp x13, #1 + beq Write1x1 + cmp x13, #2 + beq Write1x2 + cmp x13, #3 + beq Write1x3 + st1 {v8.4h}, [x21], #8 + subs x13, x13, #4 + bgt LoopCol1x4Core + b End + Write1x1: + st1 {v8.h}[0], [x21] + b End + Write1x2: + st1 {v8.s}[0], [x21] + b End + Write1x3: + add x22, x21, #4 + st1 {v8.s}[0], [x21] + st1 {v8.h}[2], [x22] + b End + +Compute12x16Unit: + subs x14, x14, #2 + ble Compute12x16End + Compute12x16: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h, v2.8h}, [x10], #32 + ld1 {v4.8h, v5.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v6.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + fmla v25.8h, v4.8h, v1.h[0] + fmla v27.8h, v4.8h, v1.h[1] + fmla v29.8h, v4.8h, v1.h[2] + fmla v31.8h, v4.8h, v1.h[3] + + fmla v8.8h, v5.8h, v1.h[4] + fmla v10.8h, v5.8h, v1.h[5] + fmla v12.8h, v5.8h, v1.h[6] + fmla v14.8h, v5.8h, v1.h[7] + fmla v16.8h, v5.8h, v2.h[0] + fmla v18.8h, v5.8h, v2.h[1] + fmla v20.8h, v5.8h, v2.h[2] + fmla v22.8h, v5.8h, v2.h[3] + fmla v24.8h, v5.8h, v2.h[4] + fmla v26.8h, v5.8h, v2.h[5] + fmla v28.8h, v5.8h, v2.h[6] + fmla v30.8h, v5.8h, v2.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v6.8h, v1.h[4] + fmla v11.8h, v6.8h, v1.h[5] + fmla v13.8h, v6.8h, v1.h[6] + fmla v15.8h, v6.8h, v1.h[7] + prfm pldl1keep, [x10, #632] + ld1 {v0.8h}, [x10], #16 + fmla v17.8h, v6.8h, v2.h[0] + fmla v19.8h, v6.8h, v2.h[1] + fmla v21.8h, v6.8h, v2.h[2] + fmla v23.8h, v6.8h, v2.h[3] + fmla v25.8h, v6.8h, v2.h[4] + fmla v27.8h, v6.8h, v2.h[5] + fmla v29.8h, v6.8h, v2.h[6] + fmla v31.8h, v6.8h, v2.h[7] + + subs x14, x14, #2 + bgt Compute12x16 + Compute12x16End: + cbnz x14, Compute12x16End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + ld1 {v2.8h}, [x10], #16 + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + fmla v25.8h, v4.8h, v1.h[0] + fmla v27.8h, v4.8h, v1.h[1] + fmla v29.8h, v4.8h, v1.h[2] + fmla v31.8h, v4.8h, v1.h[3] + mov v0.16b, v2.16b + Compute12x16End1: + ld1 {v1.4h}, [x10] + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + fmla v25.8h, v4.8h, v1.h[0] + fmla v27.8h, v4.8h, v1.h[1] + fmla v29.8h, v4.8h, v1.h[2] + fmla v31.8h, v4.8h, v1.h[3] + ret + +Compute12x8Unit: + subs x14, x14, #2 + ble Compute12x8End + Compute12x8: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h, v2.8h}, [x10], #32 + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v4.8h, v1.h[4] + fmla v10.8h, v4.8h, v1.h[5] + fmla v12.8h, v4.8h, v1.h[6] + fmla v14.8h, v4.8h, v1.h[7] + ld1 {v0.8h}, [x10], #16 + fmla v16.8h, v4.8h, v2.h[0] + fmla v18.8h, v4.8h, v2.h[1] + fmla v20.8h, v4.8h, v2.h[2] + fmla v22.8h, v4.8h, v2.h[3] + fmla v24.8h, v4.8h, v2.h[4] + fmla v26.8h, v4.8h, v2.h[5] + fmla v28.8h, v4.8h, v2.h[6] + fmla v30.8h, v4.8h, v2.h[7] + + subs x14, x14, #2 + bgt Compute12x8 + Compute12x8End: + cbnz x14, Compute12x8End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + ld1 {v0.8h}, [x10], #16 + mov v3.16b, v4.16b + Compute12x8End1: + ld1 {v1.4h}, [x10] + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + ret + +Compute12x4Unit: + subs x14, x14, #2 + ble Compute12x4End + Compute12x4: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h, v2.8h}, [x10], #32 + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + fmla v24.4h, v3.4h, v1.h[0] + fmla v26.4h, v3.4h, v1.h[1] + fmla v28.4h, v3.4h, v1.h[2] + fmla v30.4h, v3.4h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v4.4h, v1.h[4] + fmla v10.4h, v4.4h, v1.h[5] + fmla v12.4h, v4.4h, v1.h[6] + fmla v14.4h, v4.4h, v1.h[7] + ld1 {v0.8h}, [x10], #16 + fmla v16.4h, v4.4h, v2.h[0] + fmla v18.4h, v4.4h, v2.h[1] + fmla v20.4h, v4.4h, v2.h[2] + fmla v22.4h, v4.4h, v2.h[3] + fmla v24.4h, v4.4h, v2.h[4] + fmla v26.4h, v4.4h, v2.h[5] + fmla v28.4h, v4.4h, v2.h[6] + fmla v30.4h, v4.4h, v2.h[7] + + subs x14, x14, #2 + bgt Compute12x4 + Compute12x4End: + cbnz x14, Compute12x4End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + fmla v24.4h, v3.4h, v1.h[0] + fmla v26.4h, v3.4h, v1.h[1] + fmla v28.4h, v3.4h, v1.h[2] + fmla v30.4h, v3.4h, v1.h[3] + ld1 {v0.8h}, [x10], #16 + mov v3.8b, v4.8b + Compute12x4End1: + ld1 {v1.4h}, [x10] + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + fmla v24.4h, v3.4h, v1.h[0] + fmla v26.4h, v3.4h, v1.h[1] + fmla v28.4h, v3.4h, v1.h[2] + fmla v30.4h, v3.4h, v1.h[3] + ret + +Compute8x16Unit: + subs x14, x14, #2 + ble Compute8x16End + Compute8x16: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10], #16 + ld1 {v4.8h, v5.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v6.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + + fmla v8.8h, v5.8h, v1.h[0] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v1.h[2] + fmla v14.8h, v5.8h, v1.h[3] + fmla v16.8h, v5.8h, v1.h[4] + fmla v18.8h, v5.8h, v1.h[5] + fmla v20.8h, v5.8h, v1.h[6] + fmla v22.8h, v5.8h, v1.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v6.8h, v1.h[0] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v1.h[2] + fmla v15.8h, v6.8h, v1.h[3] + prfm pldl1keep, [x10, #632] + ld1 {v0.8h}, [x10], #16 + fmla v17.8h, v6.8h, v1.h[4] + fmla v19.8h, v6.8h, v1.h[5] + fmla v21.8h, v6.8h, v1.h[6] + fmla v23.8h, v6.8h, v1.h[7] + + subs x14, x14, #2 + bgt Compute8x16 + Compute8x16End: + cbnz x14, Compute8x16End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10] + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + mov v0.16b, v1.16b + Compute8x16End1: + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + ret + +Compute8x8Unit: + subs x14, x14, #2 + ble Compute8x8End + Compute8x8: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10], #16 + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v4.8h, v1.h[0] + fmla v10.8h, v4.8h, v1.h[1] + fmla v12.8h, v4.8h, v1.h[2] + fmla v14.8h, v4.8h, v1.h[3] + ld1 {v0.8h}, [x10], #16 + fmla v16.8h, v4.8h, v1.h[4] + fmla v18.8h, v4.8h, v1.h[5] + fmla v20.8h, v4.8h, v1.h[6] + fmla v22.8h, v4.8h, v1.h[7] + + subs x14, x14, #2 + bgt Compute8x8 + Compute8x8End: + cbnz x14, Compute8x8End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10] + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + mov v0.16b, v1.16b + mov v3.16b, v4.16b + Compute8x8End1: + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + ret + +Compute8x4Unit: + subs x14, x14, #2 + ble Compute8x4End + Compute8x4: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10], #16 + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v4.4h, v1.h[0] + fmla v10.4h, v4.4h, v1.h[1] + fmla v12.4h, v4.4h, v1.h[2] + fmla v14.4h, v4.4h, v1.h[3] + ld1 {v0.8h}, [x10], #16 + fmla v16.4h, v4.4h, v1.h[4] + fmla v18.4h, v4.4h, v1.h[5] + fmla v20.4h, v4.4h, v1.h[6] + fmla v22.4h, v4.4h, v1.h[7] + + subs x14, x14, #2 + bgt Compute8x4 + Compute8x4End: + cbnz x14, Compute8x4End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10] + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + mov v0.16b, v1.16b + mov v3.8b, v4.8b + Compute8x4End1: + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + ret + +Compute4x16Unit: + subs x14, x14, #2 + ble Compute4x16End + Compute4x16: + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v6.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + + fmla v8.8h, v5.8h, v1.h[0] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v1.h[2] + fmla v14.8h, v5.8h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v6.8h, v1.h[0] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v1.h[2] + fmla v15.8h, v6.8h, v1.h[3] + ld1 {v0.4h}, [x10], #8 + + subs x14, x14, #2 + bgt Compute4x16 + Compute4x16End: + cbnz x14, Compute4x16End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10] + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + mov v0.8b, v1.8b + Compute4x16End1: + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + ret + +Compute4x8Unit: + subs x14, x14, #2 + ble Compute4x8End + Compute4x8: + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v4.8h, v1.h[0] + fmla v10.8h, v4.8h, v1.h[1] + fmla v12.8h, v4.8h, v1.h[2] + fmla v14.8h, v4.8h, v1.h[3] + ld1 {v0.4h}, [x10], #8 + + subs x14, x14, #2 + bgt Compute4x8 + Compute4x8End: + cbnz x14, Compute4x8End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10] + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + mov v0.8b, v1.8b + mov v3.16b, v4.16b + Compute4x8End1: + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + ret + +Compute4x4Unit: + subs x14, x14, #2 + ble Compute4x4End + Compute4x4: + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v4.4h, v1.h[0] + fmla v10.4h, v4.4h, v1.h[1] + fmla v12.4h, v4.4h, v1.h[2] + fmla v14.4h, v4.4h, v1.h[3] + ld1 {v0.4h}, [x10], #8 + + subs x14, x14, #2 + bgt Compute4x4 + Compute4x4End: + cbnz x14, Compute4x4End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10] + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + mov v0.8b, v1.8b + mov v3.8b, v4.8b + Compute4x4End1: + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + ret + +Compute3x16Unit: + add x19, x10, x16 + add x20, x10, x16, lsl #1 + subs x14, x14, #8 + blt Compute3x16End4 + Compute3x16: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + ld1 {v2.8h}, [x20], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v2.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v2.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v2.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + fmla v12.8h, v3.8h, v2.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + fmla v13.8h, v4.8h, v2.h[2] + fmla v8.8h, v5.8h, v0.h[3] + fmla v10.8h, v5.8h, v1.h[3] + fmla v12.8h, v5.8h, v2.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[3] + fmla v11.8h, v6.8h, v1.h[3] + fmla v13.8h, v6.8h, v2.h[3] + + fmla v8.8h, v3.8h, v0.h[4] + fmla v10.8h, v3.8h, v1.h[4] + fmla v12.8h, v3.8h, v2.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[4] + fmla v11.8h, v4.8h, v1.h[4] + fmla v13.8h, v4.8h, v2.h[4] + fmla v8.8h, v5.8h, v0.h[5] + fmla v10.8h, v5.8h, v1.h[5] + fmla v12.8h, v5.8h, v2.h[5] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[5] + fmla v11.8h, v6.8h, v1.h[5] + fmla v13.8h, v6.8h, v2.h[5] + fmla v8.8h, v3.8h, v0.h[6] + fmla v10.8h, v3.8h, v1.h[6] + fmla v12.8h, v3.8h, v2.h[6] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[6] + fmla v11.8h, v4.8h, v1.h[6] + fmla v13.8h, v4.8h, v2.h[6] + fmla v8.8h, v5.8h, v0.h[7] + fmla v10.8h, v5.8h, v1.h[7] + fmla v12.8h, v5.8h, v2.h[7] + fmla v9.8h, v6.8h, v0.h[7] + fmla v11.8h, v6.8h, v1.h[7] + fmla v13.8h, v6.8h, v2.h[7] + + subs x14, x14, #8 + bge Compute3x16 + Compute3x16End4: + adds x14, x14, #8 + cbz x14, Compute3x16Return + subs x14, x14, #4 + blt Compute3x16EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + ld1 {v2.4h}, [x20], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v2.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v2.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v2.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + fmla v12.8h, v3.8h, v2.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + fmla v13.8h, v4.8h, v2.h[2] + fmla v8.8h, v5.8h, v0.h[3] + fmla v10.8h, v5.8h, v1.h[3] + fmla v12.8h, v5.8h, v2.h[3] + fmla v9.8h, v6.8h, v0.h[3] + fmla v11.8h, v6.8h, v1.h[3] + fmla v13.8h, v6.8h, v2.h[3] + subs x14, x14, #4 + Compute3x16EndTail: + adds x14, x14, #4 + cbz x14, Compute3x16Return + cmp x14, #1 + beq Compute3x16EndTail1 + cmp x14, #2 + beq Compute3x16EndTail2 + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld1 {v2.s}[0], [x20], #4 + ld1 {v2.h}[2], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v2.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v2.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v2.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + fmla v12.8h, v3.8h, v2.h[2] + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + fmla v13.8h, v4.8h, v2.h[2] + b Compute3x16Return + Compute3x16EndTail2: + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld1 {v2.s}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v2.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v2.h[1] + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v2.h[1] + b Compute3x16Return + Compute3x16EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + ld1 {v2.h}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v2.h[0] + Compute3x16Return: + ret + +Compute3x8Unit: + add x19, x10, x16 + add x20, x10, x16, lsl #1 + subs x14, x14, #8 + blt Compute3x8End4 + Compute3x8: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + ld1 {v2.8h}, [x20], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v1.h[1] + fmla v12.8h, v4.8h, v2.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v10.8h, v5.8h, v1.h[2] + fmla v12.8h, v5.8h, v2.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v6.8h, v0.h[3] + fmla v10.8h, v6.8h, v1.h[3] + fmla v12.8h, v6.8h, v2.h[3] + fmla v8.8h, v3.8h, v0.h[4] + fmla v10.8h, v3.8h, v1.h[4] + fmla v12.8h, v3.8h, v2.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[5] + fmla v10.8h, v4.8h, v1.h[5] + fmla v12.8h, v4.8h, v2.h[5] + fmla v8.8h, v5.8h, v0.h[6] + fmla v10.8h, v5.8h, v1.h[6] + fmla v12.8h, v5.8h, v2.h[6] + fmla v8.8h, v6.8h, v0.h[7] + fmla v10.8h, v6.8h, v1.h[7] + fmla v12.8h, v6.8h, v2.h[7] + + subs x14, x14, #8 + bge Compute3x8 + Compute3x8End4: + adds x14, x14, #8 + cbz x14, Compute3x8Return + subs x14, x14, #4 + blt Compute3x8EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + ld1 {v2.4h}, [x20], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v1.h[1] + fmla v12.8h, v4.8h, v2.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v10.8h, v5.8h, v1.h[2] + fmla v12.8h, v5.8h, v2.h[2] + fmla v8.8h, v6.8h, v0.h[3] + fmla v10.8h, v6.8h, v1.h[3] + fmla v12.8h, v6.8h, v2.h[3] + subs x14, x14, #4 + Compute3x8EndTail: + adds x14, x14, #4 + cbz x14, Compute3x8Return + cmp x14, #1 + beq Compute3x8EndTail1 + cmp x14, #2 + beq Compute3x8EndTail2 + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld1 {v2.s}[0], [x20], #4 + ld1 {v2.h}[2], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h}, [x11], #16 + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v1.h[1] + fmla v12.8h, v4.8h, v2.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v10.8h, v5.8h, v1.h[2] + fmla v12.8h, v5.8h, v2.h[2] + b Compute3x8Return + Compute3x8EndTail2: + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld2 {v2.h, v3.h}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v5.8h, v0.h[0] + fmla v10.8h, v5.8h, v1.h[0] + fmla v12.8h, v5.8h, v2.h[0] + fmla v8.8h, v6.8h, v0.h[1] + fmla v10.8h, v6.8h, v1.h[1] + fmla v12.8h, v6.8h, v3.h[0] + b Compute3x8Return + Compute3x8EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + ld1 {v2.h}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + Compute3x8Return: + ret + +Compute3x4Unit: + add x19, x10, x16 + add x20, x10, x16, lsl #1 + subs x14, x14, #8 + blt Compute3x4End4 + Compute3x4: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + ld1 {v2.8h}, [x20], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + fmla v12.4h, v3.4h, v2.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v1.h[1] + fmla v12.4h, v4.4h, v2.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v10.4h, v5.4h, v1.h[2] + fmla v12.4h, v5.4h, v2.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v6.4h, v0.h[3] + fmla v10.4h, v6.4h, v1.h[3] + fmla v12.4h, v6.4h, v2.h[3] + fmla v8.4h, v3.4h, v0.h[4] + fmla v10.4h, v3.4h, v1.h[4] + fmla v12.4h, v3.4h, v2.h[4] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[5] + fmla v10.4h, v4.4h, v1.h[5] + fmla v12.4h, v4.4h, v2.h[5] + fmla v8.4h, v5.4h, v0.h[6] + fmla v10.4h, v5.4h, v1.h[6] + fmla v12.4h, v5.4h, v2.h[6] + fmla v8.4h, v6.4h, v0.h[7] + fmla v10.4h, v6.4h, v1.h[7] + fmla v12.4h, v6.4h, v2.h[7] + + subs x14, x14, #8 + bge Compute3x4 + Compute3x4End4: + adds x14, x14, #8 + cbz x14, Compute3x4Return + subs x14, x14, #4 + blt Compute3x4EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + ld1 {v2.4h}, [x20], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + fmla v12.4h, v3.4h, v2.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v1.h[1] + fmla v12.4h, v4.4h, v2.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v10.4h, v5.4h, v1.h[2] + fmla v12.4h, v5.4h, v2.h[2] + fmla v8.4h, v6.4h, v0.h[3] + fmla v10.4h, v6.4h, v1.h[3] + fmla v12.4h, v6.4h, v2.h[3] + subs x14, x14, #4 + Compute3x4EndTail: + adds x14, x14, #4 + cbz x14, Compute3x4Return + cmp x14, #1 + beq Compute3x4EndTail1 + cmp x14, #2 + beq Compute3x4EndTail2 + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld1 {v2.s}[0], [x20], #4 + ld1 {v2.h}[2], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + fmla v12.4h, v3.4h, v2.h[0] + ld1 {v5.4h}, [x11], #8 + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v1.h[1] + fmla v12.4h, v4.4h, v2.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v10.4h, v5.4h, v1.h[2] + fmla v12.4h, v5.4h, v2.h[2] + b Compute3x4Return + Compute3x4EndTail2: + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld2 {v2.h, v3.h}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v5.4h, v0.h[0] + fmla v10.4h, v5.4h, v1.h[0] + fmla v12.4h, v5.4h, v2.h[0] + fmla v8.4h, v6.4h, v0.h[1] + fmla v10.4h, v6.4h, v1.h[1] + fmla v12.4h, v6.4h, v3.h[0] + b Compute3x4Return + Compute3x4EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + ld1 {v2.h}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + fmla v12.4h, v3.4h, v2.h[0] + Compute3x4Return: + ret + +Compute2x16Unit: + add x19, x10, x16 + subs x14, x14, #8 + blt Compute2x16End4 + Compute2x16: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + fmla v8.8h, v5.8h, v0.h[3] + fmla v10.8h, v5.8h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[3] + fmla v11.8h, v6.8h, v1.h[3] + + fmla v8.8h, v3.8h, v0.h[4] + fmla v10.8h, v3.8h, v1.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[4] + fmla v11.8h, v4.8h, v1.h[4] + fmla v8.8h, v5.8h, v0.h[5] + fmla v10.8h, v5.8h, v1.h[5] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[5] + fmla v11.8h, v6.8h, v1.h[5] + fmla v8.8h, v3.8h, v0.h[6] + fmla v10.8h, v3.8h, v1.h[6] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[6] + fmla v11.8h, v4.8h, v1.h[6] + fmla v8.8h, v5.8h, v0.h[7] + fmla v10.8h, v5.8h, v1.h[7] + fmla v9.8h, v6.8h, v0.h[7] + fmla v11.8h, v6.8h, v1.h[7] + + subs x14, x14, #8 + bge Compute2x16 + Compute2x16End4: + adds x14, x14, #8 + cbz x14, Compute2x16Return + subs x14, x14, #4 + blt Compute2x16EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + fmla v8.8h, v5.8h, v0.h[3] + fmla v10.8h, v5.8h, v1.h[3] + fmla v9.8h, v6.8h, v0.h[3] + fmla v11.8h, v6.8h, v1.h[3] + subs x14, x14, #4 + Compute2x16EndTail: + adds x14, x14, #4 + cbz x14, Compute2x16Return + cmp x14, #1 + beq Compute2x16EndTail1 + cmp x14, #2 + beq Compute2x16EndTail2 + ld1 {v0.4h}, [x10] + ld1 {v1.s}[0], [x19], #4 + ld1 {v1.h}[2], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + b Compute2x16Return + Compute2x16EndTail2: + ld1 {v0.4h}, [x10] + ld2 {v1.h, v2.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v2.h[0] + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v2.h[0] + b Compute2x16Return + Compute2x16EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + Compute2x16Return: + ret + +Compute2x8Unit: + add x19, x10, x16 + subs x14, x14, #8 + blt Compute2x8End4 + Compute2x8: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v1.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v10.8h, v5.8h, v1.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v6.8h, v0.h[3] + fmla v10.8h, v6.8h, v1.h[3] + fmla v8.8h, v3.8h, v0.h[4] + fmla v10.8h, v3.8h, v1.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[5] + fmla v10.8h, v4.8h, v1.h[5] + fmla v8.8h, v5.8h, v0.h[6] + fmla v10.8h, v5.8h, v1.h[6] + fmla v8.8h, v6.8h, v0.h[7] + fmla v10.8h, v6.8h, v1.h[7] + + subs x14, x14, #8 + bge Compute2x8 + Compute2x8End4: + adds x14, x14, #8 + cbz x14, Compute2x8Return + subs x14, x14, #4 + blt Compute2x8EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v1.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v10.8h, v5.8h, v1.h[2] + fmla v8.8h, v6.8h, v0.h[3] + fmla v10.8h, v6.8h, v1.h[3] + subs x14, x14, #4 + Compute2x8EndTail: + adds x14, x14, #4 + cbz x14, Compute2x8Return + cmp x14, #1 + beq Compute2x8EndTail1 + cmp x14, #2 + beq Compute2x8EndTail2 + ld1 {v0.4h}, [x10] + ld3 {v1.h, v2.h, v3.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v4.8h, v5.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[0] + fmla v10.8h, v4.8h, v1.h[0] + ld1 {v6.8h}, [x11], #16 + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v2.h[0] + fmla v8.8h, v6.8h, v0.h[2] + fmla v10.8h, v6.8h, v3.h[0] + b Compute2x8Return + Compute2x8EndTail2: + ld1 {v0.4h}, [x10] + ld2 {v1.h, v2.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v2.h[0] + b Compute2x8Return + Compute2x8EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + Compute2x8Return: + ret + +Compute2x4Unit: + add x19, x10, x16 + subs x14, x14, #8 + blt Compute2x4End4 + Compute2x4: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v1.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v10.4h, v5.4h, v1.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v6.4h, v0.h[3] + fmla v10.4h, v6.4h, v1.h[3] + fmla v8.4h, v3.4h, v0.h[4] + fmla v10.4h, v3.4h, v1.h[4] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[5] + fmla v10.4h, v4.4h, v1.h[5] + fmla v8.4h, v5.4h, v0.h[6] + fmla v10.4h, v5.4h, v1.h[6] + fmla v8.4h, v6.4h, v0.h[7] + fmla v10.4h, v6.4h, v1.h[7] + + subs x14, x14, #8 + bge Compute2x4 + Compute2x4End4: + adds x14, x14, #8 + cbz x14, Compute2x4Return + subs x14, x14, #4 + blt Compute2x4EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v1.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v10.4h, v5.4h, v1.h[2] + fmla v8.4h, v6.4h, v0.h[3] + fmla v10.4h, v6.4h, v1.h[3] + subs x14, x14, #4 + Compute2x4EndTail: + adds x14, x14, #4 + cbz x14, Compute2x4Return + cmp x14, #1 + beq Compute2x4EndTail1 + cmp x14, #2 + beq Compute2x4EndTail2 + ld1 {v0.4h}, [x10] + ld3 {v1.h, v2.h, v3.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v4.4h, v5.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[0] + fmla v10.4h, v4.4h, v1.h[0] + ld1 {v6.4h}, [x11], #8 + fmla v8.4h, v5.4h, v0.h[1] + fmla v10.4h, v5.4h, v2.h[0] + fmla v8.4h, v6.4h, v0.h[2] + fmla v10.4h, v6.4h, v3.h[0] + b Compute2x4Return + Compute2x4EndTail2: + ld1 {v0.4h}, [x10] + ld2 {v1.h, v2.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v2.h[0] + b Compute2x4Return + Compute2x4EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + Compute2x4Return: + ret + +Compute1x16Unit: + subs x14, x14, #8 + blt Compute1x16End4 + Compute1x16: + ld1 {v0.8h}, [x10], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v8.8h, v5.8h, v0.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v8.8h, v3.8h, v0.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v8.8h, v5.8h, v0.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[3] + + fmla v8.8h, v3.8h, v0.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[4] + fmla v8.8h, v5.8h, v0.h[5] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[5] + fmla v8.8h, v3.8h, v0.h[6] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[6] + fmla v8.8h, v5.8h, v0.h[7] + fmla v9.8h, v6.8h, v0.h[7] + + subs x14, x14, #8 + bge Compute1x16 + Compute1x16End4: + adds x14, x14, #8 + cbz x14, Compute1x16Return + subs x14, x14, #4 + blt Compute1x16EndTail + ld1 {v0.4h}, [x10], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v8.8h, v5.8h, v0.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v8.8h, v3.8h, v0.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v8.8h, v5.8h, v0.h[3] + fmla v9.8h, v6.8h, v0.h[3] + subs x14, x14, #4 + Compute1x16EndTail: + adds x14, x14, #4 + cbz x14, Compute1x16Return + cmp x14, #1 + beq Compute1x16EndTail1 + cmp x14, #2 + beq Compute1x16EndTail2 + ld3 {v0.h, v1.h, v2.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v8.8h, v5.8h, v1.h[0] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v1.h[0] + fmla v8.8h, v3.8h, v2.h[0] + fmla v9.8h, v4.8h, v2.h[0] + b Compute1x16Return + Compute1x16EndTail2: + ld2 {v0.h, v1.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v8.8h, v5.8h, v1.h[0] + fmla v9.8h, v6.8h, v1.h[0] + b Compute1x16Return + Compute1x16EndTail1: + ld1 {v0.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + Compute1x16Return: + ret + +Compute1x8Unit: + subs x14, x14, #8 + blt Compute1x8End4 + Compute1x8: + ld1 {v0.8h}, [x10], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v8.8h, v5.8h, v0.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v6.8h, v0.h[3] + fmla v8.8h, v3.8h, v0.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[5] + fmla v8.8h, v5.8h, v0.h[6] + fmla v8.8h, v6.8h, v0.h[7] + + subs x14, x14, #8 + bge Compute1x8 + Compute1x8End4: + adds x14, x14, #8 + cbz x14, Compute1x8Return + subs x14, x14, #4 + blt Compute1x8EndTail + ld1 {v0.4h}, [x10], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v8.8h, v6.8h, v0.h[3] + subs x14, x14, #4 + Compute1x8EndTail: + adds x14, x14, #4 + cbz x14, Compute1x8Return + cmp x14, #1 + beq Compute1x8EndTail1 + cmp x14, #2 + beq Compute1x8EndTail2 + ld3 {v0.h, v1.h, v2.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h}, [x11], #16 + fmla v8.8h, v4.8h, v1.h[0] + fmla v8.8h, v5.8h, v2.h[0] + b Compute1x8Return + Compute1x8EndTail2: + ld2 {v0.h, v1.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v8.8h, v4.8h, v1.h[0] + b Compute1x8Return + Compute1x8EndTail1: + ld1 {v0.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + Compute1x8Return: + ret + +Compute1x4Unit: + subs x14, x14, #8 + blt Compute1x4End4 + Compute1x4: + ld1 {v0.8h}, [x10], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v8.4h, v5.4h, v0.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v6.4h, v0.h[3] + fmla v8.4h, v3.4h, v0.h[4] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[5] + fmla v8.4h, v5.4h, v0.h[6] + fmla v8.4h, v6.4h, v0.h[7] + + subs x14, x14, #8 + bge Compute1x4 + Compute1x4End4: + adds x14, x14, #8 + cbz x14, Compute1x4Return + subs x14, x14, #4 + blt Compute1x4EndTail + ld1 {v0.4h}, [x10], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v8.4h, v6.4h, v0.h[3] + subs x14, x14, #4 + Compute1x4EndTail: + adds x14, x14, #4 + cbz x14, Compute1x4Return + cmp x14, #1 + beq Compute1x4EndTail1 + cmp x14, #2 + beq Compute1x4EndTail2 + ld3 {v0.h, v1.h, v2.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + ld1 {v5.4h}, [x11], #8 + fmla v8.4h, v4.4h, v1.h[0] + fmla v8.4h, v5.4h, v2.h[0] + b Compute1x4Return + Compute1x4EndTail2: + ld2 {v0.h, v1.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v8.4h, v4.4h, v1.h[0] + b Compute1x4Return + Compute1x4EndTail1: + ld1 {v0.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + Compute1x4Return: + ret + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x29, x30, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulWinogradFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulWinogradFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..ac1351705ab7e32da4400ccb26507529e2eaea7b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulWinogradFp16.S @@ -0,0 +1,217 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// MatrixMultiplyWinogradFp16(float16_t *matix_a, float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n, int in_channel) + // x0: matrix_a, x1: matrix_b, x2: matrix_c, x3: m, x4: k, x5: n, x6: in_channel +asm_function MatrixMultiplyWinogradFp16 + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #48 + st1 {v8.8h}, [sp] + stp x19, x20, [sp, #16] + stp x21, x22, [sp, #32] + + mov x8, #2 + mul x10, x5, x8 // n * 2 + mov x17, x3 // m + mul x13, x6, x8 // in_channel * 2 + mul x21, x13, x4 // in_channel * k * 2 + + LoopM: + mov x15, x5 // n + mov x14, x1 // mat_b + LoopN: + mov x16, x0 // mat_a_m + sub x22, x5, x15 // ni + sub x19, x17, x3 // mi + mul x22, x22, x17 // ni * m + mov x11, x6 // in_channel + add x22, x22, x19 // (ni * m) + mi + mul x22, x22, x13 // x22 * channel_in * 2 + add x20, x2, x22 // dst + offset + cmp x11, #32 + bge LoopC32 + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + b EndLoopC + LoopC32: + mov x12, x14 + mov x9, x4 // new_k + dup v5.8h, wzr + dup v6.8h, wzr + dup v7.8h, wzr + dup v8.8h, wzr + LoopK32: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.8h, v0.8h, v4.h[0] + fmla v6.8h, v1.8h, v4.h[0] + fmla v7.8h, v2.8h, v4.h[0] + fmla v8.8h, v3.8h, v4.h[0] + subs x9, x9, #1 + bne LoopK32 + Write32: + st1 {v5.8h}, [x20], #16 + st1 {v6.8h}, [x20], #16 + st1 {v7.8h}, [x20], #16 + st1 {v8.8h}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #64 // add 64B + subs x11, x11, #32 + beq EndLoopC + cmp x11, #32 + bge LoopC32 + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC16: + dup v5.8h, wzr + dup v6.8h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK16: + ld1 {v0.8h, v1.8h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.8h, v0.8h, v4.h[0] + fmla v6.8h, v1.8h, v4.h[0] + subs x9, x9, #1 + bne LoopK16 + Write16: + st1 {v5.8h}, [x20], #16 + st1 {v6.8h}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #32 // add 32B + subs x11, x11, #16 + beq EndLoopC + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC8: + dup v5.8h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK8: + ld1 {v0.8h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.8h, v0.8h, v4.h[0] + subs x9, x9, #1 + bne LoopK8 + Write8: + st1 {v5.8h}, [x20], #16 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #16 // add 16B + subs x11, x11, #8 + beq EndLoopC + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC4: + dup v5.4h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK4: + ld1 {v0.4h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.4h, v0.4h, v4.h[0] + subs x9, x9, #1 + bne LoopK4 + Write4: + st1 {v5.4h}, [x20], #8 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #8 // add 8B + subs x11, x11, #4 + beq EndLoopC + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC: + dup v5.8h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK: + ldr h0, [x16] + add x16, x16, x13 + ldr h4, [x12] + add x12, x12, x10 + fmul h0, h0, h4 + fadd h5, h5, h0 + subs x9, x9, #1 + bne LoopK + Write: + str h5, [x20], #2 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #2 // ptr add 2B + subs x11, x11, #1 + beq EndLoopC + b LoopC + + EndLoopC: + add x14, x14, #2 + subs x15, x15, #1 + beq EndLoopN + b LoopN + EndLoopN: + subs x3, x3, #1 + beq EndLoopM + add x0, x0, x21 + b LoopM + + EndLoopM: + ld1 {v8.8h}, [sp], #16 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/PostFuncBiasReluC4Fp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/PostFuncBiasReluC4Fp16.S new file mode 100644 index 0000000000000000000000000000000000000000..82fff43078f2e94d707c29bf1e1b3e6ab4facb47 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/PostFuncBiasReluC4Fp16.S @@ -0,0 +1,293 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PostFuncBiasReluC4Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc4div, size_t oc4mod, +// size_t plane_size, size_t plane_stride, size_t relu_type); +// x0 dst x1 srx x2 bias +// w3 oc4div w4 oc4mod w5 plane_size +// x6 plane_stride x7 relu_type + +asm_function PostFuncBiasReluC4Fp16 + + movi v26.4h, #6 + scvtf v26.4h, v26.4h + dup v27.4h, wzr + + mov x10, #2 + add x12, x3, x4 + mul x12, x12, x10 + + mov w10, #0 + +Loop_C4: + cmp w10, w3 + beq Loop_C1 + mov x15, #2 + mul x14, x10, x15 + add x15, x0, x14 + add w10, w10, #4 + mov w13, w5 + ld1 {v16.4h}, [x2], #8 + +Loop_8x4: + cmp w13, #8 + blt Loop_4x4 + sub w13, w13, #8 + ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x1], #32 + ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x1], #32 + + fadd v0.4h, v0.4h, v16.4h + fadd v1.4h, v1.4h, v16.4h + fadd v2.4h, v2.4h, v16.4h + fadd v3.4h, v3.4h, v16.4h + fadd v4.4h, v4.4h, v16.4h + fadd v5.4h, v5.4h, v16.4h + fadd v6.4h, v6.4h, v16.4h + fadd v7.4h, v7.4h, v16.4h + + cmp x7, #3 + beq Relu6_8x4 + cmp x7, #1 + beq Relu_8x4 + b Write_8x4 +Relu6_8x4: + fmin v0.4h, v0.4h, v26.4h + fmin v1.4h, v1.4h, v26.4h + fmin v2.4h, v2.4h, v26.4h + fmin v3.4h, v3.4h, v26.4h + fmin v4.4h, v4.4h, v26.4h + fmin v5.4h, v5.4h, v26.4h + fmin v6.4h, v6.4h, v26.4h + fmin v7.4h, v7.4h, v26.4h +Relu_8x4: + fmax v0.4h, v0.4h, v27.4h + fmax v1.4h, v1.4h, v27.4h + fmax v2.4h, v2.4h, v27.4h + fmax v3.4h, v3.4h, v27.4h + fmax v4.4h, v4.4h, v27.4h + fmax v5.4h, v5.4h, v27.4h + fmax v6.4h, v6.4h, v27.4h + fmax v7.4h, v7.4h, v27.4h +Write_8x4: + st1 {v0.4h}, [x15], x12 + st1 {v1.4h}, [x15], x12 + st1 {v2.4h}, [x15], x12 + st1 {v3.4h}, [x15], x12 + st1 {v4.4h}, [x15], x12 + st1 {v5.4h}, [x15], x12 + st1 {v6.4h}, [x15], x12 + st1 {v7.4h}, [x15], x12 + b Loop_8x4 + +Loop_4x4: + cmp w13, #4 + blt Loop_1x4 + sub w13, w13, #4 + ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x1], #32 + fadd v0.4h, v0.4h, v16.4h + fadd v1.4h, v1.4h, v16.4h + fadd v2.4h, v2.4h, v16.4h + fadd v3.4h, v3.4h, v16.4h + cmp x7, #3 + beq Relu6_4x4 + cmp x7, #1 + beq Relu_4x4 + b Write_4x4 +Relu6_4x4: + fmin v0.4h, v0.4h, v26.4h + fmin v1.4h, v1.4h, v26.4h + fmin v2.4h, v2.4h, v26.4h + fmin v3.4h, v3.4h, v26.4h +Relu_4x4: + fmax v0.4h, v0.4h, v27.4h + fmax v1.4h, v1.4h, v27.4h + fmax v2.4h, v2.4h, v27.4h + fmax v3.4h, v3.4h, v27.4h +Write_4x4: + st1 {v0.4h}, [x15], x12 + st1 {v1.4h}, [x15], x12 + st1 {v2.4h}, [x15], x12 + st1 {v3.4h}, [x15], x12 + +Loop_1x4: + cmp x7, #3 + beq Relu6_1x4 + cmp x7, #1 + beq Relu_1x4 + b Write_1x4 +Relu6_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmin v0.4h, v0.4h, v26.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.4h}, [x15], x12 + b Relu6_1x4 +Relu_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.4h}, [x15], x12 + b Relu_1x4 +Write_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + st1 {v0.4h}, [x15], x12 + b Write_1x4 + +HW_Add: + add x1, x1, x6 + b Loop_C4 + +Loop_C1: + cmp w4, #0 + beq End + mov w13, w5 + ld1 {v16.4h}, [x2], #8 + mov x15, #2 + mul x14, x10, x15 + add x0, x0, x14 + + cmp w4, #1 + beq Loop_C1_1 + cmp w4, #2 + beq Loop_C1_2 + cmp w4, #3 + beq Loop_C1_3 + +Loop_C1_1: + cmp x7, #3 + beq Loop_C1_1_Relu6 + cmp x7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmin v0.4h, v0.4h, v26.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.h}[0], [x0], x12 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.h}[0], [x0], x12 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + st1 {v0.h}[0], [x0], x12 + b Loop_C1_1_Write + +Loop_C1_2: + cmp x7, #3 + beq Loop_C1_2_Relu6 + cmp x7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmin v0.4h, v0.4h, v26.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.s}[0], [x0], x12 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.s}[0], [x0], x12 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + st1 {v0.s}[0], [x0], x12 + b Loop_C1_2_Write + +Loop_C1_3: + add x15, x0, #4 + cmp x7, #3 + beq Loop_C1_3_Relu6 + cmp x7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmin v0.4h, v0.4h, v26.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.s}[0], [x0], x12 + st1 {v0.h}[2], [x15], x12 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.s}[0], [x0], x12 + st1 {v0.h}[2], [x15], x12 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + st1 {v0.s}[0], [x0], x12 + st1 {v0.h}[2], [x15], x12 + b Loop_C1_3_Write + +End: + ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/PostFuncBiasReluC8Fp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/PostFuncBiasReluC8Fp16.S new file mode 100644 index 0000000000000000000000000000000000000000..c339ac8c30afab20a4aca40d1b529a3f3243bfa7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/PostFuncBiasReluC8Fp16.S @@ -0,0 +1,469 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PostFuncBiasReluC8Fp16(float *dst, const float *src, const float *bias, size_t oc8div,size_t oc8mod +// size_t plane_size, size_t stride, int relu_type); +// x0 dst x1 srx x2 bias +// x3 oc8div x4 oc8mod x5 plane_size +// x6 stride x7 relu_type + +// v0 ~ v7 value +// v16 bias data +// x22 x23 x24 x25 write loop tmp buf +// x26 relu6 #6; x27 relu #0 +// w10 oc8 loop control +// w13 hw loop control + +asm_function PostFuncBiasReluC8Fp16 + movi v26.8h, #0x46, lsl #8 + dup v27.8h, wzr + mov w10, #0 + +Loop_C8: + cmp w10, w3 + beq Loop_C1 + mov x25, #2 + mul x24, x10, x25 + add x25, x0, x24 + add w10, w10, #8 + mov w13, w5 + ld1 {v16.8h}, [x2], #16 + +Loop8x8: + cmp w13, #8 + blt Loop_4x8 + sub w13, w13, #8 + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x1], #64 + + fadd v0.8h, v0.8h, v16.8h + fadd v1.8h, v1.8h, v16.8h + fadd v2.8h, v2.8h, v16.8h + fadd v3.8h, v3.8h, v16.8h + fadd v4.8h, v4.8h, v16.8h + fadd v5.8h, v5.8h, v16.8h + fadd v6.8h, v6.8h, v16.8h + fadd v7.8h, v7.8h, v16.8h + + cmp w7, #2 + beq Relu6_8x8 + cmp w7, #1 + beq Relu_8x8 + b Write_8x8 +Relu6_8x8: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h + fmin v4.8h, v4.8h, v26.8h + fmin v5.8h, v5.8h, v26.8h + fmin v6.8h, v6.8h, v26.8h + fmin v7.8h, v7.8h, v26.8h +Relu_8x8: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h + fmax v4.8h, v4.8h, v27.8h + fmax v5.8h, v5.8h, v27.8h + fmax v6.8h, v6.8h, v27.8h + fmax v7.8h, v7.8h, v27.8h +Write_8x8: + st1 {v0.8h}, [x25], x6 + st1 {v1.8h}, [x25], x6 + st1 {v2.8h}, [x25], x6 + st1 {v3.8h}, [x25], x6 + st1 {v4.8h}, [x25], x6 + st1 {v5.8h}, [x25], x6 + st1 {v6.8h}, [x25], x6 + st1 {v7.8h}, [x25], x6 + b Loop8x8 + +Loop_4x8: + cmp w13, #4 + blt Loop_1x8 + sub w13, w13, #4 + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 + + fadd v0.8h, v0.8h, v16.8h + fadd v1.8h, v1.8h, v16.8h + fadd v2.8h, v2.8h, v16.8h + fadd v3.8h, v3.8h, v16.8h + + cmp w7, #2 + beq Relu6_4x8 + cmp w7, #1 + beq Relu_4x8 + b Write_4x8 +Relu6_4x8: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h +Relu_4x8: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h +Write_4x8: + st1 {v0.8h}, [x25], x6 + st1 {v1.8h}, [x25], x6 + st1 {v2.8h}, [x25], x6 + st1 {v3.8h}, [x25], x6 + b Loop_4x8 + +Loop_1x8: + cmp w7, #2 + beq Relu6_1x8 + cmp w7, #1 + beq Relu_1x8 + b Write_1x8 +Relu6_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.8h}, [x25], x6 + b Relu6_1x8 +Relu_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.8h}, [x25], x6 + b Relu_1x8 +Write_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.8h}, [x25], x6 + b Write_1x8 + + +Loop_C1: + cmp x4, #0 + beq End + mov w13, w5 + ld1 {v16.8h}, [x2], #16 + mov x25, #2 + mul x24, x10, x25 + add x22, x0, x24 + + cmp x4, #1 + beq Loop_C1_1 + cmp x4, #2 + beq Loop_C1_2 + cmp x4, #3 + beq Loop_C1_3 + cmp x4, #4 + beq Loop_C1_4 + cmp x4, #5 + beq Loop_C1_5 + cmp x4, #6 + beq Loop_C1_6 + cmp x4, #7 + beq Loop_C1_7 + +Loop_C1_1: + cmp w7, #2 + beq Loop_C1_1_Relu6 + cmp w7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.h}[0], [x22], x6 + b Loop_C1_1_Write + +Loop_C1_2: + add x24, x0, #2 + cmp w7, #2 + beq Loop_C1_2_Relu6 + cmp w7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + b Loop_C1_2_Write + + +Loop_C1_3: + add x24, x22, #2 + add x25, x22, #4 + cmp w7, #2 + beq Loop_C1_3_Relu6 + cmp w7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + st1 {v0.h}[2], [x25], x6 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + st1 {v0.h}[2], [x25], x6 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + st1 {v0.h}[2], [x25], x6 + b Loop_C1_3_Write + +Loop_C1_4: + cmp w7, #2 + beq Loop_C1_4_Relu6 + cmp w7, #1 + beq Loop_C1_4_Relu + b Loop_C1_4_Write +Loop_C1_4_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.4h}, [x22], x6 + b Loop_C1_4_Write + +Loop_C1_5: + add x25, x22, #8 + cmp w7, #2 + beq Loop_C1_5_Relu6 + cmp w7, #1 + beq Loop_C1_5_Relu + b Loop_C1_5_Write +Loop_C1_5_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x25], x6 + b Loop_C1_5_Relu6 +Loop_C1_5_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x25], x6 + b Loop_C1_5_Relu +Loop_C1_5_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x25], x6 + b Loop_C1_5_Write + +Loop_C1_6: + add x23, x22, #8 + add x24, x22, #10 + cmp w7, #2 + beq Loop_C1_6_Relu6 + cmp w7, #1 + beq Loop_C1_6_Relu + b Loop_C1_6_Write +Loop_C1_6_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + b Loop_C1_6_Relu6 +Loop_C1_6_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + b Loop_C1_6_Relu +Loop_C1_6_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + b Loop_C1_6_Write + +Loop_C1_7: + add x23, x22, #8 + add x24, x22, #10 + add x25, x22, #12 + cmp w7, #2 + beq Loop_C1_7_Relu6 + cmp w7, #1 + beq Loop_C1_7_Relu + b Loop_C1_7_Write +Loop_C1_7_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + st1 {v0.h}[6], [x25], x6 + b Loop_C1_7_Relu6 +Loop_C1_7_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + st1 {v0.h}[6], [x25], x6 + b Loop_C1_7_Relu +Loop_C1_7_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + st1 {v0.h}[6], [x25], x6 + b Loop_C1_7_Write + +End: + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/TiledC4MatmulFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/TiledC4MatmulFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..e0f2211ed41ed39be4609219acdcf10fe21ac90b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/TiledC4MatmulFp16.S @@ -0,0 +1,273 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function TiledC4MatmulFp16 + +sub sp, sp, #128 +st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] +add x9, sp, #64 +st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + +mov x7, #2 //sizeof(float) +mul x3, x3, x7 +mov x7, #32 +mul x10, x4, x7 + +cmp x5, #2 +blt LoopOcHalf +LoopOc: + mov x8, x1 + subs x9, x4, #1 + + add x6, x2, x10 + ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x8], #32 + ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x2], #32 + fmul v16.4h, v8.4h, v0.h[0] + fmul v17.4h, v8.4h, v1.h[0] + ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x8], #32 + fmul v18.4h, v8.4h, v2.h[0] + fmul v19.4h, v8.4h, v3.h[0] + ld1 {v12.4h, v13.4h, v14.4h, v15.4h}, [x6], #32 + fmul v20.4h, v8.4h, v4.h[0] + fmul v21.4h, v8.4h, v5.h[0] + fmul v22.4h, v8.4h, v6.h[0] + fmul v23.4h, v8.4h, v7.h[0] + fmul v24.4h, v12.4h, v0.h[0] + fmul v25.4h, v12.4h, v1.h[0] + fmul v26.4h, v12.4h, v2.h[0] + fmul v27.4h, v12.4h, v3.h[0] + fmul v28.4h, v12.4h, v4.h[0] + fmul v29.4h, v12.4h, v5.h[0] + fmul v30.4h, v12.4h, v6.h[0] + fmul v31.4h, v12.4h, v7.h[0] + + beq LoopIcEnd + LoopIc: + add x2, x2, #64 + prfm pldl1keep, [x2] + prfm pldl1keep, [x2, x10] + sub x2, x2, #64 + prfm pldl1keep, [x8, #64] + prfm pldl1keep, [x8, #96] + + fmla v16.4h, v9.4h, v0.h[1] + fmla v17.4h, v9.4h, v1.h[1] + fmla v18.4h, v9.4h, v2.h[1] + fmla v19.4h, v9.4h, v3.h[1] + fmla v20.4h, v9.4h, v4.h[1] + fmla v21.4h, v9.4h, v5.h[1] + fmla v22.4h, v9.4h, v6.h[1] + fmla v23.4h, v9.4h, v7.h[1] + fmla v24.4h, v13.4h, v0.h[1] + fmla v25.4h, v13.4h, v1.h[1] + fmla v26.4h, v13.4h, v2.h[1] + fmla v27.4h, v13.4h, v3.h[1] + fmla v28.4h, v13.4h, v4.h[1] + fmla v29.4h, v13.4h, v5.h[1] + fmla v30.4h, v13.4h, v6.h[1] + fmla v31.4h, v13.4h, v7.h[1] + + fmla v16.4h, v10.4h, v0.h[2] + fmla v17.4h, v10.4h, v1.h[2] + fmla v18.4h, v10.4h, v2.h[2] + fmla v19.4h, v10.4h, v3.h[2] + fmla v20.4h, v10.4h, v4.h[2] + fmla v21.4h, v10.4h, v5.h[2] + fmla v22.4h, v10.4h, v6.h[2] + fmla v23.4h, v10.4h, v7.h[2] + fmla v24.4h, v14.4h, v0.h[2] + fmla v25.4h, v14.4h, v1.h[2] + fmla v26.4h, v14.4h, v2.h[2] + fmla v27.4h, v14.4h, v3.h[2] + fmla v28.4h, v14.4h, v4.h[2] + fmla v29.4h, v14.4h, v5.h[2] + fmla v30.4h, v14.4h, v6.h[2] + fmla v31.4h, v14.4h, v7.h[2] + + fmla v16.4h, v11.4h, v0.h[3] + fmla v17.4h, v11.4h, v1.h[3] + fmla v18.4h, v11.4h, v2.h[3] + fmla v19.4h, v11.4h, v3.h[3] + fmla v20.4h, v11.4h, v4.h[3] + fmla v21.4h, v11.4h, v5.h[3] + fmla v22.4h, v11.4h, v6.h[3] + fmla v23.4h, v11.4h, v7.h[3] + fmla v24.4h, v15.4h, v0.h[3] + fmla v25.4h, v15.4h, v1.h[3] + fmla v26.4h, v15.4h, v2.h[3] + fmla v27.4h, v15.4h, v3.h[3] + fmla v28.4h, v15.4h, v4.h[3] + fmla v29.4h, v15.4h, v5.h[3] + fmla v30.4h, v15.4h, v6.h[3] + fmla v31.4h, v15.4h, v7.h[3] + + ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x2], #32 + ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x8], #32 + fmla v16.4h, v8.4h, v0.h[0] + fmla v17.4h, v8.4h, v1.h[0] + ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x8], #32 + fmla v18.4h, v8.4h, v2.h[0] + fmla v19.4h, v8.4h, v3.h[0] + ld1 {v12.4h, v13.4h, v14.4h, v15.4h}, [x6], #32 + fmla v20.4h, v8.4h, v4.h[0] + fmla v21.4h, v8.4h, v5.h[0] + fmla v22.4h, v8.4h, v6.h[0] + fmla v23.4h, v8.4h, v7.h[0] + fmla v24.4h, v12.4h, v0.h[0] + fmla v25.4h, v12.4h, v1.h[0] + fmla v26.4h, v12.4h, v2.h[0] + fmla v27.4h, v12.4h, v3.h[0] + fmla v28.4h, v12.4h, v4.h[0] + fmla v29.4h, v12.4h, v5.h[0] + fmla v30.4h, v12.4h, v6.h[0] + fmla v31.4h, v12.4h, v7.h[0] + + subs x9, x9, #1 + bne LoopIc + + LoopIcEnd: + fmla v16.4h, v9.4h, v0.h[1] + fmla v17.4h, v9.4h, v1.h[1] + fmla v18.4h, v9.4h, v2.h[1] + fmla v19.4h, v9.4h, v3.h[1] + fmla v20.4h, v9.4h, v4.h[1] + fmla v21.4h, v9.4h, v5.h[1] + fmla v22.4h, v9.4h, v6.h[1] + fmla v23.4h, v9.4h, v7.h[1] + fmla v24.4h, v13.4h, v0.h[1] + fmla v25.4h, v13.4h, v1.h[1] + fmla v26.4h, v13.4h, v2.h[1] + fmla v27.4h, v13.4h, v3.h[1] + fmla v28.4h, v13.4h, v4.h[1] + fmla v29.4h, v13.4h, v5.h[1] + fmla v30.4h, v13.4h, v6.h[1] + fmla v31.4h, v13.4h, v7.h[1] + + fmla v16.4h, v10.4h, v0.h[2] + fmla v17.4h, v10.4h, v1.h[2] + fmla v18.4h, v10.4h, v2.h[2] + fmla v19.4h, v10.4h, v3.h[2] + fmla v20.4h, v10.4h, v4.h[2] + fmla v21.4h, v10.4h, v5.h[2] + fmla v22.4h, v10.4h, v6.h[2] + fmla v23.4h, v10.4h, v7.h[2] + fmla v24.4h, v14.4h, v0.h[2] + fmla v25.4h, v14.4h, v1.h[2] + fmla v26.4h, v14.4h, v2.h[2] + fmla v27.4h, v14.4h, v3.h[2] + fmla v28.4h, v14.4h, v4.h[2] + fmla v29.4h, v14.4h, v5.h[2] + fmla v30.4h, v14.4h, v6.h[2] + fmla v31.4h, v14.4h, v7.h[2] + + add x7, x0, #32 + + fmla v16.4h, v11.4h, v0.h[3] + fmla v17.4h, v11.4h, v1.h[3] + fmla v18.4h, v11.4h, v2.h[3] + fmla v19.4h, v11.4h, v3.h[3] + fmla v20.4h, v11.4h, v4.h[3] + fmla v21.4h, v11.4h, v5.h[3] + fmla v22.4h, v11.4h, v6.h[3] + fmla v23.4h, v11.4h, v7.h[3] + fmla v24.4h, v15.4h, v0.h[3] + fmla v25.4h, v15.4h, v1.h[3] + fmla v26.4h, v15.4h, v2.h[3] + fmla v27.4h, v15.4h, v3.h[3] + fmla v28.4h, v15.4h, v4.h[3] + st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], x3 + fmla v29.4h, v15.4h, v5.h[3] + st1 {v20.4h, v21.4h, v22.4h, v23.4h}, [x7], x3 + fmla v30.4h, v15.4h, v6.h[3] + st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [x0], x3 + mov x2, x6 + fmla v31.4h, v15.4h, v7.h[3] + st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [x7] + + subs x5, x5, #2 + beq LoopOcEnd + cmp x5, #2 + bge LoopOc + +LoopOcHalf: + mov x8, x1 + mov x9, x4 + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + + LoopIcHalf: + ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x2], #32 + ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x8], #32 + fmla v16.4h, v8.4h, v0.h[0] + fmla v17.4h, v8.4h, v1.h[0] + ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x8], #32 + fmla v18.4h, v8.4h, v2.h[0] + fmla v19.4h, v8.4h, v3.h[0] + fmla v20.4h, v8.4h, v4.h[0] + fmla v21.4h, v8.4h, v5.h[0] + fmla v22.4h, v8.4h, v6.h[0] + fmla v23.4h, v8.4h, v7.h[0] + + fmla v16.4h, v9.4h, v0.h[1] + fmla v17.4h, v9.4h, v1.h[1] + fmla v18.4h, v9.4h, v2.h[1] + fmla v19.4h, v9.4h, v3.h[1] + fmla v20.4h, v9.4h, v4.h[1] + fmla v21.4h, v9.4h, v5.h[1] + fmla v22.4h, v9.4h, v6.h[1] + fmla v23.4h, v9.4h, v7.h[1] + + fmla v16.4h, v10.4h, v0.h[2] + fmla v17.4h, v10.4h, v1.h[2] + fmla v18.4h, v10.4h, v2.h[2] + fmla v19.4h, v10.4h, v3.h[2] + fmla v20.4h, v10.4h, v4.h[2] + fmla v21.4h, v10.4h, v5.h[2] + fmla v22.4h, v10.4h, v6.h[2] + fmla v23.4h, v10.4h, v7.h[2] + + fmla v16.4h, v11.4h, v0.h[3] + fmla v17.4h, v11.4h, v1.h[3] + fmla v18.4h, v11.4h, v2.h[3] + fmla v19.4h, v11.4h, v3.h[3] + fmla v20.4h, v11.4h, v4.h[3] + fmla v21.4h, v11.4h, v5.h[3] + fmla v22.4h, v11.4h, v6.h[3] + fmla v23.4h, v11.4h, v7.h[3] + + subs x9, x9, #1 + bne LoopIcHalf + + st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], #32 + st1 {v20.4h, v21.4h, v22.4h, v23.4h}, [x0], #32 + +LoopOcEnd: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/VecMatmulFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/VecMatmulFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..bf1803c2b03ab56a94a9d790ab573f02d2cf4515 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/VecMatmulFp16.S @@ -0,0 +1,181 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void VecMatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int col) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: col + +asm_function VecMatmulFp16Neon64_2 + sub sp, sp, #128 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + +LoopCol: + mov x15, x0 // reload a ptr + ld1 {v0.8h}, [x3], #16 // acc0 + ld1 {v1.8h}, [x3], #16 // acc1 + mov w9, #0 // tmp depth + +Loop2x8Inner: + sub w18, w5, w9 + cmp w18, #8 + blt DepthRemain + + ld1 {v2.8h}, [x15], #16 // a + ld1 {v3.8h, v4.8h, v5.8h, v6.8h}, [x1], #64 + ld1 {v7.8h, v8.8h, v9.8h, v10.8h}, [x1], #64 + ld1 {v11.8h, v12.8h, v13.8h, v14.8h}, [x1], #64 + ld1 {v15.8h, v16.8h, v17.8h, v18.8h}, [x1], #64 + + fmla v0.8h, v3.8h, v2.h[0] + fmla v0.8h, v5.8h, v2.h[1] + fmla v0.8h, v7.8h, v2.h[2] + fmla v0.8h, v9.8h, v2.h[3] + fmla v0.8h, v11.8h, v2.h[4] + fmla v0.8h, v13.8h, v2.h[5] + fmla v0.8h, v15.8h, v2.h[6] + fmla v0.8h, v17.8h, v2.h[7] + fmla v1.8h, v4.8h, v2.h[0] + fmla v1.8h, v6.8h, v2.h[1] + fmla v1.8h, v8.8h, v2.h[2] + fmla v1.8h, v10.8h, v2.h[3] + fmla v1.8h, v12.8h, v2.h[4] + fmla v1.8h, v14.8h, v2.h[5] + fmla v1.8h, v16.8h, v2.h[6] + fmla v1.8h, v18.8h, v2.h[7] + + add w9, w9, #8 + b Loop2x8Inner + +DepthRemain: // last depth [0, 8) + cmp w18, #0 + ble Act + ld1 {v2.h}[0], [x15], #2 + ld1 {v3.8h}, [x1], #16 + ld1 {v4.8h}, [x1], #16 + fmla v0.8h, v3.8h, v2.h[0] + fmla v1.8h, v4.8h, v2.h[0] + sub w18, w18, #1 + b DepthRemain + +Act: + cmp w4, #3 + beq Relu6 + cmp w4, #1 + beq Relu + b Write + +Relu6: + movi v19.8h, #0x46, lsl #8 + fmin v0.8h, v0.8h, v19.8h + fmin v1.8h, v1.8h, v19.8h + +Relu: + dup v20.8h, wzr + fmax v0.8h, v0.8h, v20.8h + fmax v1.8h, v1.8h, v20.8h + +Write: + cmp w6, #8 + blt WriteMod8 + st1 {v0.8h}, [x2], #16 + sub w6, w6, #8 + mov v0.16b, v1.16b + cmp w6, #8 + blt WriteMod8 + st1 {v1.8h}, [x2], #16 + sub w6, w6, #8 + cbz w6, End + b LoopCol + +WriteMod8: + cmp w6, #0 + ble End + cmp w6, #1 + beq Write1 + cmp w6, #2 + beq Write2 + cmp w6, #3 + beq Write3 + cmp w6, #4 + beq Write4 + cmp w6, #5 + beq Write5 + cmp w6, #6 + beq Write6 + cmp w6, #7 + beq Write7 + +Write1: + st1 {v0.h}[0], [x2], #2 + b End +Write2: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + b End +Write3: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + st1 {v0.h}[2], [x2], #2 + b End +Write4: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + st1 {v0.h}[2], [x2], #2 + st1 {v0.h}[3], [x2], #2 + b End +Write5: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + st1 {v0.h}[2], [x2], #2 + st1 {v0.h}[3], [x2], #2 + st1 {v0.h}[4], [x2], #2 + b End +Write6: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + st1 {v0.h}[2], [x2], #2 + st1 {v0.h}[3], [x2], #2 + st1 {v0.h}[4], [x2], #2 + st1 {v0.h}[5], [x2], #2 + b End +Write7: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + st1 {v0.h}[2], [x2], #2 + st1 {v0.h}[3], [x2], #2 + st1 {v0.h}[4], [x2], #2 + st1 {v0.h}[5], [x2], #2 + st1 {v0.h}[6], [x2], #2 + b End + +End: + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/WinogradTransLeftFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/WinogradTransLeftFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..c308eb3415d338dbf5d970121e1c7fe6d67c5398 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/WinogradTransLeftFp16.S @@ -0,0 +1,150 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function WinogradTransLeftFp16 + +sub sp, sp, #16 +stp x19, x20, [sp] + +mov x8, #8 // 4 * sizeof(float16) +mul x8, x6, x8 +mul x9, x3, x8 +sub x9, x9, x8 +add x7, x9, x8 // step for S +mov x10, #2 +mul x10, x4, x10 // step for B + +LoopH: + mov x13, x0 + mov x15, x3 + LoopW: + mov x14, x13 + mov x17, x1 + dup v30.4h, wzr + mov x11, x6 + InitZero: + st1 {v30.4h}, [x2], #8 + subs x11, x11, #1 + bne InitZero + + sub x2, x2, x8 + mov x12, x5 + LoopKStart4: + cmp x12, #4 + blt LoopKStart3 + mov x16, x15 + mov x19, x4 + LoopK4: + ld1 {v0.h}[0], [x17], x10 + ld1 {v0.h}[1], [x17], x10 + ld1 {v0.h}[2], [x17], x10 + ld1 {v0.h}[3], [x17], x10 + mov x11, x6 + mov x20, x17 + add x20, x14, x7 + add x16, x20, x7 + add x19, x16, x7 + + LoopLength4: + ld1 {v16.4h}, [x2] + ld1 {v20.4h}, [x14], #8 + fmla v16.4h, v20.4h, v0.h[0] + ld1 {v21.4h}, [x20], #8 + fmul v17.4h, v21.4h, v0.h[1] + ld1 {v20.4h}, [x16], #8 + fmla v16.4h, v20.4h, v0.h[2] + ld1 {v21.4h}, [x19], #8 + fmla v17.4h, v21.4h, v0.h[3] + fadd v17.4h, v16.4h, v17.4h + st1 {v17.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength4 + + sub x2, x2, x8 + sub x12, x12, #4 + add x14, x19, x9 + cmp x12, #4 + bge LoopK4 + + LoopKStart3: + cmp x12, #3 + blt LoopKStart + mov x16, x15 + LoopK3: + ld1 {v0.h}[0], [x17], x10 + ld1 {v0.h}[1], [x17], x10 + ld1 {v0.h}[2], [x17], x10 + mov x11, x6 + mov x20, x17 + add x20, x14, x7 + add x16, x20, x7 + LoopLength3: + ld1 {v16.4h}, [x2] + ld1 {v20.4h}, [x14], #8 + fmla v16.4h, v20.4h, v0.h[0] + ld1 {v21.4h}, [x20], #8 + fmul v17.4h, v21.4h, v0.h[1] + ld1 {v20.4h}, [x16], #8 + fmla v16.4h, v20.4h, v0.h[2] + fadd v17.4h, v16.4h, v17.4h + st1 {v17.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength3 + + sub x2, x2, x8 + sub x12, x12, #3 + add x14, x16, x9 + cmp x12, #3 + bge LoopK3 + + LoopKStart: + cmp x12, #0 + beq LKEnd + LoopK: + ld1r {v31.4h}, [x17], x10 + mov x11, x6 + LoopLength: + ld1 {v0.4h}, [x2] + ld1 {v1.4h}, [x14], #8 + fmla v0.4h, v1.4h, v31.4h + st1 {v0.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength + + subs x12, x12, #1 + sub x2, x2, x8 + add x14, x14, x9 + bne LoopK + + LKEnd: + subs x15, x15, #1 + add x13, x13, x8 + add x2, x2, x8 + bne LoopW + + add x1, x1, #2 //sizeof(float) + subs x4, x4, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/WinogradTransRightFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/WinogradTransRightFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..cde99cc120cf8d1898b79f0f31bc1cd4a1b6366a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/WinogradTransRightFp16.S @@ -0,0 +1,154 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function WinogradTransRightFp16 + +sub sp, sp, #16 +stp x19, x20, [sp] + +mov x8, #8 // 4 * sizeof(float16) +mul x8, x6, x8 +mul x9, x5, x8 // step for S +mov x10, #2 +mul x10, x4, x10 // step for B + +LoopH: + mov x7, x1 + mov x15, x3 + LoopW: + mov x17, x0 + mov x13, x7 + dup v30.4h, wzr + mov x11, x6 + InitZero: + st1 {v30.4h}, [x2], #8 + subs x11, x11, #1 + bne InitZero + sub x2, x2, x8 + mov x12, x5 + + LoopKStart4: + cmp x12, #4 + blt LoopKStart3 + mov x16, x15 + mov x19, x4 + LoopK4: + ld1 {v0.h}[0], [x13], x10 + ld1 {v0.h}[1], [x13], x10 + ld1 {v0.h}[2], [x13], x10 + ld1 {v0.h}[3], [x13], x10 + mov x11, x6 + mov x14, x13 + + add x14, x17, x8 + add x16, x14, x8 + add x19, x16, x8 + + LoopLength4: + ld1 {v16.4h}, [x2] + ld1 {v20.4h}, [x17], #8 + fmla v16.4h, v20.4h, v0.h[0] + ld1 {v21.4h}, [x14], #8 + fmul v17.4h, v21.4h, v0.h[1] + ld1 {v20.4h}, [x16], #8 + fmla v16.4h, v20.4h, v0.h[2] + ld1 {v21.4h}, [x19], #8 + fmla v17.4h, v21.4h, v0.h[3] + + fadd v17.4h, v16.4h, v17.4h + st1 {v17.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength4 + sub x2, x2, x8 + sub x12, x12, #4 + mov x17, x19 + + cmp x12, #4 + bge LoopK4 + + LoopKStart3: + cmp x12, #3 + blt LoopKStart + mov x16, x15 + LoopK3: + ld1 {v0.h}[0], [x13], x10 + ld1 {v0.h}[1], [x13], x10 + ld1 {v0.h}[2], [x13], x10 + mov x11, x6 + mov x14, x13 + + add x14, x17, x8 + add x16, x14, x8 + + LoopLength3: + ld1 {v16.4h}, [x2] + ld1 {v20.4h}, [x17], #8 + fmla v16.4h, v20.4h, v0.h[0] + ld1 {v21.4h}, [x14], #8 + fmul v17.4h, v21.4h, v0.h[1] + ld1 {v20.4h}, [x16], #8 + fmla v16.4h, v20.4h, v0.h[2] + + fadd v17.4h, v16.4h, v17.4h + st1 {v17.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength3 + sub x2, x2, x8 + sub x12, x12, #3 + mov x17, x19 + cmp x12, #3 + bge LoopK3 + + LoopKStart: + cmp x12, #0 + beq LoopKEnd + + LoopK: + ld1r {v31.4h}, [x13], x10 + + mov x11, x6 + LoopLength: + ld1 {v0.4h}, [x2] + ld1 {v1.4h}, [x17], #8 + fmla v0.4h, v1.4h, v31.4h + + st1 {v0.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength + subs x12, x12, #1 + + sub x2, x2, x8 + bne LoopK + LoopKEnd: + subs x15, x15, #1 + add x2, x2, x8 + add x7, x7, #2 + bne LoopW + + add x0, x0, x9 + subs x4, x4, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + + ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S new file mode 100644 index 0000000000000000000000000000000000000000..106eba3846523d9019f9ea27797c1cf34d1bde5f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S @@ -0,0 +1,764 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" +.text +.align 5 + +// void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, float *multi_scales, +// float *bias, size_t row, size_t col, size_t stride, const int *a_sums, +// const int *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode); +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// x3: deep +// x4: multi_scales +// x5: bias +// x6: row +// x7: col +// x8: stride +// x9: a_sums +// x10: b_sums +// x19/w19: a_zp +// x19/w20: b_zp_sum +// x21: act_type -> 0: none, 1:Relu, 3:Relu6 + +asm_function DynamicMatmulSdot4x4x16AIWI + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x19, [sp, #24] + ldr x20, [sp, #32] + ldr x21, [sp, #40] + ldr x22, [sp, #48] + + dup v16.4s, wzr // dup:Duplicate general-purpose register to vector. + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + mov x11, x1 // reload rhs ptr + mov x17, x0 // reload lhs ptr + mov x16, x3 // reload depth + + cmp x7, #4 + ble LoopDepthQuarter + cmp x7, #8 + ble LoopDepthHalf + +LoopDepth: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x11], #64 + + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v18.4s, v3.16b, v0.4b[0] + sdot v19.4s, v4.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v22.4s, v3.16b, v0.4b[1] + sdot v23.4s, v4.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v26.4s, v3.16b, v0.4b[2] + sdot v27.4s, v4.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + sdot v30.4s, v3.16b, v0.4b[3] + sdot v31.4s, v4.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepth + b AddInputSum + +LoopDepthHalf: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b, v2.16b}, [x11] + add x11, x11, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepthHalf + b AddInputSum + +LoopDepthQuarter: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x11] + add x11, x11, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepthQuarter + b AddInputSum + +AddInputSum: + cmp w20, #0 + beq AddInputSumEnd + ld1 {v5.4s}, [x9], #16 + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + dup v8.4s, v5.s[2] + dup v9.4s, v5.s[3] + + sub v16.4s, v16.4s, v6.4s + sub v17.4s, v17.4s, v6.4s + sub v18.4s, v18.4s, v6.4s + sub v19.4s, v19.4s, v6.4s + sub v20.4s, v20.4s, v7.4s + sub v21.4s, v21.4s, v7.4s + sub v22.4s, v22.4s, v7.4s + sub v23.4s, v23.4s, v7.4s + sub v24.4s, v24.4s, v8.4s + sub v25.4s, v25.4s, v8.4s + sub v26.4s, v26.4s, v8.4s + sub v27.4s, v27.4s, v8.4s + sub v28.4s, v28.4s, v9.4s + sub v29.4s, v29.4s, v9.4s + sub v30.4s, v30.4s, v9.4s + sub v31.4s, v31.4s, v9.4s +AddInputSumEnd: + +AddWeightSum: + ld1 {v9.4s}, [x10], #16 + ld1 {v10.4s}, [x10], #16 + ld1 {v11.4s}, [x10], #16 + ld1 {v12.4s}, [x10], #16 + dup v13.4s, w19 + mul v9.4s, v9.4s, v13.4s + mul v10.4s, v10.4s, v13.4s + mul v11.4s, v11.4s, v13.4s + mul v12.4s, v12.4s, v13.4s + sub v16.4s, v16.4s, v9.4s + sub v17.4s, v17.4s, v10.4s + sub v18.4s, v18.4s, v11.4s + sub v19.4s, v19.4s, v12.4s + sub v20.4s, v20.4s, v9.4s + sub v21.4s, v21.4s, v10.4s + sub v22.4s, v22.4s, v11.4s + sub v23.4s, v23.4s, v12.4s + sub v24.4s, v24.4s, v9.4s + sub v25.4s, v25.4s, v10.4s + sub v26.4s, v26.4s, v11.4s + sub v27.4s, v27.4s, v12.4s + sub v28.4s, v28.4s, v9.4s + sub v29.4s, v29.4s, v10.4s + sub v30.4s, v30.4s, v11.4s + sub v31.4s, v31.4s, v12.4s + +AddZpSum: + mul w15, w19, w20 + cmp w15, #0 + beq AddZpSumEnd + dup v14.4s, w15 + add v16.4s, v16.4s, v14.4s + add v17.4s, v17.4s, v14.4s + add v18.4s, v18.4s, v14.4s + add v19.4s, v19.4s, v14.4s + add v20.4s, v20.4s, v14.4s + add v21.4s, v21.4s, v14.4s + add v22.4s, v22.4s, v14.4s + add v23.4s, v23.4s, v14.4s + add v24.4s, v24.4s, v14.4s + add v25.4s, v25.4s, v14.4s + add v26.4s, v26.4s, v14.4s + add v27.4s, v27.4s, v14.4s + add v28.4s, v28.4s, v14.4s + add v29.4s, v29.4s, v14.4s + add v30.4s, v30.4s, v14.4s + add v31.4s, v31.4s, v14.4s +AddZpSumEnd: + +Convert2Float: + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + scvtf v20.4s, v20.4s + scvtf v21.4s, v21.4s + scvtf v22.4s, v22.4s + scvtf v23.4s, v23.4s + scvtf v24.4s, v24.4s + scvtf v25.4s, v25.4s + scvtf v26.4s, v26.4s + scvtf v27.4s, v27.4s + scvtf v28.4s, v28.4s + scvtf v29.4s, v29.4s + scvtf v30.4s, v30.4s + scvtf v31.4s, v31.4s + +MultiplyScale: + // multi_scale * input_matrix + cbz x22, TensorXTensor + cmp x22, #1 + beq TensorXChannel + cmp x22, #2 + beq ChannelXTensor + ChannelXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x4], #64 + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x4], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x4] + + fmul v20.4s,v20.4s,v4.4s + fmul v21.4s,v21.4s,v5.4s + fmul v22.4s,v22.4s,v6.4s + fmul v23.4s,v23.4s,v7.4s + + fmul v24.4s,v24.4s,v8.4s + fmul v25.4s,v25.4s,v9.4s + fmul v26.4s,v26.4s,v10.4s + fmul v27.4s,v27.4s,v11.4s + + fmul v28.4s,v28.4s,v12.4s + fmul v29.4s,v29.4s,v13.4s + fmul v30.4s,v30.4s,v14.4s + fmul v31.4s,v31.4s,v15.4s + b AddBias + + TensorXTensor: + ld1 {v0.s}[0], [x4] + + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[0] + fmul v21.4s,v21.4s,v0.s[0] + fmul v22.4s,v22.4s,v0.s[0] + fmul v23.4s,v23.4s,v0.s[0] + + fmul v24.4s,v24.4s,v0.s[0] + fmul v25.4s,v25.4s,v0.s[0] + fmul v26.4s,v26.4s,v0.s[0] + fmul v27.4s,v27.4s,v0.s[0] + + fmul v28.4s,v28.4s,v0.s[0] + fmul v29.4s,v29.4s,v0.s[0] + fmul v30.4s,v30.4s,v0.s[0] + fmul v31.4s,v31.4s,v0.s[0] + b AddBias + + TensorXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4] + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + fmul v20.4s,v20.4s,v0.4s + fmul v21.4s,v21.4s,v1.4s + fmul v22.4s,v22.4s,v2.4s + fmul v23.4s,v23.4s,v3.4s + + fmul v24.4s,v24.4s,v0.4s + fmul v25.4s,v25.4s,v1.4s + fmul v26.4s,v26.4s,v2.4s + fmul v27.4s,v27.4s,v3.4s + + fmul v28.4s,v28.4s,v0.4s + fmul v29.4s,v29.4s,v1.4s + fmul v30.4s,v30.4s,v2.4s + fmul v31.4s,v31.4s,v3.4s + b AddBias + + ChannelXTensor: + ld1 {v0.4s}, [x4] + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[1] + fmul v21.4s,v21.4s,v0.s[1] + fmul v22.4s,v22.4s,v0.s[1] + fmul v23.4s,v23.4s,v0.s[1] + + fmul v24.4s,v24.4s,v0.s[2] + fmul v25.4s,v25.4s,v0.s[2] + fmul v26.4s,v26.4s,v0.s[2] + fmul v27.4s,v27.4s,v0.s[2] + + fmul v28.4s,v28.4s,v0.s[3] + fmul v29.4s,v29.4s,v0.s[3] + fmul v30.4s,v30.4s,v0.s[3] + fmul v31.4s,v31.4s,v0.s[3] +AddBias: + // +bias + cbz x5, StoreData + ld1 {v1.4s, v2.4s, v3.4s, v4.4s}, [x5] + + fadd v16.4s,v16.4s,v1.4s + fadd v17.4s,v17.4s,v2.4s + fadd v18.4s,v18.4s,v3.4s + fadd v19.4s,v19.4s,v4.4s + + fadd v20.4s,v20.4s,v1.4s + fadd v21.4s,v21.4s,v2.4s + fadd v22.4s,v22.4s,v3.4s + fadd v23.4s,v23.4s,v4.4s + + fadd v24.4s,v24.4s,v1.4s + fadd v25.4s,v25.4s,v2.4s + fadd v26.4s,v26.4s,v3.4s + fadd v27.4s,v27.4s,v4.4s + + fadd v28.4s,v28.4s,v1.4s + fadd v29.4s,v29.4s,v2.4s + fadd v30.4s,v30.4s,v3.4s + fadd v31.4s,v31.4s,v4.4s + +Activate: + cmp x21, #1 + beq Relu + cmp x21, #3 + beq Relu6 + b StoreData + +Relu: + dup v1.4s, wzr + + smax v16.4s,v16.4s,v1.4s + smax v17.4s,v17.4s,v1.4s + smax v18.4s,v18.4s,v1.4s + smax v19.4s,v19.4s,v1.4s + + smax v20.4s,v20.4s,v1.4s + smax v21.4s,v21.4s,v1.4s + smax v22.4s,v22.4s,v1.4s + smax v23.4s,v23.4s,v1.4s + + smax v24.4s,v24.4s,v1.4s + smax v25.4s,v25.4s,v1.4s + smax v26.4s,v26.4s,v1.4s + smax v27.4s,v27.4s,v1.4s + + smax v28.4s,v28.4s,v1.4s + smax v29.4s,v29.4s,v1.4s + smax v30.4s,v30.4s,v1.4s + smax v31.4s,v31.4s,v1.4s + + b StoreData + +Relu6: + dup v1.4s, wzr + movi v2.4s, #6 + scvtf v2.4s, v2.4s + + // max (out, 0) + smax v16.4s,v16.4s,v1.4s + smax v17.4s,v17.4s,v1.4s + smax v18.4s,v18.4s,v1.4s + smax v19.4s,v19.4s,v1.4s + + smax v20.4s,v20.4s,v1.4s + smax v21.4s,v21.4s,v1.4s + smax v22.4s,v22.4s,v1.4s + smax v23.4s,v23.4s,v1.4s + + smax v24.4s,v24.4s,v1.4s + smax v25.4s,v25.4s,v1.4s + smax v26.4s,v26.4s,v1.4s + smax v27.4s,v27.4s,v1.4s + + smax v28.4s,v28.4s,v1.4s + smax v29.4s,v29.4s,v1.4s + smax v30.4s,v30.4s,v1.4s + smax v31.4s,v31.4s,v1.4s + + // min (out, 6) + + smin v16.4s,v16.4s,v2.4s + smin v17.4s,v17.4s,v2.4s + smin v18.4s,v18.4s,v2.4s + smin v19.4s,v19.4s,v2.4s + + smin v20.4s,v20.4s,v2.4s + smin v21.4s,v21.4s,v2.4s + smin v22.4s,v22.4s,v2.4s + smin v23.4s,v23.4s,v2.4s + + smin v24.4s,v24.4s,v2.4s + smin v25.4s,v25.4s,v2.4s + smin v26.4s,v26.4s,v2.4s + smin v27.4s,v27.4s,v2.4s + + smin v28.4s,v28.4s,v2.4s + smin v29.4s,v29.4s,v2.4s + smin v30.4s,v30.4s,v2.4s + smin v31.4s,v31.4s,v2.4s + + b StoreData + +StoreData: + cmp x7, #16 + beq Write16 + + mov x15, x2 // reload out ptr + add x14, x15, x8 + add x13, x14, x8 + add x12, x13, x8 + + cmp x7, #15 + beq Write15 + cmp x7, #14 + beq Write14 + cmp x7, #13 + beq Write13 + cmp x7, #12 + beq Write12 + cmp x7, #11 + beq Write11 + cmp x7, #10 + beq Write10 + cmp x7, #9 + beq Write9 + cmp x7, #8 + beq Write8 + cmp x7, #7 + beq Write7 + cmp x7, #6 + beq Write6 + cmp x7, #5 + beq Write5 + cmp x7, #4 + beq Write4 + cmp x7, #3 + beq Write3 + cmp x7, #2 + beq Write2 + cmp x7, #1 + beq Write1 + b StoreDataEnd + +Write16: + cmp x6, #4 + beq Write16Row4 + cmp x6, #3 + beq Write16Row3 + cmp x6, #2 + beq Write16Row2 + cmp x6, #1 + beq Write16Row1 + + Write16Row4: + st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2], x8 + st1 {v20.4s,v21.4s,v22.4s,v23.4s}, [x2], x8 + st1 {v24.4s,v25.4s,v26.4s,v27.4s}, [x2], x8 + st1 {v28.4s,v29.4s,v30.4s,v31.4s}, [x2] + b StoreDataEnd + Write16Row3: + st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2], x8 + st1 {v20.4s,v21.4s,v22.4s,v23.4s}, [x2], x8 + st1 {v24.4s,v25.4s,v26.4s,v27.4s}, [x2] + b StoreDataEnd + Write16Row2: + st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2], x8 + st1 {v20.4s,v21.4s,v22.4s,v23.4s}, [x2] + b StoreDataEnd + Write16Row1: + st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2] + b StoreDataEnd + +Write15: + st1 {v16.4s,v17.4s,v18.4s}, [x15], #48 + st1 {v19.1d}, [x15], #8 + st1 {v19.s}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s,v22.4s}, [x14], #48 + st1 {v23.1d}, [x14], #8 + st1 {v23.s}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s,v26.4s}, [x13], #48 + st1 {v27.1d}, [x13], #8 + st1 {v27.s}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s,v30.4s}, [x12], #48 + st1 {v31.1d}, [x12], #8 + st1 {v31.s}[2], [x12] + b StoreDataEnd + +Write14: + st1 {v16.4s,v17.4s,v18.4s}, [x15], #48 + st1 {v19.1d}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s,v22.4s}, [x14], #48 + st1 {v23.1d}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s,v26.4s}, [x13], #48 + st1 {v27.1d}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s,v30.4s}, [x12], #48 + st1 {v31.1d}, [x12] + b StoreDataEnd + +Write13: + st1 {v16.4s,v17.4s,v18.4s}, [x15], #48 + st1 {v19.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s,v22.4s}, [x14], #48 + st1 {v23.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s,v26.4s}, [x13], #48 + st1 {v27.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s,v30.4s}, [x12], #48 + st1 {v31.s}[0], [x12] + b StoreDataEnd + +Write12: + st1 {v16.4s,v17.4s,v18.4s}, [x15], #48 + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s,v22.4s}, [x14], #48 + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s,v26.4s}, [x13], #48 + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s,v30.4s}, [x12], #48 + b StoreDataEnd + +Write11: + st1 {v16.4s,v17.4s}, [x15], #32 + st1 {v18.1d}, [x15], #8 + st1 {v18.s}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s}, [x14], #32 + st1 {v22.1d}, [x14], #8 + st1 {v22.s}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s}, [x13], #32 + st1 {v26.1d}, [x13], #8 + st1 {v26.s}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s}, [x12], #32 + st1 {v30.1d}, [x12], #8 + st1 {v30.s}[2], [x12] + b StoreDataEnd + +Write10: + st1 {v16.4s,v17.4s}, [x15], #32 + st1 {v18.1d}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s}, [x14], #32 + st1 {v22.1d}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s}, [x13], #32 + st1 {v26.1d}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s}, [x12], #32 + st1 {v30.1d}, [x12] + b StoreDataEnd + +Write9: + st1 {v16.4s,v17.4s}, [x15], #32 + st1 {v18.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s}, [x14], #32 + st1 {v22.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s}, [x13], #32 + st1 {v26.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s}, [x12], #32 + st1 {v30.s}[0], [x12] + b StoreDataEnd + +Write8: + st1 {v16.4s,v17.4s}, [x15], #32 + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s}, [x14], #32 + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s}, [x13], #32 + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s}, [x12], #32 + b StoreDataEnd + +Write7: + st1 {v16.4s}, [x15], #16 + st1 {v17.1d}, [x15], #8 + st1 {v17.s}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s}, [x14], #16 + st1 {v21.1d}, [x14], #8 + st1 {v21.s}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s}, [x13], #16 + st1 {v25.1d}, [x13], #8 + st1 {v25.s}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s}, [x12], #16 + st1 {v29.1d}, [x12], #8 + st1 {v29.s}[2], [x12] + b StoreDataEnd + +Write6: + st1 {v16.4s}, [x15], #16 + st1 {v17.1d}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s}, [x14], #16 + st1 {v21.1d}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s}, [x13], #16 + st1 {v25.1d}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s}, [x12], #16 + st1 {v29.1d}, [x12] + b StoreDataEnd + +Write5: + st1 {v16.4s}, [x15], #16 + st1 {v17.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s}, [x14], #16 + st1 {v21.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s}, [x13], #16 + st1 {v25.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s}, [x12], #16 + st1 {v29.s}[0], [x12] + b StoreDataEnd + +Write4: + st1 {v16.4s}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s}, [x12] + b StoreDataEnd + +Write3: + st1 {v16.1d}, [x15], #8 + st1 {v16.s}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.1d}, [x14], #8 + st1 {v20.s}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.1d}, [x13], #8 + st1 {v24.s}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.1d}, [x12], #8 + st1 {v28.s}[2], [x12] + b StoreDataEnd + +Write2: + st1 {v16.1d}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.1d}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.1d}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.1d}, [x12] + b StoreDataEnd + +Write1: + st1 {v16.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.s}[0], [x12] + b StoreDataEnd +StoreDataEnd: + sub sp, sp, #160 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..b60055bddc85410fff6e653d12ec33cacbf04f3d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S @@ -0,0 +1,788 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" +.text +.align 5 + +// void DynamicMatmulSdot4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, +// float16_t *multi_scales, float16_t *bias, size_t row, size_t col, size_t stride, +// const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, +// int64_t act_type, int64_t mode); +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// x3: deep +// x4: multi_scales +// x5: bias +// x6: row +// x7: col +// x8: stride +// x9: a_sums +// x10: b_sums +// x19/w19: a_zp +// x19/w20: b_zp_sum +// x21: act_type -> 0: none, 1:Relu, 3:Relu6 + +asm_function DynamicMatmulSdot4x4x16AIWIForFp16 + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x19, [sp, #24] + ldr x20, [sp, #32] + ldr x21, [sp, #40] + ldr x22, [sp, #48] + + dup v16.4s, wzr // dup:Duplicate general-purpose register to vector. + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + mov x11, x1 // reload rhs ptr + mov x17, x0 // reload lhs ptr + mov x16, x3 // reload depth + + cmp x7, #4 + ble LoopDepthQuarter + cmp x7, #8 + ble LoopDepthHalf + +LoopDepth: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x11], #64 + + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v18.4s, v3.16b, v0.4b[0] + sdot v19.4s, v4.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v22.4s, v3.16b, v0.4b[1] + sdot v23.4s, v4.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v26.4s, v3.16b, v0.4b[2] + sdot v27.4s, v4.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + sdot v30.4s, v3.16b, v0.4b[3] + sdot v31.4s, v4.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepth + b AddInputSum + +LoopDepthHalf: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b, v2.16b}, [x11] + add x11, x11, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepthHalf + b AddInputSum + +LoopDepthQuarter: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x11] + add x11, x11, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepthQuarter + b AddInputSum + +AddInputSum: + cmp w20, #0 + beq AddInputSumEnd + ld1 {v5.4s}, [x9], #16 + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + dup v8.4s, v5.s[2] + dup v9.4s, v5.s[3] + + sub v16.4s, v16.4s, v6.4s + sub v17.4s, v17.4s, v6.4s + sub v18.4s, v18.4s, v6.4s + sub v19.4s, v19.4s, v6.4s + sub v20.4s, v20.4s, v7.4s + sub v21.4s, v21.4s, v7.4s + sub v22.4s, v22.4s, v7.4s + sub v23.4s, v23.4s, v7.4s + sub v24.4s, v24.4s, v8.4s + sub v25.4s, v25.4s, v8.4s + sub v26.4s, v26.4s, v8.4s + sub v27.4s, v27.4s, v8.4s + sub v28.4s, v28.4s, v9.4s + sub v29.4s, v29.4s, v9.4s + sub v30.4s, v30.4s, v9.4s + sub v31.4s, v31.4s, v9.4s +AddInputSumEnd: + +AddWeightSum: + ld1 {v9.4s}, [x10], #16 + ld1 {v10.4s}, [x10], #16 + ld1 {v11.4s}, [x10], #16 + ld1 {v12.4s}, [x10], #16 + dup v13.4s, w19 + mul v9.4s, v9.4s, v13.4s + mul v10.4s, v10.4s, v13.4s + mul v11.4s, v11.4s, v13.4s + mul v12.4s, v12.4s, v13.4s + sub v16.4s, v16.4s, v9.4s + sub v17.4s, v17.4s, v10.4s + sub v18.4s, v18.4s, v11.4s + sub v19.4s, v19.4s, v12.4s + sub v20.4s, v20.4s, v9.4s + sub v21.4s, v21.4s, v10.4s + sub v22.4s, v22.4s, v11.4s + sub v23.4s, v23.4s, v12.4s + sub v24.4s, v24.4s, v9.4s + sub v25.4s, v25.4s, v10.4s + sub v26.4s, v26.4s, v11.4s + sub v27.4s, v27.4s, v12.4s + sub v28.4s, v28.4s, v9.4s + sub v29.4s, v29.4s, v10.4s + sub v30.4s, v30.4s, v11.4s + sub v31.4s, v31.4s, v12.4s + +AddZpSum: + mul w15, w19, w20 + cmp w15, #0 + beq AddZpSumEnd + dup v14.4s, w15 + add v16.4s, v16.4s, v14.4s + add v17.4s, v17.4s, v14.4s + add v18.4s, v18.4s, v14.4s + add v19.4s, v19.4s, v14.4s + add v20.4s, v20.4s, v14.4s + add v21.4s, v21.4s, v14.4s + add v22.4s, v22.4s, v14.4s + add v23.4s, v23.4s, v14.4s + add v24.4s, v24.4s, v14.4s + add v25.4s, v25.4s, v14.4s + add v26.4s, v26.4s, v14.4s + add v27.4s, v27.4s, v14.4s + add v28.4s, v28.4s, v14.4s + add v29.4s, v29.4s, v14.4s + add v30.4s, v30.4s, v14.4s + add v31.4s, v31.4s, v14.4s +AddZpSumEnd: + +Convert2Float: + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + scvtf v20.4s, v20.4s + scvtf v21.4s, v21.4s + scvtf v22.4s, v22.4s + scvtf v23.4s, v23.4s + scvtf v24.4s, v24.4s + scvtf v25.4s, v25.4s + scvtf v26.4s, v26.4s + scvtf v27.4s, v27.4s + scvtf v28.4s, v28.4s + scvtf v29.4s, v29.4s + scvtf v30.4s, v30.4s + scvtf v31.4s, v31.4s + +MultiplyScale: + // multi_scale * input_matrix + cbz x22, TensorXTensor + cmp x22, #1 + beq TensorXChannel + cmp x22, #2 + beq ChannelXTensor + ChannelXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x4], #64 + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x4], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x4] + + fmul v20.4s,v20.4s,v4.4s + fmul v21.4s,v21.4s,v5.4s + fmul v22.4s,v22.4s,v6.4s + fmul v23.4s,v23.4s,v7.4s + + fmul v24.4s,v24.4s,v8.4s + fmul v25.4s,v25.4s,v9.4s + fmul v26.4s,v26.4s,v10.4s + fmul v27.4s,v27.4s,v11.4s + + fmul v28.4s,v28.4s,v12.4s + fmul v29.4s,v29.4s,v13.4s + fmul v30.4s,v30.4s,v14.4s + fmul v31.4s,v31.4s,v15.4s + b ConvertHalfPrecision + + TensorXTensor: + ld1 {v0.s}[0], [x4] + + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[0] + fmul v21.4s,v21.4s,v0.s[0] + fmul v22.4s,v22.4s,v0.s[0] + fmul v23.4s,v23.4s,v0.s[0] + + fmul v24.4s,v24.4s,v0.s[0] + fmul v25.4s,v25.4s,v0.s[0] + fmul v26.4s,v26.4s,v0.s[0] + fmul v27.4s,v27.4s,v0.s[0] + + fmul v28.4s,v28.4s,v0.s[0] + fmul v29.4s,v29.4s,v0.s[0] + fmul v30.4s,v30.4s,v0.s[0] + fmul v31.4s,v31.4s,v0.s[0] + b ConvertHalfPrecision + + TensorXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4] + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + fmul v20.4s,v20.4s,v0.4s + fmul v21.4s,v21.4s,v1.4s + fmul v22.4s,v22.4s,v2.4s + fmul v23.4s,v23.4s,v3.4s + + fmul v24.4s,v24.4s,v0.4s + fmul v25.4s,v25.4s,v1.4s + fmul v26.4s,v26.4s,v2.4s + fmul v27.4s,v27.4s,v3.4s + + fmul v28.4s,v28.4s,v0.4s + fmul v29.4s,v29.4s,v1.4s + fmul v30.4s,v30.4s,v2.4s + fmul v31.4s,v31.4s,v3.4s + b ConvertHalfPrecision + + ChannelXTensor: + ld1 {v0.4s}, [x4] + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[1] + fmul v21.4s,v21.4s,v0.s[1] + fmul v22.4s,v22.4s,v0.s[1] + fmul v23.4s,v23.4s,v0.s[1] + + fmul v24.4s,v24.4s,v0.s[2] + fmul v25.4s,v25.4s,v0.s[2] + fmul v26.4s,v26.4s,v0.s[2] + fmul v27.4s,v27.4s,v0.s[2] + + fmul v28.4s,v28.4s,v0.s[3] + fmul v29.4s,v29.4s,v0.s[3] + fmul v30.4s,v30.4s,v0.s[3] + fmul v31.4s,v31.4s,v0.s[3] + +ConvertHalfPrecision: +// from single-precision convert to half-precision + fcvtn v16.4h,v16.4s + fcvtn v17.4h,v17.4s + fcvtn v18.4h,v18.4s + fcvtn v19.4h,v19.4s + + fcvtn v20.4h,v20.4s + fcvtn v21.4h,v21.4s + fcvtn v22.4h,v22.4s + fcvtn v23.4h,v23.4s + + fcvtn v24.4h,v24.4s + fcvtn v25.4h,v25.4s + fcvtn v26.4h,v26.4s + fcvtn v27.4h,v27.4s + + fcvtn v28.4h,v28.4s + fcvtn v29.4h,v29.4s + fcvtn v30.4h,v30.4s + fcvtn v31.4h,v31.4s + +AddBias: + // +bias + cbz x5, StoreData + ld1 {v1.4h, v2.4h, v3.4h, v4.4h}, [x5] + + fadd v16.4h,v16.4h,v1.4h + fadd v17.4h,v17.4h,v2.4h + fadd v18.4h,v18.4h,v3.4h + fadd v19.4h,v19.4h,v4.4h + + fadd v20.4h,v20.4h,v1.4h + fadd v21.4h,v21.4h,v2.4h + fadd v22.4h,v22.4h,v3.4h + fadd v23.4h,v23.4h,v4.4h + + fadd v24.4h,v24.4h,v1.4h + fadd v25.4h,v25.4h,v2.4h + fadd v26.4h,v26.4h,v3.4h + fadd v27.4h,v27.4h,v4.4h + + fadd v28.4h,v28.4h,v1.4h + fadd v29.4h,v29.4h,v2.4h + fadd v30.4h,v30.4h,v3.4h + fadd v31.4h,v31.4h,v4.4h + +Activate: + cmp x21, #1 + beq Relu + cmp x21, #3 + beq Relu6 + b StoreData + +Relu: + dup v1.4h, wzr + + smax v16.4h,v16.4h,v1.4h + smax v17.4h,v17.4h,v1.4h + smax v18.4h,v18.4h,v1.4h + smax v19.4h,v19.4h,v1.4h + + smax v20.4h,v20.4h,v1.4h + smax v21.4h,v21.4h,v1.4h + smax v22.4h,v22.4h,v1.4h + smax v23.4h,v23.4h,v1.4h + + smax v24.4h,v24.4h,v1.4h + smax v25.4h,v25.4h,v1.4h + smax v26.4h,v26.4h,v1.4h + smax v27.4h,v27.4h,v1.4h + + smax v28.4h,v28.4h,v1.4h + smax v29.4h,v29.4h,v1.4h + smax v30.4h,v30.4h,v1.4h + smax v31.4h,v31.4h,v1.4h + + b StoreData + +Relu6: + dup v1.4h, wzr + movi v2.4h, #6 + scvtf v2.4h, v2.4h + + // max (out, 0) + smax v16.4h,v16.4h,v1.4h + smax v17.4h,v17.4h,v1.4h + smax v18.4h,v18.4h,v1.4h + smax v19.4h,v19.4h,v1.4h + + smax v20.4h,v20.4h,v1.4h + smax v21.4h,v21.4h,v1.4h + smax v22.4h,v22.4h,v1.4h + smax v23.4h,v23.4h,v1.4h + + smax v24.4h,v24.4h,v1.4h + smax v25.4h,v25.4h,v1.4h + smax v26.4h,v26.4h,v1.4h + smax v27.4h,v27.4h,v1.4h + + smax v28.4h,v28.4h,v1.4h + smax v29.4h,v29.4h,v1.4h + smax v30.4h,v30.4h,v1.4h + smax v31.4h,v31.4h,v1.4h + + // min (out, 6) + + smin v16.4h,v16.4h,v2.4h + smin v17.4h,v17.4h,v2.4h + smin v18.4h,v18.4h,v2.4h + smin v19.4h,v19.4h,v2.4h + + smin v20.4h,v20.4h,v2.4h + smin v21.4h,v21.4h,v2.4h + smin v22.4h,v22.4h,v2.4h + smin v23.4h,v23.4h,v2.4h + + smin v24.4h,v24.4h,v2.4h + smin v25.4h,v25.4h,v2.4h + smin v26.4h,v26.4h,v2.4h + smin v27.4h,v27.4h,v2.4h + + smin v28.4h,v28.4h,v2.4h + smin v29.4h,v29.4h,v2.4h + smin v30.4h,v30.4h,v2.4h + smin v31.4h,v31.4h,v2.4h + + b StoreData + +StoreData: + cmp x7, #16 + beq Write16 + + mov x15, x2 // reload out ptr + add x14, x15, x8 + add x13, x14, x8 + add x12, x13, x8 + + cmp x7, #15 + beq Write15 + cmp x7, #14 + beq Write14 + cmp x7, #13 + beq Write13 + cmp x7, #12 + beq Write12 + cmp x7, #11 + beq Write11 + cmp x7, #10 + beq Write10 + cmp x7, #9 + beq Write9 + cmp x7, #8 + beq Write8 + cmp x7, #7 + beq Write7 + cmp x7, #6 + beq Write6 + cmp x7, #5 + beq Write5 + cmp x7, #4 + beq Write4 + cmp x7, #3 + beq Write3 + cmp x7, #2 + beq Write2 + cmp x7, #1 + beq Write1 + b StoreDataEnd + +Write16: + cmp x6, #4 + beq Write16Row4 + cmp x6, #3 + beq Write16Row3 + cmp x6, #2 + beq Write16Row2 + cmp x6, #1 + beq Write16Row1 + + Write16Row4: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2], x8 + st1 {v20.4h,v21.4h,v22.4h,v23.4h}, [x2], x8 + st1 {v24.4h,v25.4h,v26.4h,v27.4h}, [x2], x8 + st1 {v28.4h,v29.4h,v30.4h,v31.4h}, [x2] + b StoreDataEnd + Write16Row3: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2], x8 + st1 {v20.4h,v21.4h,v22.4h,v23.4h}, [x2], x8 + st1 {v24.4h,v25.4h,v26.4h,v27.4h}, [x2] + b StoreDataEnd + Write16Row2: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2], x8 + st1 {v20.4h,v21.4h,v22.4h,v23.4h}, [x2] + b StoreDataEnd + Write16Row1: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2] + b StoreDataEnd + +Write15: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + st1 {v19.s}[0], [x15], #4 + st1 {v19.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + st1 {v23.s}[0], [x14], #4 + st1 {v23.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + st1 {v27.s}[0], [x13], #4 + st1 {v27.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + st1 {v31.s}[0], [x12], #4 + st1 {v31.h}[2], [x12] + b StoreDataEnd + +Write14: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + st1 {v19.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + st1 {v23.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + st1 {v27.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + st1 {v31.s}[0], [x12] + b StoreDataEnd + +Write13: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + st1 {v19.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + st1 {v23.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + st1 {v27.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + st1 {v31.h}[0], [x12] + b StoreDataEnd + +Write12: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + b StoreDataEnd + +Write11: + st1 {v16.4h,v17.4h}, [x15], #16 + st1 {v18.s}[0], [x15], #4 + st1 {v18.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + st1 {v22.s}[0], [x14], #4 + st1 {v22.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + st1 {v26.s}[0], [x13], #4 + st1 {v26.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + st1 {v30.s}[0], [x12], #4 + st1 {v30.h}[2], [x12] + b StoreDataEnd + +Write10: + st1 {v16.4h,v17.4h}, [x15], #16 + st1 {v18.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + st1 {v22.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + st1 {v26.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + st1 {v30.s}[0], [x12] + b StoreDataEnd + +Write9: + st1 {v16.4h,v17.4h}, [x15], #16 + st1 {v18.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + st1 {v22.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + st1 {v26.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + st1 {v30.h}[0], [x12] + b StoreDataEnd + +Write8: + st1 {v16.4h,v17.4h}, [x15], #16 + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + b StoreDataEnd + +Write7: + st1 {v16.4h}, [x15], #8 + st1 {v17.s}[0], [x15], #4 + st1 {v17.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14], #8 + st1 {v21.s}[0], [x14], #4 + st1 {v21.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13], #8 + st1 {v25.s}[0], [x13], #4 + st1 {v25.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12], #8 + st1 {v29.s}[0], [x12], #4 + st1 {v29.h}[2], [x12] + b StoreDataEnd + +Write6: + st1 {v16.4h}, [x15], #8 + st1 {v17.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14], #8 + st1 {v21.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13], #8 + st1 {v25.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12], #8 + st1 {v29.s}[0], [x12] + b StoreDataEnd + +Write5: + st1 {v16.4h}, [x15], #8 + st1 {v17.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14], #8 + st1 {v21.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13], #8 + st1 {v25.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12], #8 + st1 {v29.h}[0], [x12] + b StoreDataEnd + +Write4: + st1 {v16.4h}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12] + b StoreDataEnd + +Write3: + st1 {v16.s}[0], [x15], #4 + st1 {v16.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.s}[0], [x14], #4 + st1 {v20.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.s}[0], [x13], #4 + st1 {v24.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.s}[0], [x12], #4 + st1 {v28.h}[2], [x12] + b StoreDataEnd + +Write2: + st1 {v16.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.s}[0], [x12] + b StoreDataEnd + +Write1: + st1 {v16.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.h}[0], [x12] + b StoreDataEnd +StoreDataEnd: + sub sp, sp, #160 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulDpInt8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulDpInt8.S new file mode 100644 index 0000000000000000000000000000000000000000..c2818043b8d25c0968ba3f38fac2c5b8e9b4817c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulDpInt8.S @@ -0,0 +1,864 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" +.text +.align 5 + +//void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4, +// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, +// const int *multiplier, const int *left_shift, const int *right_shift, int row, +// int col, int stride, int peroc); + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row8 +// w4: col8 +// w5: deep4 +// x6: a_sums +// x7: bias +// w8: act_min +// w9: act_max +// w10: out_zp +// x11: multiplier +// x12: left_shift +// x13: right_shift +// w14: row +// w15: col +// w24: stride +// w27: filter_peroc + +asm_function MatmulInt8DpNeon64 + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + stp x23, x24, [sp], #16 + stp x25, x26, [sp], #16 + stp x27, x28, [sp], #16 + + ldr w8, [sp] + ldr w9, [sp, #8] + ldr w10, [sp, #16] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + ldr w14, [sp, #48] + ldr w15, [sp, #56] + ldr w24, [sp, #64] + ldr w27, [sp, #72] + + mov w17, #8 // sizeof(int8)*8 + mul w21, w5, w17 // the stride of a/b: sizeof(int8)*8*deep4 + mov x25, x2 +L1: + cmp w4, #0 // if at the end of col8 + beq End1 + + mov w16, w3 // reset a row8 counter + mov w23, w14 // reset a row counter + mov x17, x0 // reload a ptr + mov x22, x6 // reload a_sums ptr +L2: + cmp w16, #0 + beq End2 + + mov x28, x1 // reload b ptr + mov x19, x7 // reload bias ptr + mov w20, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w20, #16 + blt LoopD4 + +LoopD16: + ld1 {v0.16b, v1.16b}, [x17], #32 + ld1 {v2.16b, v3.16b}, [x28], #32 + + sdot v16.4s, v2.16b, v0.4b[0] + sdot v18.4s, v2.16b, v0.4b[1] + sdot v20.4s, v2.16b, v0.4b[2] + sdot v22.4s, v2.16b, v0.4b[3] + + ld1 {v4.16b, v5.16b}, [x17], #32 + sdot v24.4s, v2.16b, v1.4b[0] + sdot v26.4s, v2.16b, v1.4b[1] + sdot v28.4s, v2.16b, v1.4b[2] + sdot v30.4s, v2.16b, v1.4b[3] + + ld1 {v6.16b, v7.16b}, [x28], #32 + sdot v17.4s, v3.16b, v0.4b[0] + sdot v19.4s, v3.16b, v0.4b[1] + sdot v21.4s, v3.16b, v0.4b[2] + sdot v23.4s, v3.16b, v0.4b[3] + + sdot v25.4s, v3.16b, v1.4b[0] + sdot v27.4s, v3.16b, v1.4b[1] + sdot v29.4s, v3.16b, v1.4b[2] + sdot v31.4s, v3.16b, v1.4b[3] + + ld1 {v8.16b, v9.16b}, [x17], #32 + sdot v16.4s, v6.16b, v4.4b[0] + sdot v18.4s, v6.16b, v4.4b[1] + sdot v20.4s, v6.16b, v4.4b[2] + sdot v22.4s, v6.16b, v4.4b[3] + + sdot v24.4s, v6.16b, v5.4b[0] + sdot v26.4s, v6.16b, v5.4b[1] + sdot v28.4s, v6.16b, v5.4b[2] + sdot v30.4s, v6.16b, v5.4b[3] + + ld1 {v10.16b, v11.16b}, [x28], #32 + sdot v17.4s, v7.16b, v4.4b[0] + sdot v19.4s, v7.16b, v4.4b[1] + sdot v21.4s, v7.16b, v4.4b[2] + sdot v23.4s, v7.16b, v4.4b[3] + + sdot v25.4s, v7.16b, v5.4b[0] + sdot v27.4s, v7.16b, v5.4b[1] + sdot v29.4s, v7.16b, v5.4b[2] + sdot v31.4s, v7.16b, v5.4b[3] + + ld1 {v12.16b, v13.16b}, [x17], #32 + sdot v16.4s, v10.16b, v8.4b[0] + sdot v18.4s, v10.16b, v8.4b[1] + sdot v20.4s, v10.16b, v8.4b[2] + sdot v22.4s, v10.16b, v8.4b[3] + + sdot v24.4s, v10.16b, v9.4b[0] + sdot v26.4s, v10.16b, v9.4b[1] + sdot v28.4s, v10.16b, v9.4b[2] + sdot v30.4s, v10.16b, v9.4b[3] + + ld1 {v14.16b, v15.16b}, [x28], #32 + sdot v17.4s, v11.16b, v8.4b[0] + sdot v19.4s, v11.16b, v8.4b[1] + sdot v21.4s, v11.16b, v8.4b[2] + sdot v23.4s, v11.16b, v8.4b[3] + + sdot v25.4s, v11.16b, v9.4b[0] + sdot v27.4s, v11.16b, v9.4b[1] + sdot v29.4s, v11.16b, v9.4b[2] + sdot v31.4s, v11.16b, v9.4b[3] + + sdot v16.4s, v14.16b, v12.4b[0] + sdot v18.4s, v14.16b, v12.4b[1] + sdot v20.4s, v14.16b, v12.4b[2] + sdot v22.4s, v14.16b, v12.4b[3] + + sdot v24.4s, v14.16b, v13.4b[0] + sdot v26.4s, v14.16b, v13.4b[1] + sdot v28.4s, v14.16b, v13.4b[2] + sdot v30.4s, v14.16b, v13.4b[3] + + sdot v17.4s, v15.16b, v12.4b[0] + sdot v19.4s, v15.16b, v12.4b[1] + sdot v21.4s, v15.16b, v12.4b[2] + sdot v23.4s, v15.16b, v12.4b[3] + + sdot v25.4s, v15.16b, v13.4b[0] + sdot v27.4s, v15.16b, v13.4b[1] + sdot v29.4s, v15.16b, v13.4b[2] + sdot v31.4s, v15.16b, v13.4b[3] + + subs w20, w20, #16 // depth - 16 + b L3 + +LoopD4: + cmp w20, #0 + beq End3 + + ld1 {v0.16b, v1.16b}, [x17], #32 + ld1 {v2.16b, v3.16b}, [x28], #32 + + sdot v16.4s, v2.16b, v0.4b[0] + sdot v18.4s, v2.16b, v0.4b[1] + sdot v20.4s, v2.16b, v0.4b[2] + sdot v22.4s, v2.16b, v0.4b[3] + sdot v24.4s, v2.16b, v1.4b[0] + sdot v26.4s, v2.16b, v1.4b[1] + sdot v28.4s, v2.16b, v1.4b[2] + sdot v30.4s, v2.16b, v1.4b[3] + sdot v17.4s, v3.16b, v0.4b[0] + sdot v19.4s, v3.16b, v0.4b[1] + sdot v21.4s, v3.16b, v0.4b[2] + sdot v23.4s, v3.16b, v0.4b[3] + sdot v25.4s, v3.16b, v1.4b[0] + sdot v27.4s, v3.16b, v1.4b[1] + sdot v29.4s, v3.16b, v1.4b[2] + sdot v31.4s, v3.16b, v1.4b[3] + + subs w20, w20, #4 // depth - 4 + b LoopD4 + +End3: + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x19], #16 + ld1 {v14.4s}, [x19], #16 + add v16.4s, v16.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v20.4s, v20.4s, v15.4s + add v22.4s, v22.4s, v15.4s + add v24.4s, v24.4s, v15.4s + add v26.4s, v26.4s, v15.4s + add v28.4s, v28.4s, v15.4s + add v30.4s, v30.4s, v15.4s + add v17.4s, v17.4s, v14.4s + add v19.4s, v19.4s, v14.4s + add v21.4s, v21.4s, v14.4s + add v23.4s, v23.4s, v14.4s + add v25.4s, v25.4s, v14.4s + add v27.4s, v27.4s, v14.4s + add v29.4s, v29.4s, v14.4s + add v31.4s, v31.4s, v14.4s + + cmp w27, #0 + beq PerTSumLoad +PerCSumLoad: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x6], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + b ApplySum +PerTSumLoad: + ld1 {v14.4s}, [x22], #16 + ld1 {v15.4s}, [x22], #16 + dup v0.4s, v14.s[0] + dup v1.4s, v14.s[0] + dup v2.4s, v14.s[1] + dup v3.4s, v14.s[1] + dup v4.4s, v14.s[2] + dup v5.4s, v14.s[2] + dup v6.4s, v14.s[3] + dup v7.4s, v14.s[3] + dup v8.4s, v15.s[0] + dup v9.4s, v15.s[0] + dup v10.4s, v15.s[1] + dup v11.4s, v15.s[1] + dup v12.4s, v15.s[2] + dup v13.4s, v15.s[2] + dup v14.4s, v15.s[3] + dup v15.4s, v14.s[0] +ApplySum: + // Subtract (Asums*Zb) + sub v16.4s, v16.4s, v0.4s + sub v17.4s, v17.4s, v1.4s + sub v18.4s, v18.4s, v2.4s + sub v19.4s, v19.4s, v3.4s + sub v20.4s, v20.4s, v4.4s + sub v21.4s, v21.4s, v5.4s + sub v22.4s, v22.4s, v6.4s + sub v23.4s, v23.4s, v7.4s + sub v24.4s, v24.4s, v8.4s + sub v25.4s, v25.4s, v9.4s + sub v26.4s, v26.4s, v10.4s + sub v27.4s, v27.4s, v11.4s + sub v28.4s, v28.4s, v12.4s + sub v29.4s, v29.4s, v13.4s + sub v30.4s, v30.4s, v14.4s + sub v31.4s, v31.4s, v15.4s + + cmp w27, #0 + beq PerTRoundLoad +PerCRoundLoad: + ld1 {v8.4s, v9.4s}, [x12] + ld1 {v10.4s, v11.4s}, [x11] + ld1 {v12.4s, v13.4s}, [x13] + b ApplyRound +PerTRoundLoad: + ld1 {v14.s}[0], [x12] + dup v8.4s, v14.s[0] + dup v9.4s, v14.s[0] + ld1 {v14.s}[0], [x11] + dup v10.4s, v14.s[0] + dup v11.4s, v14.s[0] + ld1 {v14.s}[0], [x13] + dup v12.4s, v14.s[0] + dup v13.4s, v14.s[0] +ApplyRound: + // Apply left shift + sqshl v16.4s, v16.4s, v8.4s + sqshl v17.4s, v17.4s, v9.4s + sqshl v18.4s, v18.4s, v8.4s + sqshl v19.4s, v19.4s, v9.4s + sqshl v20.4s, v20.4s, v8.4s + sqshl v21.4s, v21.4s, v9.4s + sqshl v22.4s, v22.4s, v8.4s + sqshl v23.4s, v23.4s, v9.4s + sqshl v24.4s, v24.4s, v8.4s + sqshl v25.4s, v25.4s, v9.4s + sqshl v26.4s, v26.4s, v8.4s + sqshl v27.4s, v27.4s, v9.4s + sqshl v28.4s, v28.4s, v8.4s + sqshl v29.4s, v29.4s, v9.4s + sqshl v30.4s, v30.4s, v8.4s + sqshl v31.4s, v31.4s, v9.4s + + // Apply the fixed-point part of the multiplier. + sqrdmulh v16.4s, v16.4s, v10.4s + sqrdmulh v17.4s, v17.4s, v11.4s + sqrdmulh v18.4s, v18.4s, v10.4s + sqrdmulh v19.4s, v19.4s, v11.4s + sqrdmulh v20.4s, v20.4s, v10.4s + sqrdmulh v21.4s, v21.4s, v11.4s + sqrdmulh v22.4s, v22.4s, v10.4s + sqrdmulh v23.4s, v23.4s, v11.4s + sqrdmulh v24.4s, v24.4s, v10.4s + sqrdmulh v25.4s, v25.4s, v11.4s + sqrdmulh v26.4s, v26.4s, v10.4s + sqrdmulh v27.4s, v27.4s, v11.4s + sqrdmulh v28.4s, v28.4s, v10.4s + sqrdmulh v29.4s, v29.4s, v11.4s + sqrdmulh v30.4s, v30.4s, v10.4s + sqrdmulh v31.4s, v31.4s, v11.4s + + // Apply right shift + and v0.16b, v12.16b, v16.16b + sshr v0.4s, v0.4s, #31 + sqadd v16.4s, v16.4s, v0.4s + srshl v16.4s, v16.4s, v12.4s + and v1.16b, v13.16b, v17.16b + sshr v1.4s, v1.4s, #31 + sqadd v17.4s, v17.4s, v1.4s + srshl v17.4s, v17.4s, v13.4s + and v2.16b, v12.16b, v18.16b + sshr v2.4s, v2.4s, #31 + sqadd v18.4s, v18.4s, v2.4s + srshl v18.4s, v18.4s, v12.4s + and v3.16b, v13.16b, v19.16b + sshr v3.4s, v3.4s, #31 + sqadd v19.4s, v19.4s, v3.4s + srshl v19.4s, v19.4s, v13.4s + and v0.16b, v12.16b, v20.16b + sshr v0.4s, v0.4s, #31 + sqadd v20.4s, v20.4s, v0.4s + srshl v20.4s, v20.4s, v12.4s + and v1.16b, v13.16b, v21.16b + sshr v1.4s, v1.4s, #31 + sqadd v21.4s, v21.4s, v1.4s + srshl v21.4s, v21.4s, v13.4s + and v2.16b, v12.16b, v22.16b + sshr v2.4s, v2.4s, #31 + sqadd v22.4s, v22.4s, v2.4s + srshl v22.4s, v22.4s, v12.4s + and v3.16b, v13.16b, v23.16b + sshr v3.4s, v3.4s, #31 + sqadd v23.4s, v23.4s, v3.4s + srshl v23.4s, v23.4s, v13.4s + and v0.16b, v12.16b, v24.16b + sshr v0.4s, v0.4s, #31 + sqadd v24.4s, v24.4s, v0.4s + srshl v24.4s, v24.4s, v12.4s + and v1.16b, v13.16b, v25.16b + sshr v1.4s, v1.4s, #31 + sqadd v25.4s, v25.4s, v1.4s + srshl v25.4s, v25.4s, v13.4s + and v2.16b, v12.16b, v26.16b + sshr v2.4s, v2.4s, #31 + sqadd v26.4s, v26.4s, v2.4s + srshl v26.4s, v26.4s, v12.4s + and v3.16b, v13.16b, v27.16b + sshr v3.4s, v3.4s, #31 + sqadd v27.4s, v27.4s, v3.4s + srshl v27.4s, v27.4s, v13.4s + and v0.16b, v12.16b, v28.16b + sshr v0.4s, v0.4s, #31 + sqadd v28.4s, v28.4s, v0.4s + srshl v28.4s, v28.4s, v12.4s + and v1.16b, v13.16b, v29.16b + sshr v1.4s, v1.4s, #31 + sqadd v29.4s, v29.4s, v1.4s + srshl v29.4s, v29.4s, v13.4s + and v2.16b, v12.16b, v30.16b + sshr v2.4s, v2.4s, #31 + sqadd v30.4s, v30.4s, v2.4s + srshl v30.4s, v30.4s, v12.4s + and v3.16b, v13.16b, v31.16b + sshr v3.4s, v3.4s, #31 + sqadd v31.4s, v31.4s, v3.4s + srshl v31.4s, v31.4s, v13.4s + + // Add the destination zero point + dup v8.4s, w10 + add v16.4s, v16.4s, v8.4s + add v17.4s, v17.4s, v8.4s + add v18.4s, v18.4s, v8.4s + add v19.4s, v19.4s, v8.4s + add v20.4s, v20.4s, v8.4s + add v21.4s, v21.4s, v8.4s + add v22.4s, v22.4s, v8.4s + add v23.4s, v23.4s, v8.4s + add v24.4s, v24.4s, v8.4s + add v25.4s, v25.4s, v8.4s + add v26.4s, v26.4s, v8.4s + add v27.4s, v27.4s, v8.4s + add v28.4s, v28.4s, v8.4s + add v29.4s, v29.4s, v8.4s + add v30.4s, v30.4s, v8.4s + add v31.4s, v31.4s, v8.4s + + // Apply the act_min bound + dup v7.4s, w8 + smax v16.4s, v16.4s, v7.4s + smax v17.4s, v17.4s, v7.4s + smax v18.4s, v18.4s, v7.4s + smax v19.4s, v19.4s, v7.4s + smax v20.4s, v20.4s, v7.4s + smax v21.4s, v21.4s, v7.4s + smax v22.4s, v22.4s, v7.4s + smax v23.4s, v23.4s, v7.4s + smax v24.4s, v24.4s, v7.4s + smax v25.4s, v25.4s, v7.4s + smax v26.4s, v26.4s, v7.4s + smax v27.4s, v27.4s, v7.4s + smax v28.4s, v28.4s, v7.4s + smax v29.4s, v29.4s, v7.4s + smax v30.4s, v30.4s, v7.4s + smax v31.4s, v31.4s, v7.4s + + // Apply the act_max bound + dup v6.4s, w9 + smin v16.4s, v16.4s, v6.4s + smin v17.4s, v17.4s, v6.4s + smin v18.4s, v18.4s, v6.4s + smin v19.4s, v19.4s, v6.4s + smin v20.4s, v20.4s, v6.4s + smin v21.4s, v21.4s, v6.4s + smin v22.4s, v22.4s, v6.4s + smin v23.4s, v23.4s, v6.4s + smin v24.4s, v24.4s, v6.4s + smin v25.4s, v25.4s, v6.4s + smin v26.4s, v26.4s, v6.4s + smin v27.4s, v27.4s, v6.4s + smin v28.4s, v28.4s, v6.4s + smin v29.4s, v29.4s, v6.4s + smin v30.4s, v30.4s, v6.4s + smin v31.4s, v31.4s, v6.4s + + // int32 -> int16 + sqxtn v0.4h, v16.4s + sqxtn2 v0.8h, v17.4s + sqxtn v1.4h, v18.4s + sqxtn2 v1.8h, v19.4s + sqxtn v2.4h, v20.4s + sqxtn2 v2.8h, v21.4s + sqxtn v3.4h, v22.4s + sqxtn2 v3.8h, v23.4s + sqxtn v4.4h, v24.4s + sqxtn2 v4.8h, v25.4s + sqxtn v5.4h, v26.4s + sqxtn2 v5.8h, v27.4s + sqxtn v6.4h, v28.4s + sqxtn2 v6.8h, v29.4s + sqxtn v7.4h, v30.4s + sqxtn2 v7.8h, v31.4s + + // int16 -> int8 + sqxtn v8.8b, v0.8h + sqxtn2 v8.16b, v1.8h + sqxtn v9.8b, v2.8h + sqxtn2 v9.16b, v3.8h + sqxtn v10.8b, v4.8h + sqxtn2 v10.16b, v5.8h + sqxtn v11.8b, v6.8h + sqxtn2 v11.16b, v7.8h + + cmp w23, #8 + blt Write // if rows < 8 + cmp w15, #8 + blt Write // if cols < 8 + + st1 {v8.d}[0], [x2], x24 + st1 {v8.d}[1], [x2], x24 + st1 {v9.d}[0], [x2], x24 + st1 {v9.d}[1], [x2], x24 + st1 {v10.d}[0], [x2], x24 + st1 {v10.d}[1], [x2], x24 + st1 {v11.d}[0], [x2], x24 + st1 {v11.d}[1], [x2], x24 + b Endwrite + +Write: + cmp w15, #8 + bge WriteCol8 + cmp w15, #7 + beq WriteCol7 + cmp w15, #6 + beq WriteCol6 + cmp w15, #5 + beq WriteCol5 + cmp w15, #4 + beq WriteCol4 + cmp w15, #3 + beq WriteCol3 + cmp w15, #2 + beq WriteCol2 + cmp w15, #1 + beq WriteCol1 + +WriteCol8: + st1 {v8.d}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v8.d}[1], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v9.d}[0], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v9.d}[1], [x2], x24 + cmp w23, #4 + beq Endwrite + st1 {v10.d}[0], [x2], x24 + cmp w23, #5 + beq Endwrite + st1 {v10.d}[1], [x2], x24 + cmp w23, #6 + beq Endwrite + st1 {v11.d}[0], [x2], x24 + cmp w23, #7 + beq Endwrite + st1 {v11.d}[1], [x2], x24 + b Endwrite + +WriteCol7: + mov x26, x2 + st1 {v8.s}[0], [x26], #4 + st1 {v8.h}[2], [x26], #2 + st1 {v8.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v8.s}[2], [x26], #4 + st1 {v8.h}[6], [x26], #2 + st1 {v8.b}[14], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v9.s}[0], [x26], #4 + st1 {v9.h}[2], [x26], #2 + st1 {v9.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v9.s}[2], [x26], #4 + st1 {v9.h}[6], [x26], #2 + st1 {v9.b}[14], [x26], #1 + add x2, x2, x24 + cmp w23, #4 + beq Endwrite + mov x26, x2 + st1 {v10.s}[0], [x26], #4 + st1 {v10.h}[2], [x26], #2 + st1 {v10.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #5 + beq Endwrite + mov x26, x2 + st1 {v10.s}[2], [x26], #4 + st1 {v10.h}[6], [x26], #2 + st1 {v10.b}[14], [x26], #1 + add x2, x2, x24 + cmp w23, #6 + beq Endwrite + mov x26, x2 + st1 {v11.s}[0], [x26], #4 + st1 {v11.h}[2], [x26], #2 + st1 {v11.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #7 + beq Endwrite + mov x26, x2 + st1 {v11.s}[2], [x26], #4 + st1 {v11.h}[6], [x26], #2 + st1 {v11.b}[14], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol6: + mov x26, x2 + st1 {v8.s}[0], [x26], #4 + st1 {v8.h}[2], [x26], #2 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v8.s}[2], [x26], #4 + st1 {v8.h}[6], [x26], #2 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v9.s}[0], [x26], #4 + st1 {v9.h}[2], [x26], #2 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v9.s}[2], [x26], #4 + st1 {v9.h}[6], [x26], #2 + add x2, x2, x24 + cmp w23, #4 + beq Endwrite + mov x26, x2 + st1 {v10.s}[0], [x26], #4 + st1 {v10.h}[2], [x26], #2 + add x2, x2, x24 + cmp w23, #5 + beq Endwrite + mov x26, x2 + st1 {v10.s}[2], [x26], #4 + st1 {v10.h}[6], [x26], #2 + add x2, x2, x24 + cmp w23, #6 + beq Endwrite + mov x26, x2 + st1 {v11.s}[0], [x26], #4 + st1 {v11.h}[2], [x26], #2 + add x2, x2, x24 + cmp w23, #7 + beq Endwrite + mov x26, x2 + st1 {v11.s}[2], [x26], #4 + st1 {v11.h}[6], [x26], #2 + add x2, x2, x24 + b Endwrite + +WriteCol5: + mov x26, x2 + st1 {v8.s}[0], [x26], #4 + st1 {v8.b}[4], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v8.s}[2], [x26], #4 + st1 {v8.b}[12], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v9.s}[0], [x26], #4 + st1 {v9.b}[4], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v9.s}[2], [x26], #4 + st1 {v9.b}[12], [x26], #1 + add x2, x2, x24 + cmp w23, #4 + beq Endwrite + mov x26, x2 + st1 {v10.s}[0], [x26], #4 + st1 {v10.b}[4], [x26], #1 + add x2, x2, x24 + cmp w23, #5 + beq Endwrite + mov x26, x2 + st1 {v10.s}[2], [x26], #4 + st1 {v10.b}[12], [x26], #1 + add x2, x2, x24 + cmp w23, #6 + beq Endwrite + mov x26, x2 + st1 {v11.s}[0], [x26], #4 + st1 {v11.b}[4], [x26], #1 + add x2, x2, x24 + cmp w23, #7 + beq Endwrite + mov x26, x2 + st1 {v11.s}[2], [x26], #4 + st1 {v11.b}[12], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol4: + st1 {v8.s}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v8.s}[2], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v9.s}[0], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v9.s}[2], [x2], x24 + cmp w23, #4 + beq Endwrite + st1 {v10.s}[0], [x2], x24 + cmp w23, #5 + beq Endwrite + st1 {v10.s}[2], [x2], x24 + cmp w23, #6 + beq Endwrite + st1 {v11.s}[0], [x2], x24 + cmp w23, #7 + beq Endwrite + st1 {v11.s}[2], [x2], x24 + b Endwrite + +WriteCol3: + mov x26, x2 + st1 {v8.h}[0], [x26], #2 + st1 {v8.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v8.h}[4], [x26], #2 + st1 {v8.b}[10], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v9.h}[0], [x26], #2 + st1 {v9.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v9.h}[4], [x26], #2 + st1 {v9.b}[10], [x26], #1 + add x2, x2, x24 + cmp w23, #4 + beq Endwrite + mov x26, x2 + st1 {v10.h}[0], [x26], #2 + st1 {v10.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #5 + beq Endwrite + mov x26, x2 + st1 {v10.h}[4], [x26], #2 + st1 {v10.b}[10], [x26], #1 + add x2, x2, x24 + cmp w23, #6 + beq Endwrite + mov x26, x2 + st1 {v11.h}[0], [x26], #2 + st1 {v11.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #7 + beq Endwrite + mov x26, x2 + st1 {v11.h}[4], [x26], #2 + st1 {v11.b}[10], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol2: + st1 {v8.h}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v8.h}[4], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v9.h}[0], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v9.h}[4], [x2], x24 + cmp w23, #4 + beq Endwrite + st1 {v10.h}[0], [x2], x24 + cmp w23, #5 + beq Endwrite + st1 {v10.h}[4], [x2], x24 + cmp w23, #6 + beq Endwrite + st1 {v11.h}[0], [x2], x24 + cmp w23, #7 + beq Endwrite + st1 {v11.h}[4], [x2], x24 + b Endwrite + +WriteCol1: + st1 {v8.b}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v8.b}[8], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v9.b}[0], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v9.b}[8], [x2], x24 + cmp w23, #4 + beq Endwrite + st1 {v10.b}[0], [x2], x24 + cmp w23, #5 + beq Endwrite + st1 {v10.b}[8], [x2], x24 + cmp w23, #6 + beq Endwrite + st1 {v11.b}[0], [x2], x24 + cmp w23, #7 + beq Endwrite + st1 {v11.b}[8], [x2], x24 + b Endwrite + +Endwrite: + sub w16, w16, #8 // a row8 counter - 8 + sub w23, w23, #8 // a row counter - 8 + b L2 + +End2: + sub w4, w4, #8 // b col8 counter - 8 + sub w15, w15, #8 // b col counter - 8 + add x1, x1, x21 // b ptr + stride + add x7, x7, #32 // bias ptr + stride + add x25, x25, #8 // output + stride(8 * sizeof(int8)) + mov x2, x25 + + cmp w27, #0 + beq PerTEnd2 + add x12, x12, #32 + add x11, x11, #32 + add x13, x13, #32 +PerTEnd2: + b L1 + +End1: + sub sp, sp, #208 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulDpInt8Opt.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulDpInt8Opt.S new file mode 100644 index 0000000000000000000000000000000000000000..ee119b1a28eff65a6041ddfa9c1c25acff95f954 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulDpInt8Opt.S @@ -0,0 +1,1098 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" +.text +.align 5 + +//void MatmulInt8DpOpt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep4, const int *a_sums, +// const int *bias, int act_min, int act_max, int out_zp, const int32_t *multiplier, +// const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc, +// const int32_t *filter_zp) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// x3: row +// x4: col +// x5: deep4 +// x6: a_sums +// x7: bias +// w8: act_min +// w9: act_max +// w10: out_zp +// x11: multiplier +// x12: left_shift +// x13: right_shift +// x14: stride +// x15: filter_peroc +// x28: filter_zp + +asm_function MatmulInt8DpOpt + sub sp, sp, #224 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + stp x23, x24, [sp], #16 + stp x25, x26, [sp], #16 + stp x27, x28, [sp], #16 + stp x29, x30, [sp], #16 + + ldr w8, [sp] + ldr w9, [sp, #8] + ldr w10, [sp, #16] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + ldr x14, [sp, #48] + ldr x15, [sp, #56] + + mov x23, #4 + mul x23, x23, x5 // lhs step + mov x24, #4 + mul x24, x24, x14 // dst step + +LoopRow: + mov x16, x1 // reload rhs ptr + mov x17, x4 // reload rhs col + mov x29, x7 // reload bias ptr + mov x25, x6 // reload input_sum ptr + mov x27, x2 // reload dst ptr + ldr x28, [sp, #64] // reload filter_zp + + LoopCol: + mov x19, x27 // reload dst ptr + mov x20, x0 // reload lhs ptr + mov x21, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + cmp x17, #4 + ble LoopDepthQuarter + cmp x17, #8 + ble LoopDepthHalf + + LoopDepth: + ld1 {v0.16b}, [x20], #16 + ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x16], #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v18.4s, v3.16b, v0.4b[0] + sdot v19.4s, v4.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v22.4s, v3.16b, v0.4b[1] + sdot v23.4s, v4.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v26.4s, v3.16b, v0.4b[2] + sdot v27.4s, v4.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + sdot v30.4s, v3.16b, v0.4b[3] + sdot v31.4s, v4.16b, v0.4b[3] + + subs x21, x21, #4 + bgt LoopDepth + + Bias: + cbz x7, NoReadBias + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x29], #64 + add v16.4s, v16.4s, v0.4s + add v17.4s, v17.4s, v1.4s + add v18.4s, v18.4s, v2.4s + add v19.4s, v19.4s, v3.4s + add v20.4s, v20.4s, v0.4s + add v21.4s, v21.4s, v1.4s + add v22.4s, v22.4s, v2.4s + add v23.4s, v23.4s, v3.4s + add v24.4s, v24.4s, v0.4s + add v25.4s, v25.4s, v1.4s + add v26.4s, v26.4s, v2.4s + add v27.4s, v27.4s, v3.4s + add v28.4s, v28.4s, v0.4s + add v29.4s, v29.4s, v1.4s + add v30.4s, v30.4s, v2.4s + add v31.4s, v31.4s, v3.4s + + NoReadBias: + ld1r {v12.4s}, [x25], #4 + ld1r {v13.4s}, [x25], #4 + ld1r {v14.4s}, [x25], #4 + ld1r {v15.4s}, [x25], #4 + cbnz x15, PerChannelSum + + PerTensorSum: + sub v16.4s, v16.4s, v12.4s + sub v17.4s, v17.4s, v12.4s + sub v18.4s, v18.4s, v12.4s + sub v19.4s, v19.4s, v12.4s + sub v20.4s, v20.4s, v13.4s + sub v21.4s, v21.4s, v13.4s + sub v22.4s, v22.4s, v13.4s + sub v23.4s, v23.4s, v13.4s + sub v24.4s, v24.4s, v14.4s + sub v25.4s, v25.4s, v14.4s + sub v26.4s, v26.4s, v14.4s + sub v27.4s, v27.4s, v14.4s + sub v28.4s, v28.4s, v15.4s + sub v29.4s, v29.4s, v15.4s + sub v30.4s, v30.4s, v15.4s + sub v31.4s, v31.4s, v15.4s + + b PerTensor + + PerChannelSum: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x28], #64 + mul v0.4s, v8.4s, v12.4s + mul v1.4s, v9.4s, v12.4s + mul v2.4s, v10.4s, v12.4s + mul v3.4s, v11.4s, v12.4s + mul v4.4s, v8.4s, v13.4s + mul v5.4s, v9.4s, v13.4s + mul v6.4s, v10.4s, v13.4s + mul v7.4s, v11.4s, v13.4s + sub v16.4s, v16.4s, v0.4s + sub v17.4s, v17.4s, v1.4s + sub v18.4s, v18.4s, v2.4s + sub v19.4s, v19.4s, v3.4s + sub v20.4s, v20.4s, v4.4s + sub v21.4s, v21.4s, v5.4s + sub v22.4s, v22.4s, v6.4s + sub v23.4s, v23.4s, v7.4s + mul v0.4s, v8.4s, v14.4s + mul v1.4s, v9.4s, v14.4s + mul v2.4s, v10.4s, v14.4s + mul v3.4s, v11.4s, v14.4s + mul v4.4s, v8.4s, v15.4s + mul v5.4s, v9.4s, v15.4s + mul v6.4s, v10.4s, v15.4s + mul v7.4s, v11.4s, v15.4s + sub v24.4s, v24.4s, v0.4s + sub v25.4s, v25.4s, v1.4s + sub v26.4s, v26.4s, v2.4s + sub v27.4s, v27.4s, v3.4s + sub v28.4s, v28.4s, v4.4s + sub v29.4s, v29.4s, v5.4s + sub v30.4s, v30.4s, v6.4s + sub v31.4s, v31.4s, v7.4s + + PerTensor: + cbnz x15, PerChannel + ld1r {v0.4s}, [x12] + mov v1.16b, v0.16b + mov v2.16b, v0.16b + mov v3.16b, v0.16b + ld1r {v4.4s}, [x11] + mov v5.16b, v4.16b + mov v6.16b, v4.16b + mov v7.16b, v4.16b + ld1r {v8.4s}, [x13] + mov v9.16b, v8.16b + mov v10.16b, v8.16b + mov v11.16b, v8.16b + + b Quantization + + PerChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x11], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x13], #64 + + Quantization: + sqshl v16.4s, v16.4s, v0.4s + sqshl v17.4s, v17.4s, v1.4s + sqshl v18.4s, v18.4s, v2.4s + sqshl v19.4s, v19.4s, v3.4s + sqshl v20.4s, v20.4s, v0.4s + sqshl v21.4s, v21.4s, v1.4s + sqshl v22.4s, v22.4s, v2.4s + sqshl v23.4s, v23.4s, v3.4s + sqshl v24.4s, v24.4s, v0.4s + sqshl v25.4s, v25.4s, v1.4s + sqshl v26.4s, v26.4s, v2.4s + sqshl v27.4s, v27.4s, v3.4s + sqshl v28.4s, v28.4s, v0.4s + sqshl v29.4s, v29.4s, v1.4s + sqshl v30.4s, v30.4s, v2.4s + sqshl v31.4s, v31.4s, v3.4s + + sqrdmulh v16.4s, v16.4s, v4.4s + sqrdmulh v17.4s, v17.4s, v5.4s + sqrdmulh v18.4s, v18.4s, v6.4s + sqrdmulh v19.4s, v19.4s, v7.4s + sqrdmulh v20.4s, v20.4s, v4.4s + sqrdmulh v21.4s, v21.4s, v5.4s + sqrdmulh v22.4s, v22.4s, v6.4s + sqrdmulh v23.4s, v23.4s, v7.4s + sqrdmulh v24.4s, v24.4s, v4.4s + sqrdmulh v25.4s, v25.4s, v5.4s + sqrdmulh v26.4s, v26.4s, v6.4s + sqrdmulh v27.4s, v27.4s, v7.4s + sqrdmulh v28.4s, v28.4s, v4.4s + sqrdmulh v29.4s, v29.4s, v5.4s + sqrdmulh v30.4s, v30.4s, v6.4s + sqrdmulh v31.4s, v31.4s, v7.4s + + and v0.16b, v8.16b, v16.16b + sshr v0.4s, v0.4s, #31 + sqadd v16.4s, v16.4s, v0.4s + srshl v16.4s, v16.4s, v8.4s + and v1.16b, v9.16b, v17.16b + sshr v1.4s, v1.4s, #31 + sqadd v17.4s, v17.4s, v1.4s + srshl v17.4s, v17.4s, v9.4s + and v2.16b, v10.16b, v18.16b + sshr v2.4s, v2.4s, #31 + sqadd v18.4s, v18.4s, v2.4s + srshl v18.4s, v18.4s, v10.4s + and v3.16b, v11.16b, v19.16b + sshr v3.4s, v3.4s, #31 + sqadd v19.4s, v19.4s, v3.4s + srshl v19.4s, v19.4s, v11.4s + + and v0.16b, v8.16b, v20.16b + sshr v0.4s, v0.4s, #31 + sqadd v20.4s, v20.4s, v0.4s + srshl v20.4s, v20.4s, v8.4s + and v1.16b, v9.16b, v21.16b + sshr v1.4s, v1.4s, #31 + sqadd v21.4s, v21.4s, v1.4s + srshl v21.4s, v21.4s, v9.4s + and v2.16b, v10.16b, v22.16b + sshr v2.4s, v2.4s, #31 + sqadd v22.4s, v22.4s, v2.4s + srshl v22.4s, v22.4s, v10.4s + and v3.16b, v11.16b, v23.16b + sshr v3.4s, v3.4s, #31 + sqadd v23.4s, v23.4s, v3.4s + srshl v23.4s, v23.4s, v11.4s + + and v0.16b, v8.16b, v24.16b + sshr v0.4s, v0.4s, #31 + sqadd v24.4s, v24.4s, v0.4s + srshl v24.4s, v24.4s, v8.4s + and v1.16b, v9.16b, v25.16b + sshr v1.4s, v1.4s, #31 + sqadd v25.4s, v25.4s, v1.4s + srshl v25.4s, v25.4s, v9.4s + and v2.16b, v10.16b, v26.16b + sshr v2.4s, v2.4s, #31 + sqadd v26.4s, v26.4s, v2.4s + srshl v26.4s, v26.4s, v10.4s + and v3.16b, v11.16b, v27.16b + sshr v3.4s, v3.4s, #31 + sqadd v27.4s, v27.4s, v3.4s + srshl v27.4s, v27.4s, v11.4s + + and v0.16b, v8.16b, v28.16b + sshr v0.4s, v0.4s, #31 + sqadd v28.4s, v28.4s, v0.4s + srshl v28.4s, v28.4s, v8.4s + and v1.16b, v9.16b, v29.16b + sshr v1.4s, v1.4s, #31 + sqadd v29.4s, v29.4s, v1.4s + srshl v29.4s, v29.4s, v9.4s + and v2.16b, v10.16b, v30.16b + sshr v2.4s, v2.4s, #31 + sqadd v30.4s, v30.4s, v2.4s + srshl v30.4s, v30.4s, v10.4s + and v3.16b, v11.16b, v31.16b + sshr v3.4s, v3.4s, #31 + sqadd v31.4s, v31.4s, v3.4s + srshl v31.4s, v31.4s, v11.4s + + // zp + dup v6.4s, w10 + add v16.4s, v16.4s, v6.4s + add v17.4s, v17.4s, v6.4s + add v18.4s, v18.4s, v6.4s + add v19.4s, v19.4s, v6.4s + add v20.4s, v20.4s, v6.4s + add v21.4s, v21.4s, v6.4s + add v22.4s, v22.4s, v6.4s + add v23.4s, v23.4s, v6.4s + add v24.4s, v24.4s, v6.4s + add v25.4s, v25.4s, v6.4s + add v26.4s, v26.4s, v6.4s + add v27.4s, v27.4s, v6.4s + add v28.4s, v28.4s, v6.4s + add v29.4s, v29.4s, v6.4s + add v30.4s, v30.4s, v6.4s + add v31.4s, v31.4s, v6.4s + + // min + dup v0.4s, w8 + smax v16.4s, v16.4s, v0.4s + smax v17.4s, v17.4s, v0.4s + smax v18.4s, v18.4s, v0.4s + smax v19.4s, v19.4s, v0.4s + smax v20.4s, v20.4s, v0.4s + smax v21.4s, v21.4s, v0.4s + smax v22.4s, v22.4s, v0.4s + smax v23.4s, v23.4s, v0.4s + smax v24.4s, v24.4s, v0.4s + smax v25.4s, v25.4s, v0.4s + smax v26.4s, v26.4s, v0.4s + smax v27.4s, v27.4s, v0.4s + smax v28.4s, v28.4s, v0.4s + smax v29.4s, v29.4s, v0.4s + smax v30.4s, v30.4s, v0.4s + smax v31.4s, v31.4s, v0.4s + + // max + dup v1.4s, w9 + smin v16.4s, v16.4s, v1.4s + smin v17.4s, v17.4s, v1.4s + smin v18.4s, v18.4s, v1.4s + smin v19.4s, v19.4s, v1.4s + smin v20.4s, v20.4s, v1.4s + smin v21.4s, v21.4s, v1.4s + smin v22.4s, v22.4s, v1.4s + smin v23.4s, v23.4s, v1.4s + smin v24.4s, v24.4s, v1.4s + smin v25.4s, v25.4s, v1.4s + smin v26.4s, v26.4s, v1.4s + smin v27.4s, v27.4s, v1.4s + smin v28.4s, v28.4s, v1.4s + smin v29.4s, v29.4s, v1.4s + smin v30.4s, v30.4s, v1.4s + smin v31.4s, v31.4s, v1.4s + + sqxtn v16.4h, v16.4s + sqxtn2 v16.8h, v17.4s + sqxtn v0.8b, v16.8h + sqxtn v18.4h, v18.4s + sqxtn2 v18.8h, v19.4s + sqxtn2 v0.16b, v18.8h + + sqxtn v20.4h, v20.4s + sqxtn2 v20.8h, v21.4s + sqxtn v1.8b, v20.8h + sqxtn v22.4h, v22.4s + sqxtn2 v22.8h, v23.4s + sqxtn2 v1.16b, v22.8h + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v2.8b, v24.8h + sqxtn v26.4h, v26.4s + sqxtn2 v26.8h, v27.4s + sqxtn2 v2.16b, v26.8h + + sqxtn v28.4h, v28.4s + sqxtn2 v28.8h, v29.4s + sqxtn v3.8b, v28.8h + sqxtn v30.4h, v30.4s + sqxtn2 v30.8h, v31.4s + sqxtn2 v3.16b, v30.8h + + b WriteStart + + LoopDepthHalf: + ld1 {v0.16b}, [x20], #16 + ld1 {v1.16b, v2.16b}, [x16] + add x16, x16, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + + subs x21, x21, #4 + bgt LoopDepthHalf + + BiasHalf: + cbz x7, NoReadBiasHalf + ld1 {v0.4s, v1.4s}, [x29] + add x29, x29, #64 + add v16.4s, v16.4s, v0.4s + add v17.4s, v17.4s, v1.4s + add v20.4s, v20.4s, v0.4s + add v21.4s, v21.4s, v1.4s + add v24.4s, v24.4s, v0.4s + add v25.4s, v25.4s, v1.4s + add v28.4s, v28.4s, v0.4s + add v29.4s, v29.4s, v1.4s + + NoReadBiasHalf: + ld1r {v12.4s}, [x25], #4 + ld1r {v13.4s}, [x25], #4 + ld1r {v14.4s}, [x25], #4 + ld1r {v15.4s}, [x25], #4 + cbnz x15, PerChannelSumHalf + + PerTensorSumHalf: + sub v16.4s, v16.4s, v12.4s + sub v17.4s, v17.4s, v12.4s + sub v20.4s, v20.4s, v13.4s + sub v21.4s, v21.4s, v13.4s + sub v24.4s, v24.4s, v14.4s + sub v25.4s, v25.4s, v14.4s + sub v28.4s, v28.4s, v15.4s + sub v29.4s, v29.4s, v15.4s + + b PerTensorHalf + + PerChannelSumHalf: + ld1 {v8.4s, v9.4s}, [x28] + add x28, x28, #64 + mul v0.4s, v8.4s, v12.4s + mul v1.4s, v9.4s, v12.4s + mul v4.4s, v8.4s, v13.4s + mul v5.4s, v9.4s, v13.4s + sub v16.4s, v16.4s, v0.4s + sub v17.4s, v17.4s, v1.4s + sub v20.4s, v20.4s, v4.4s + sub v21.4s, v21.4s, v5.4s + mul v2.4s, v8.4s, v14.4s + mul v3.4s, v9.4s, v14.4s + mul v6.4s, v8.4s, v15.4s + mul v7.4s, v9.4s, v15.4s + sub v24.4s, v24.4s, v2.4s + sub v25.4s, v25.4s, v3.4s + sub v28.4s, v28.4s, v6.4s + sub v29.4s, v29.4s, v7.4s + + PerTensorHalf: + cbnz x15, PerChannelHalf + ld1r {v0.4s}, [x12] + mov v1.16b, v0.16b + ld1r {v4.4s}, [x11] + mov v5.16b, v4.16b + ld1r {v8.4s}, [x13] + mov v9.16b, v8.16b + + b QuantizationHalf + + PerChannelHalf: + ld1 {v0.4s, v1.4s}, [x12] + add x12, x12, #64 + ld1 {v4.4s, v5.4s}, [x11] + add x11, x11, #64 + ld1 {v8.4s, v9.4s}, [x13] + add x13, x13, #64 + + QuantizationHalf: + sqshl v16.4s, v16.4s, v0.4s + sqshl v17.4s, v17.4s, v1.4s + sqshl v20.4s, v20.4s, v0.4s + sqshl v21.4s, v21.4s, v1.4s + sqshl v24.4s, v24.4s, v0.4s + sqshl v25.4s, v25.4s, v1.4s + sqshl v28.4s, v28.4s, v0.4s + sqshl v29.4s, v29.4s, v1.4s + + sqrdmulh v16.4s, v16.4s, v4.4s + sqrdmulh v17.4s, v17.4s, v5.4s + sqrdmulh v20.4s, v20.4s, v4.4s + sqrdmulh v21.4s, v21.4s, v5.4s + sqrdmulh v24.4s, v24.4s, v4.4s + sqrdmulh v25.4s, v25.4s, v5.4s + sqrdmulh v28.4s, v28.4s, v4.4s + sqrdmulh v29.4s, v29.4s, v5.4s + + and v0.16b, v8.16b, v16.16b + sshr v0.4s, v0.4s, #31 + sqadd v16.4s, v16.4s, v0.4s + srshl v16.4s, v16.4s, v8.4s + and v1.16b, v9.16b, v17.16b + sshr v1.4s, v1.4s, #31 + sqadd v17.4s, v17.4s, v1.4s + srshl v17.4s, v17.4s, v9.4s + + and v0.16b, v8.16b, v20.16b + sshr v0.4s, v0.4s, #31 + sqadd v20.4s, v20.4s, v0.4s + srshl v20.4s, v20.4s, v8.4s + and v1.16b, v9.16b, v21.16b + sshr v1.4s, v1.4s, #31 + sqadd v21.4s, v21.4s, v1.4s + srshl v21.4s, v21.4s, v9.4s + + and v0.16b, v8.16b, v24.16b + sshr v0.4s, v0.4s, #31 + sqadd v24.4s, v24.4s, v0.4s + srshl v24.4s, v24.4s, v8.4s + and v1.16b, v9.16b, v25.16b + sshr v1.4s, v1.4s, #31 + sqadd v25.4s, v25.4s, v1.4s + srshl v25.4s, v25.4s, v9.4s + + and v0.16b, v8.16b, v28.16b + sshr v0.4s, v0.4s, #31 + sqadd v28.4s, v28.4s, v0.4s + srshl v28.4s, v28.4s, v8.4s + and v1.16b, v9.16b, v29.16b + sshr v1.4s, v1.4s, #31 + sqadd v29.4s, v29.4s, v1.4s + srshl v29.4s, v29.4s, v9.4s + + // zp + dup v6.4s, w10 + add v16.4s, v16.4s, v6.4s + add v17.4s, v17.4s, v6.4s + add v20.4s, v20.4s, v6.4s + add v21.4s, v21.4s, v6.4s + add v24.4s, v24.4s, v6.4s + add v25.4s, v25.4s, v6.4s + add v28.4s, v28.4s, v6.4s + add v29.4s, v29.4s, v6.4s + + // min + dup v0.4s, w8 + smax v16.4s, v16.4s, v0.4s + smax v17.4s, v17.4s, v0.4s + smax v20.4s, v20.4s, v0.4s + smax v21.4s, v21.4s, v0.4s + smax v24.4s, v24.4s, v0.4s + smax v25.4s, v25.4s, v0.4s + smax v28.4s, v28.4s, v0.4s + smax v29.4s, v29.4s, v0.4s + + // max + dup v1.4s, w9 + smin v16.4s, v16.4s, v1.4s + smin v17.4s, v17.4s, v1.4s + smin v20.4s, v20.4s, v1.4s + smin v21.4s, v21.4s, v1.4s + smin v24.4s, v24.4s, v1.4s + smin v25.4s, v25.4s, v1.4s + smin v28.4s, v28.4s, v1.4s + smin v29.4s, v29.4s, v1.4s + + sqxtn v16.4h, v16.4s + sqxtn2 v16.8h, v17.4s + sqxtn v0.8b, v16.8h + + sqxtn v20.4h, v20.4s + sqxtn2 v20.8h, v21.4s + sqxtn v1.8b, v20.8h + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v2.8b, v24.8h + + sqxtn v28.4h, v28.4s + sqxtn2 v28.8h, v29.4s + sqxtn v3.8b, v28.8h + + b WriteStart + + LoopDepthQuarter: + ld1 {v0.16b}, [x20], #16 + ld1 {v1.16b}, [x16] + add x16, x16, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + + subs x21, x21, #4 + bgt LoopDepthQuarter + + BiasQuarter: + cbz x7, NoReadBiasQuarter + ld1 {v0.4s}, [x29] + add x29, x29, #64 + add v16.4s, v16.4s, v0.4s + add v20.4s, v20.4s, v0.4s + add v24.4s, v24.4s, v0.4s + add v28.4s, v28.4s, v0.4s + + NoReadBiasQuarter: + ld1r {v12.4s}, [x25], #4 + ld1r {v13.4s}, [x25], #4 + ld1r {v14.4s}, [x25], #4 + ld1r {v15.4s}, [x25], #4 + cbnz x15, PerChannelSumQuarter + + PerTensorSumQuarter: + sub v16.4s, v16.4s, v12.4s + sub v20.4s, v20.4s, v13.4s + sub v24.4s, v24.4s, v14.4s + sub v28.4s, v28.4s, v15.4s + + b PerTensorQuarter + + PerChannelSumQuarter: + ld1 {v8.4s}, [x28] + add x28, x28, #64 + mul v0.4s, v8.4s, v12.4s + mul v4.4s, v8.4s, v13.4s + sub v16.4s, v16.4s, v0.4s + sub v20.4s, v20.4s, v4.4s + mul v2.4s, v8.4s, v14.4s + mul v6.4s, v8.4s, v15.4s + sub v24.4s, v24.4s, v2.4s + sub v28.4s, v28.4s, v6.4s + + PerTensorQuarter: + cbnz x15, PerChannelQuarter + ld1r {v0.4s}, [x12] + ld1r {v4.4s}, [x11] + ld1r {v8.4s}, [x13] + + b QuantizationHalf + + PerChannelQuarter: + ld1 {v0.4s}, [x12] + add x12, x12, #64 + ld1 {v4.4s}, [x11] + add x11, x11, #64 + ld1 {v8.4s}, [x13] + add x13, x13, #64 + + QuantizationQuarter: + sqshl v16.4s, v16.4s, v0.4s + sqshl v20.4s, v20.4s, v0.4s + sqshl v24.4s, v24.4s, v0.4s + sqshl v28.4s, v28.4s, v0.4s + + sqrdmulh v16.4s, v16.4s, v4.4s + sqrdmulh v20.4s, v20.4s, v4.4s + sqrdmulh v24.4s, v24.4s, v4.4s + sqrdmulh v28.4s, v28.4s, v4.4s + + and v0.16b, v8.16b, v16.16b + sshr v0.4s, v0.4s, #31 + sqadd v16.4s, v16.4s, v0.4s + srshl v16.4s, v16.4s, v8.4s + + and v0.16b, v8.16b, v20.16b + sshr v0.4s, v0.4s, #31 + sqadd v20.4s, v20.4s, v0.4s + srshl v20.4s, v20.4s, v8.4s + + and v0.16b, v8.16b, v24.16b + sshr v0.4s, v0.4s, #31 + sqadd v24.4s, v24.4s, v0.4s + srshl v24.4s, v24.4s, v8.4s + + and v0.16b, v8.16b, v28.16b + sshr v0.4s, v0.4s, #31 + sqadd v28.4s, v28.4s, v0.4s + srshl v28.4s, v28.4s, v8.4s + + // zp + dup v6.4s, w10 + add v16.4s, v16.4s, v6.4s + add v20.4s, v20.4s, v6.4s + add v24.4s, v24.4s, v6.4s + add v28.4s, v28.4s, v6.4s + + // min + dup v0.4s, w8 + smax v16.4s, v16.4s, v0.4s + smax v20.4s, v20.4s, v0.4s + smax v24.4s, v24.4s, v0.4s + smax v28.4s, v28.4s, v0.4s + + // max + dup v1.4s, w9 + smin v16.4s, v16.4s, v1.4s + smin v20.4s, v20.4s, v1.4s + smin v24.4s, v24.4s, v1.4s + smin v28.4s, v28.4s, v1.4s + + sqxtn v16.4h, v16.4s + sqxtn v0.8b, v16.8h + + sqxtn v20.4h, v20.4s + sqxtn v1.8b, v20.8h + + sqxtn v24.4h, v24.4s + sqxtn v2.8b, v24.8h + + sqxtn v28.4h, v28.4s + sqxtn v3.8b, v28.8h + + b WriteStart + + WriteStart: + cmp x17, #1 + beq Write1 + cmp x17, #2 + beq Write2 + cmp x17, #3 + beq Write3 + cmp x17, #4 + beq Write4 + cmp x17, #5 + beq Write5 + cmp x17, #6 + beq Write6 + cmp x17, #7 + beq Write7 + cmp x17, #8 + beq Write8 + cmp x17, #9 + beq Write9 + cmp x17, #10 + beq Write10 + cmp x17, #11 + beq Write11 + cmp x17, #12 + beq Write12 + cmp x17, #13 + beq Write13 + cmp x17, #14 + beq Write14 + cmp x17, #15 + beq Write15 + b Write16 + + Write1: + add x27, x27, #1 + st1 {v0.b}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.b}[0], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.b}[0], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.b}[0], [x19], x14 + b WriteEnd + Write2: + add x27, x27, #2 + st1 {v0.h}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.h}[0], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.h}[0], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.h}[0], [x19], x14 + b WriteEnd + Write3: + add x27, x27, #3 + add x22, x19, #2 + st1 {v0.h}[0], [x19], x14 + st1 {v0.b}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.h}[0], [x19], x14 + st1 {v1.b}[2], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.h}[0], [x19], x14 + st1 {v2.b}[2], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.h}[0], [x19], x14 + st1 {v3.b}[2], [x22], x14 + b WriteEnd + Write4: + add x27, x27, #4 + st1 {v0.s}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + b WriteEnd + Write5: + add x27, x27, #5 + add x22, x19, #4 + st1 {v0.s}[0], [x19], x14 + st1 {v0.b}[4], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + st1 {v1.b}[4], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + st1 {v2.b}[4], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + st1 {v3.b}[4], [x22], x14 + b WriteEnd + Write6: + add x27, x27, #6 + add x22, x19, #4 + st1 {v0.s}[0], [x19], x14 + st1 {v0.h}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + st1 {v1.h}[2], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + st1 {v2.h}[2], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + st1 {v3.h}[2], [x22], x14 + b WriteEnd + Write7: + add x27, x27, #7 + add x22, x19, #4 + add x26, x19, #6 + st1 {v0.s}[0], [x19], x14 + st1 {v0.h}[2], [x22], x14 + st1 {v0.b}[6], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + st1 {v1.h}[2], [x22], x14 + st1 {v1.b}[6], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + st1 {v2.h}[2], [x22], x14 + st1 {v2.b}[6], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + st1 {v3.h}[2], [x22], x14 + st1 {v3.b}[6], [x26], x14 + b WriteEnd + Write8: + add x27, x27, #8 + st1 {v0.8b}, [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + b WriteEnd + Write9: + add x27, x27, #9 + add x22, x19, #8 + st1 {v0.8b}, [x19], x14 + st1 {v0.b}[8], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.b}[8], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.b}[8], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.b}[8], [x22], x14 + b WriteEnd + Write10: + add x27, x27, #10 + add x22, x19, #8 + st1 {v0.8b}, [x19], x14 + st1 {v0.h}[4], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.h}[4], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.h}[4], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.h}[4], [x22], x14 + b WriteEnd + Write11: + add x27, x27, #11 + add x22, x19, #8 + add x26, x19, #10 + st1 {v0.8b}, [x19], x14 + st1 {v0.h}[4], [x22], x14 + st1 {v0.b}[10], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.h}[4], [x22], x14 + st1 {v1.b}[10], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.h}[4], [x22], x14 + st1 {v2.b}[10], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.h}[4], [x22], x14 + st1 {v3.b}[10], [x26], x14 + b WriteEnd + Write12: + add x27, x27, #12 + add x22, x19, #8 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + b WriteEnd + Write13: + add x27, x27, #13 + add x22, x19, #8 + add x26, x19, #12 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + st1 {v0.b}[12], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + st1 {v1.b}[12], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + st1 {v2.b}[12], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + st1 {v3.b}[12], [x26], x14 + b WriteEnd + Write14: + add x27, x27, #14 + add x22, x19, #8 + add x26, x19, #12 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + st1 {v0.h}[6], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + st1 {v1.h}[6], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + st1 {v2.h}[6], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + st1 {v3.h}[6], [x26], x14 + b WriteEnd + Write15: + add x27, x27, #15 + add x22, x19, #8 + add x26, x19, #12 + add x21, x19, #14 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + st1 {v0.h}[6], [x26], x14 + st1 {v0.b}[14], [x21], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + st1 {v1.h}[6], [x26], x14 + st1 {v1.b}[14], [x21], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + st1 {v2.h}[6], [x26], x14 + st1 {v2.b}[14], [x21], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + st1 {v3.h}[6], [x26], x14 + st1 {v3.b}[14], [x21], x14 + b WriteEnd + Write16: + add x27, x27, #16 + st1 {v0.16b}, [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.16b}, [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.16b}, [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.16b}, [x19], x14 + + WriteEnd: + subs x17, x17, #16 + ble LoopColEnd + mov x25, x6 + b LoopCol + +LoopColEnd: + subs x3, x3, #4 + ble LoopRowEnd + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + add x6, x6, #16 + add x0, x0, x23 + add x2, x2, x24 + b LoopRow + +LoopRowEnd: + sub sp, sp, #224 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ldp x29, x30, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulOptR4Int8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulOptR4Int8.S new file mode 100644 index 0000000000000000000000000000000000000000..28db29cb5422d24fddbf4432793aa69ac5db3021 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulOptR4Int8.S @@ -0,0 +1,155 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" +.text +.align 5 + +//void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, +// const int *input_sum, const int *bias) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row4 +// w4: col4 +// w5: deep16 +// x6: a_sums +// x7: bias + +asm_function MatMulOptR4Int8Neon64 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + + mov w15, #0 // b col index + mov w16, #0 // a row index + mov w17, #4 // sizeof(int8)*4 + mul w12, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 + +L1: + cmp w15, w4 + beq End1 + + mov w16, #0 // reset a row index + mov x17, x0 // reload a ptr + mov x13, x6 // reload a_sums ptr +L2: + cmp w16, w3 + beq End2 + + mov x19, x1 // reload b ptr + mov x10, x7 // reload bias ptr + mov w11, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w11, #0 + beq End3 + + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x17], #16 + ld1 {v2.16b}, [x17], #16 + ld1 {v3.16b}, [x17], #16 + ld1 {v4.16b}, [x19], #16 + ld1 {v5.16b}, [x19], #16 + ld1 {v6.16b}, [x19], #16 + ld1 {v7.16b}, [x19], #16 + + sdot v16.4s, v4.16b, v0.16b + sdot v17.4s, v5.16b, v0.16b + sdot v18.4s, v6.16b, v0.16b + sdot v19.4s, v7.16b, v0.16b + sdot v20.4s, v4.16b, v1.16b + sdot v21.4s, v5.16b, v1.16b + sdot v22.4s, v6.16b, v1.16b + sdot v23.4s, v7.16b, v1.16b + sdot v24.4s, v4.16b, v2.16b + sdot v25.4s, v5.16b, v2.16b + sdot v26.4s, v6.16b, v2.16b + sdot v27.4s, v7.16b, v2.16b + sdot v28.4s, v4.16b, v3.16b + sdot v29.4s, v5.16b, v3.16b + sdot v30.4s, v6.16b, v3.16b + sdot v31.4s, v7.16b, v3.16b + subs w11, w11, #16 // depth + 16 + b L3 + +End3: + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x10], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + // Subtract (Asums*Zb) + ld1 {v14.4s}, [x13], #16 + dup v20.4s, v14.s[0] + dup v21.4s, v14.s[1] + dup v22.4s, v14.s[2] + dup v23.4s, v14.s[3] + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + add w16, w16, #4 // a row index + 4 + b L2 + +End2: + add w15, w15, #4 // b col index + 4 + add x1, x1, x12 // b ptr + stride + add x7, x7, #16 // bias ptr + stride + b L1 + +End1: + sub sp, sp, #144 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly_global.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly_global.h new file mode 100644 index 0000000000000000000000000000000000000000..d1f5ca8bd6024efc1beaf2c70fc9a6f24ff24a1c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly_global.h @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_ASSEMBLY_GLOBAL_H +#define NNACL_ASSEMBLY_GLOBAL_H + +// clang-format off +.macro asm_function fname +#ifdef __APPLE__ +.globl _\fname +_\fname: +#else +.global \fname +#ifdef __ELF__ +.hidden \fname +.type \fname, %function +#endif +\fname: +#endif +.endm + +// clang-format off +.macro asm_default_function fname +#ifdef __APPLE__ +.globl _\fname +_\fname: +#else +.global \fname +#ifdef __ELF__ +.type \fname, %function +#endif +\fname: +#endif +.endm + +// clang-format on + +#endif // NNACL_ASSEMBLY_GLOBAL_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/attention_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/attention_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..f02a87fa258c145be30bb051be8141ae9824edfc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/attention_parameter.h @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_ATTENTION_PARAMETER_H_ +#define NNACL_ATTENTION_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct AttentionParameter { + OpParameter op_parameter_; + int head_num_; + int head_size_; + bool cross_; +} AttentionParameter; + +typedef struct RelativePositionAttentionParameter { + // Primitive parameter + OpParameter op_parameter_; + // multi-head-attention args + int num_heads_; // number of heads of multi-head-attention + int k_seq_; // length of sequence of key of attention + int v_seq_; // length of sequence of value of attention + bool use_bias_; // if matmul in attention has bias + // relative-position-attention args + int p_seq_; // length of sequence of position of attention + // args for compute + int batch_; // batch of query/key/value/position + int d_model_; // d_model of multi-head-attention + int q_seq_; // length of sequence of query of attention + int row_tile_; // row tile for matrix pack + int col_tile_; // col tile for matrix pack + int bias_tile_; // tile for bias pack +} RelativePositionAttentionParameter; + +#endif // NNACL_ATTENTION_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/arithmetic_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/arithmetic_base.c new file mode 100644 index 0000000000000000000000000000000000000000..793aceeff915711cc415121c5f2323b614d417ef --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/arithmetic_base.c @@ -0,0 +1,48 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/arithmetic_base.h" +#include "nnacl_c/kernel/arithmetic.h" + +void CalcMultiplesAndStrides(ArithmeticParameter *param) { + for (size_t i = 0; i < param->ndim_; i++) { + if (param->in_shape0_[i] != 0) { + param->multiples0_[i] = param->out_shape_[i] / param->in_shape0_[i]; + } + if (param->in_shape1_[i] != 0) { + param->multiples1_[i] = param->out_shape_[i] / param->in_shape1_[i]; + } + } + // cal strides + ComputeStrides(param->in_shape0_, param->in_strides0_, param->ndim_); + ComputeStrides(param->in_shape1_, param->in_strides1_, param->ndim_); + ComputeStrides(param->out_shape_, param->out_strides_, param->ndim_); +} + +void CalcStructMultiplesAndStrides(ArithmeticStruct *arithmetic) { + for (size_t i = 0; i < arithmetic->ndim_; i++) { + if (arithmetic->in_shape0_[i] != 0) { + arithmetic->multiples0_[i] = arithmetic->out_shape_[i] / arithmetic->in_shape0_[i]; + } + if (arithmetic->in_shape1_[i] != 0) { + arithmetic->multiples1_[i] = arithmetic->out_shape_[i] / arithmetic->in_shape1_[i]; + } + } + // cal strides + ComputeStrides(arithmetic->in_shape0_, arithmetic->in_strides0_, arithmetic->ndim_); + ComputeStrides(arithmetic->in_shape1_, arithmetic->in_strides1_, arithmetic->ndim_); + ComputeStrides(arithmetic->out_shape_, arithmetic->out_strides_, arithmetic->ndim_); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/arithmetic_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/arithmetic_base.h new file mode 100644 index 0000000000000000000000000000000000000000..af095319fe9a9409c60dba2edc4f5f78bb476ad4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/arithmetic_base.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_ARITHMETIC_BASE_H_ +#define NNACL_BASE_ARITHMETIC_BASE_H_ + +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/nnacl_utils.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/kernel/arithmetic.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void CalcMultiplesAndStrides(ArithmeticParameter *param); +void CalcStructMultiplesAndStrides(ArithmeticStruct *arithmetic); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_BASE_ARITHMETIC_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/batch_to_space_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/batch_to_space_base.c new file mode 100644 index 0000000000000000000000000000000000000000..6e08f6f982d60e20f3b53c663acaa56cd15fc2a1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/batch_to_space_base.c @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/batch_to_space_base.h" + +void BatchToSpaceNoCropForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + int data_size) { + int block_h = block[0]; + int block_w = block[1]; + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int stride_h = block_w * out_n; + int output_offset = 0; + int copy_size = in_c * data_size; + int in_stride_h = in_w * in_c; + int in_stride_n = in_stride_h * in_h; + for (int n = 0; n < out_n; ++n) { + for (int h = 0; h < in_h; ++h) { + int h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + for (int w = 0; w < in_w; ++w) { + int w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + int in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + memcpy((int8_t *)output + output_offset, (int8_t *)input + in_offset * data_size, copy_size); + output_offset += copy_size; + } + } + } + } + } +} + +void BatchToSpaceForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + const int *crops, int data_size) { + int block_h = block[0]; + int block_w = block[1]; + if (block_h == 0 || block_w == 0) { + return; + } + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int h_start = crops[0] / block_h; + int h_valid_begin = crops[0]; + int h_end = MSMIN((in_h * block_h - crops[1]) / block_h + 1, in_h); + int h_valid_end = in_h * block_h - crops[1] - 1; + int w_start = crops[2] / block_w; + int w_valid_begin = crops[2]; + int w_end = MSMIN((in_w * block_w - crops[3]) / block_w + 1, in_w); + int w_valid_end = in_w * block_w - crops[3] - 1; + + int stride_h = block_w * out_n; + int output_offset = 0; + int copy_size = in_c * data_size; + int in_stride_h = in_w * in_c; + int in_stride_n = in_stride_h * in_h; + for (int n = 0; n < out_n; ++n) { + for (int h = h_start; h < h_end; ++h) { + int h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + int h_index = h * block_h + bh; + if (h_index < h_valid_begin || h_index > h_valid_end) { + continue; + } + for (int w = w_start; w < w_end; ++w) { + int w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + int w_index = w * block_w + bw; + if (w_index < w_valid_begin || w_index > w_valid_end) { + continue; + } + int in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + memcpy((int8_t *)output + output_offset, (int8_t *)input + in_offset * data_size, copy_size); + output_offset += copy_size; + } + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/batch_to_space_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/batch_to_space_base.h new file mode 100644 index 0000000000000000000000000000000000000000..c85dd38005f8a06fb84bc70605cf3d6060eefd63 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/batch_to_space_base.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BATCH_TO_SPACE_BASE_H_ +#define NNACL_BATCH_TO_SPACE_BASE_H_ + +#include +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void BatchToSpaceNoCropForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + int data_size); +void BatchToSpaceForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + const int *crops, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BATCH_TO_SPACE_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/broadcast_to.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/broadcast_to.c new file mode 100644 index 0000000000000000000000000000000000000000..8853abf0aa3eae6b1ad99e3f9cc53fccbcf47edb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/broadcast_to.c @@ -0,0 +1,106 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/broadcast_to.h" +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +size_t accumulate(const int *shape, int start, int end) { + size_t product = 1; + for (int i = start; i <= end; ++i) { + product *= (size_t)shape[i]; + } + return product; +} + +void pad_input_shape(int *input_shape, int input_shape_len, int output_shape_len) { + if (input_shape_len < output_shape_len) { + const int shape_gap = output_shape_len - input_shape_len; + for (int i = input_shape_len - 1; i >= 0; --i) { + input_shape[i + shape_gap] = input_shape[i]; + } + for (int i = 0; i < shape_gap; ++i) { + input_shape[i] = 1; + } + } +} + +#define BROADCAST_TO_SIZE_IMPL(data_size) \ + int BroadcastToSize##data_size(const void *input, BroadcastShapeInfo *shape_info, void *output) { \ + if (input == NULL || output == NULL) { \ + return NNACL_NULL_PTR; \ + } \ + if (shape_info->output_shape_size_ > MAX_SHAPE_SIZE) { \ + return NNACL_ERR; \ + } \ + int *input_shape = shape_info->input_shape_; \ + const int *output_shape = shape_info->output_shape_; \ + const int dim_max = shape_info->output_shape_size_ - 1; \ + const size_t temp_length = accumulate(output_shape, 0, dim_max); \ + const size_t data_len = data_size / BYTE_SIZE; \ + if (temp_length * data_len == 0) { \ + return NNACL_ERR; \ + } \ + int8_t *data_temp = (int8_t *)malloc(temp_length * data_len); \ + if (data_temp == NULL) { \ + return NNACL_ERR; \ + } \ + pad_input_shape(input_shape, shape_info->input_shape_size_, dim_max + 1); \ + shape_info->input_shape_size_ = dim_max + 1; \ + \ + size_t before_dim_elements_num = accumulate(input_shape, 0, dim_max - 1); \ + size_t after_dim_elements_num = (size_t)(input_shape[dim_max]); \ + size_t dim_broadcast_rate = (size_t)(output_shape[dim_max] / input_shape[dim_max]); \ + for (size_t i = 0; i < before_dim_elements_num; ++i) { \ + const int8_t *in_ptr = (const int8_t *)input + i * after_dim_elements_num * data_len; \ + for (size_t j = 0; j < dim_broadcast_rate; ++j) { \ + int8_t *out_ptr = (int8_t *)output + (i * dim_broadcast_rate + j) * after_dim_elements_num * data_len; \ + memcpy(out_ptr, in_ptr, after_dim_elements_num *data_len); \ + } \ + } \ + \ + int dim_index = dim_max - 1; \ + while (dim_index >= 0) { \ + if (input_shape[dim_index] == 0) { \ + free(data_temp); \ + return NNACL_ERR; \ + } \ + dim_broadcast_rate = (size_t)(output_shape[dim_index] / input_shape[dim_index]); \ + if (dim_broadcast_rate > 1) { \ + before_dim_elements_num = accumulate(input_shape, 0, dim_index - 1); \ + after_dim_elements_num = accumulate(output_shape, dim_index + 1, dim_max); \ + for (size_t i = 0; i < before_dim_elements_num; ++i) { \ + int8_t *in_ptr = (int8_t *)output + i * after_dim_elements_num * data_len; \ + for (size_t j = 0; j < dim_broadcast_rate; ++j) { \ + int8_t *out_ptr = data_temp + (i * dim_broadcast_rate + j) * after_dim_elements_num * data_len; \ + memcpy(out_ptr, in_ptr, after_dim_elements_num *data_len); \ + } \ + } \ + size_t elements_total = before_dim_elements_num * dim_broadcast_rate * after_dim_elements_num; \ + memcpy(output, data_temp, elements_total *data_len); \ + } \ + --dim_index; \ + } \ + free(data_temp); \ + return NNACL_OK; \ + } + +BROADCAST_TO_SIZE_IMPL(8) +BROADCAST_TO_SIZE_IMPL(16) +BROADCAST_TO_SIZE_IMPL(32) +BROADCAST_TO_SIZE_IMPL(64) +BROADCAST_TO_SIZE_IMPL(128) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/broadcast_to.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/broadcast_to.h new file mode 100644 index 0000000000000000000000000000000000000000..d13114b0677eb0c2ac9e30ccf1f73173d8ba693f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/broadcast_to.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_BROADCAST_TO_H_ +#define NNACL_FP32_BROADCAST_TO_H_ + +#include "nnacl_c/broadcast_to_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +#define BYTE_SIZE 8 +int BroadcastToSize8(const void *input, BroadcastShapeInfo *shape_info, void *output); +int BroadcastToSize16(const void *input, BroadcastShapeInfo *shape_info, void *output); +int BroadcastToSize32(const void *input, BroadcastShapeInfo *shape_info, void *output); +int BroadcastToSize64(const void *input, BroadcastShapeInfo *shape_info, void *output); +int BroadcastToSize128(const void *input, BroadcastShapeInfo *shape_info, void *output); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_BROADCAST_TO_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base.c new file mode 100644 index 0000000000000000000000000000000000000000..6d0abc460bee3ecf0d9e6959f516e7220b9e6817 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base.c @@ -0,0 +1,199 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/cast_base.h" +#include "nnacl_c/cast_base_simd.h" + +typedef union float32_bits { + unsigned int u; + float f; +} float32_bits; + +uint16_t Float32ToFloat16_(float f) { + float32_bits hbit; + hbit.f = f; + uint16_t hbits = 0; + // Extract the sign bit + uint16_t sign = (hbit.u >> FP16_BIT_SIZE) & 0x8000; // Get the sign (1 bit) ox8000 + // Extract the exponent + uint32_t exponent = (hbit.u >> FP32_SIGNIFICAND) & 0xFF; // Extract the exponent (8 bits) 0xFF + // Handle special cases (NaN, Inf, 0) + if (exponent == 0xFF) { // NaN or Infinity 0xFF + hbits |= sign | 0x7FFF; // Set to max float16 value (Infinity) + return hbits; + } else if (exponent == 0) { // Zero or denormalized number + // In float16, we treat zero the same way + hbits |= sign; // Preserve sign for zero + return hbits; + } + // Adjust the exponent to fit float16 + exponent -= FP32_EXPONENT_BIAS; // Remove float32 bias + exponent += FP16_EXPONENT_BIAS; // Add float16 bias + // Check for overflow + if (exponent >= 0x1F) { // 0X1F + hbits |= sign | 0x7FFF; // Set to max float16 value (Infinity) 0x7FFF + return hbits; + } + if (exponent == 0) { + // Handle underflow (too small to represent) + return sign; // Return zero with the correct sign + } + // Shift the mantissa: + // Extract the mantissa (23 bits), shift right by 13 (10-exp) + uint32_t mantissa = (hbit.u & 0x7FFFFF) >> FP16_SHIFT; // 0x7FFFFF + // Combine sign, exponent, and mantissa into hbits + hbits |= + sign | ((uint16_t)exponent << FP16_SIGNIFICAND) | (mantissa & 0x3FF); // combine sign exponent and mantissa 0x3FF + return hbits; +} + +void Int32ToFloat32(const int32_t *input, float *output, int number) { + int index = 0; + + SIMD_RUN_NO_SCALAR(Int32ToFloat32, index, input, output, number); + + for (; index < number; ++index) { + output[index] = (float)input[index]; + } +} + +void Float32ToInt32(const float *input, int32_t *output, int number) { + int index = 0; + + SIMD_RUN_X86_NO_SCALAR(Float32ToInt32, index, input, output, number); + + for (; index < number; ++index) { + output[index] = (int32_t)input[index]; + } +} + +void BoolToFloat32(const bool *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} + +void Uint8ToFloat32(const uint8_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} + +void Int32ToFloat32(const int32_t *input, float *output, int number); + +void Int64ToFloat32(const int64_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} + +#ifdef ENABLE_FP16 +void Int64ToFp16(const int64_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +void Int32ToFp16(const int32_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +void BoolToFp16(const bool *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +void Uint8ToFp16(const uint8_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +void Float32ToFp16(const float *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)(input[i]); + } +} + +void Fp16ToFloat32(const float16_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)(input[i]); + } +} +#else +void Fp16ToFloat32(const uint16_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = ShortToFloat32(input[i]); + } +} + +void Float32ToFp16(const float *input, uint16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = Float32ToFloat16_(input[i]); + } +} +#endif + +void Float32ToInt32(const float *input, int32_t *output, int number); + +void Float32ToInt64(const float *input, int64_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int64_t)input[i]; + } +} + +void Int32ToInt64(const int32_t *input, int64_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int64_t)input[i]; + } +} + +void Int64ToInt32(const int64_t *input, int32_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int32_t)input[i]; + } +} + +void Float32ToInt16(const float *input, int16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int16_t)input[i]; + } +} + +void BoolToInt32(const bool *input, int32_t *output, int number) { + for (int i = 0; i < number; ++i) { + if (input[i]) { + output[i] = 1; + } else { + output[i] = 0; + } + } +} + +void Float32ToBool(const float *input, bool *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (bool)input[i]; + } +} + +void Float32ToUint8(const float *input, uint8_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (uint8_t)input[i]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base.h new file mode 100644 index 0000000000000000000000000000000000000000..52db13c84278376aae403b7b84692ab896557d35 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_CAST_BASE_H_ +#define NNACL_BASE_CAST_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/nnacl_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void BoolToFloat32(const bool *input, float *output, int number); + +void Uint8ToFloat32(const uint8_t *input, float *output, int number); + +void Int32ToFloat32(const int32_t *input, float *output, int number); + +void Int64ToFloat32(const int64_t *input, float *output, int number); + +#ifdef ENABLE_FP16 +void Int64ToFp16(const int64_t *input, float16_t *output, int number); + +void Int32ToFp16(const int32_t *input, float16_t *output, int number); + +void BoolToFp16(const bool *input, float16_t *output, int number); + +void Uint8ToFp16(const uint8_t *input, float16_t *output, int number); + +void Float32ToFp16(const float *input, float16_t *output, int number); + +void Fp16ToFloat32(const float16_t *input, float *output, int number); +#else +void Fp16ToFloat32(const uint16_t *input, float *output, int number); + +void Float32ToFp16(const float *input, uint16_t *output, int number); +#endif + +uint16_t Float32ToFloat16_(float f); + +void Float32ToInt32(const float *input, int32_t *output, int number); + +void Float32ToInt64(const float *input, int64_t *output, int number); + +void Int32ToInt64(const int32_t *input, int64_t *output, int number); + +void Int64ToInt32(const int64_t *input, int32_t *output, int number); + +void Float32ToInt16(const float *input, int16_t *output, int number); + +void BoolToInt32(const bool *input, int32_t *output, int number); + +void Float32ToBool(const float *input, bool *output, int number); + +void Float32ToUint8(const float *input, uint8_t *output, int number); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_CAST_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..b7ad2f3206592b5045c8365aec636c097f32041c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base_simd.h.in @@ -0,0 +1,49 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_CAST_BASE_@SIMD_INSTRUCTION@_H_ +#define NNACL_BASE_CAST_BASE_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int Int32ToFloat32@SIMD_INSTRUCTION@(int index, const int32_t *input, float *output, int number) { + for (int block_max_size = number - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 value = SIMD_LD_EPI32(input + index); + SIMD_ST_F32(output + index, SIMD_EPI32_TO_F32(value)); + } + return index; +} + +#ifndef MS_SIMD_NEON +static inline int Float32ToInt32@SIMD_INSTRUCTION@(int index, const float *input, int32_t *output, int number) { + for (int block_max_size = number - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 value = SIMD_LD_F32(input + index); + SIMD_ST_EPI32(output + index, SIMD_F32_TO_EPI32(value)); + } + return index; +} +#endif + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/concat_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/concat_base.c new file mode 100644 index 0000000000000000000000000000000000000000..b40f4473ba63d187d80aa75aaee55f5ac0bbaf6a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/concat_base.c @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/concat_base.h" + +void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output, + int task_id, int thread_num, int data_size) { + int before_axis_size = 1; + for (int i = 0; i < axis; ++i) { + before_axis_size *= inputs_output_shape[0][i]; + } + + int after_axis_size = data_size; + for (size_t i = (size_t)(axis) + 1; i < shape_size; ++i) { + after_axis_size *= inputs_output_shape[0][i]; + } + int axis_offset = 0; + uint8_t *dst_base = (output); + int output_stride = after_axis_size * inputs_output_shape[input_num][axis]; + for (int i = 0; i < input_num; ++i) { + const uint8_t *src_base = (input[i]); + if (inputs_output_shape[i] == NULL) { + continue; + } + int input_stride = after_axis_size * inputs_output_shape[i][axis]; + NNACL_CHECK_ZERO_RETURN(thread_num); + int offset = UP_DIV(input_stride, thread_num); + int count = input_stride - offset * task_id; + if (count <= 0) { + axis_offset += inputs_output_shape[i][axis]; + continue; + } + count = MSMIN(offset, count); + for (int j = 0; j < before_axis_size; j++) { + const uint8_t *src = src_base + j * input_stride + task_id * offset; + uint8_t *dst = dst_base + j * output_stride + axis_offset * after_axis_size + task_id * offset; + memcpy(dst, src, count); + } + axis_offset += inputs_output_shape[i][axis]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/concat_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/concat_base.h new file mode 100644 index 0000000000000000000000000000000000000000..ea85f6e650820ad159f848c816b64139d1f3c321 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/concat_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_CONCAT_BASE_H_ +#define NNACL_FP32_CONCAT_BASE_H_ + +#include +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output, + int task_id, int thread_num, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONCAT_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv1x1_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv1x1_base.c new file mode 100644 index 0000000000000000000000000000000000000000..240b851617df7f71561111bc8a6b8b35c48942b8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv1x1_base.c @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/conv1x1_base.h" + +void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size) { + /* support nhwc */ + char *src = (char *)src_ptr; + char *dst = (char *)dst_ptr; + for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) { + int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_u_; + if (src_h < 0 || src_h >= conv_param->input_h_) { + continue; + } + const char *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_ * data_size; + char *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_ * data_size; + for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) { + int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_l_; + if (src_w < 0 || src_w >= conv_param->input_w_) { + continue; + } + memcpy(dst_h_ptr + dst_w * conv_param->input_channel_ * data_size, + src_h_ptr + src_w * conv_param->input_channel_ * data_size, conv_param->input_channel_ * data_size); + } + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv1x1_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv1x1_base.h new file mode 100644 index 0000000000000000000000000000000000000000..6ab0322b8c787c01c3e58deb9bd8f53b1b39ad51 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv1x1_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_CONV1X1_BASE_H_ +#define NNACL_BASE_CONV1X1_BASE_H_ + +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_CONV1X1_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv_common_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv_common_base.c new file mode 100644 index 0000000000000000000000000000000000000000..59aac49c609e1ba6e05363008f4ceeb9b59e1e61 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv_common_base.c @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/errorcode.h" + +#define MIN_UNIT 2 +#define MAX_UNIT 8 + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +bool CheckConvDw1DWinograd(const ConvParameter *conv_param, int thread_num) { + return conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_w_ == 1 && + conv_param->stride_h_ == 1 && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && + conv_param->pad_u_ == 1 && conv_param->pad_d_ == 1 && conv_param->pad_l_ == 1 && conv_param->pad_r_ == 1 && + conv_param->input_channel_ == conv_param->output_channel_ && conv_param->output_w_ >= 4 && + conv_param->output_h_ >= thread_num * 4; // better had more than 4 rows for each thread +} +#endif + +bool CheckWinogradInputOutputUnit(int input_unit, int output_unit) { + if (input_unit != 4 && input_unit != 6 && input_unit != 8) { + return false; + } + if ((output_unit >= input_unit) || (output_unit < 2)) { + return false; + } + return true; +} + +// Reference to the paper "Fast Algorithms for Convolutional Neural Networks" +// Utilize cost model to compute performance gain. +// If the gain is greater than got from Im2col, winograd algorithm will be chosen. +int SelectOutputUnit(const ConvParameter *conv_param) { + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_c = conv_param->input_channel_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_c = conv_param->output_channel_; + if (conv_param->op_parameter_.thread_num_ == 0) { + return NNACL_PARAM_INVALID; + } + int unit2 = UP_DIV(out_w * out_h, C12NUM * conv_param->op_parameter_.thread_num_); + int max_out_unit = (int)(sqrtf((float)unit2)); + max_out_unit = max_out_unit < MAX_UNIT ? max_out_unit : MAX_UNIT; + max_out_unit = max_out_unit > MIN_UNIT ? max_out_unit : MIN_UNIT; + + int unit = 0; + float max_rate = 0.0f; + float common_cost = (float)out_h * out_w * in_c * out_c * kernel_h * kernel_w; + + for (int i = MIN_UNIT; i <= max_out_unit; ++i) { + int input_unit = i + kernel_w - 1; + if (!CheckWinogradInputOutputUnit(input_unit, i)) { + continue; + } + float penalty = ((float)input_unit * input_unit) / ((float)kernel_h * kernel_w) * 0.12f; + float wino_cost = ((2 + out_c) * (float)input_unit * input_unit * in_c + ((float)input_unit + i) * i * out_c) * + UP_DIV(out_w, i) * UP_DIV(out_h, i); + float reduce_rate = common_cost / wino_cost - penalty; + if (reduce_rate > max_rate) { + max_rate = reduce_rate; + unit = i; + } + } + if (max_rate < 1.0f) { + return 1; + } + // If output_unit is 1, then it is conventional convolution + return unit; +} + +bool CheckIfUseWinograd(int *output_unit, const ConvParameter *conv_param) { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + return false; + } + if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1 && conv_param->input_channel_ != 1) { + *output_unit = SelectOutputUnit(conv_param); + if (*output_unit > 1) { + return true; + } + } + return false; +} + +bool CheckAvxUseSW1x1Conv(const ConvParameter *conv_param) { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + if (conv_param->pad_d_ == 0 && conv_param->pad_l_ == 0 && conv_param->pad_r_ == 0 && conv_param->pad_u_ == 0 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { + return true; + } + } + return false; +} + +bool CheckAvxUseSWConv(const ConvParameter *conv_param, int thread_nr_) { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_param->input_w_, conv_param->input_h_, false); + if (conv_param->pad_d_ == 0 && conv_param->pad_l_ == 0 && conv_param->pad_r_ == 0 && conv_param->pad_u_ == 0 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1 && conv_param->input_channel_ % C8NUM == 0 && + (conv_param->input_w_ * conv_param->input_h_ >= thread_nr_)) { + return true; + } + } else { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ > C128NUM) { // conv1d kernel + return false; + } else if (conv_param->input_channel_ / conv_param->op_parameter_.thread_num_ <= C16NUM && + conv_param->input_h_ >= thread_nr_ && + (conv_param->kernel_h_ < C7NUM || conv_param->input_h_ / conv_param->kernel_h_ >= C4NUM) && + (conv_param->kernel_w_ < C7NUM || conv_param->input_w_ / conv_param->kernel_w_ >= C4NUM)) { + return true; + } + } + return false; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv_common_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv_common_base.h new file mode 100644 index 0000000000000000000000000000000000000000..29dfa5690eec65929d9da9855224cb69d9eaaac5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv_common_base.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_CONV_DEPTHWISE_BASE_H_ +#define NNACL_BASE_CONV_DEPTHWISE_BASE_H_ + +#include "nnacl_c/conv_parameter.h" + +bool CheckAvxUseSW1x1Conv(const ConvParameter *conv_param); +bool CheckAvxUseSWConv(const ConvParameter *conv_param, int thread_nr_); + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +bool CheckConvDw1DWinograd(const ConvParameter *conv_param, int thread_num); +#endif + +bool CheckWinogradInputOutputUnit(int input_unit, int output_unit); + +bool CheckIfUseWinograd(int *output_unit, const ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_CONV_DEPTHWISE_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/crop_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/crop_base.c new file mode 100644 index 0000000000000000000000000000000000000000..26b58e0fdd36bef9bf612d82c0edca6fac07164c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/crop_base.c @@ -0,0 +1,40 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/crop_base.h" +#include "nnacl_c/errorcode.h" + +int CropPadOffset(int input_dim, CropParameter *crop_para, int64_t *in_offset) { + int64_t axis = crop_para->axis_; + int offsets_size = crop_para->offset_size_; + if (offsets_size > 1) { + NNACL_CHECK_TRUE_RET(axis + offsets_size == input_dim, NNACL_ERR); + } + for (int i = 0; i < input_dim; i++) { + int crop_offset = 0; + if (i >= axis) { + if (offsets_size == 1) { + crop_offset = crop_para->offset_[0]; + } else if (offsets_size > 1) { + if (i - axis < CROP_OFFSET_MAX_SIZE) { + crop_offset = crop_para->offset_[i - axis]; + } + } + } + in_offset[i] = crop_offset; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/crop_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/crop_base.h new file mode 100644 index 0000000000000000000000000000000000000000..4b036d88e701c286757fad1c329d3212a60da90e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/crop_base.h @@ -0,0 +1,35 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_CROP_BASE_H_ +#define NNACL_BASE_CROP_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/crop_parameter.h" + +#define CROP_OFFSET_MAX_SIZE 4 + +#ifdef __cplusplus +extern "C" { +#endif + +int CropPadOffset(int input_dim, CropParameter *crop_para, int64_t *in_offset); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_CROP_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/depth_to_space_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/depth_to_space_base.c new file mode 100644 index 0000000000000000000000000000000000000000..11bd8478690c62890a389c9b80b4242faff34fd5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/depth_to_space_base.c @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/depth_to_space_base.h" +#include "nnacl_c/errorcode.h" + +void DepthToSpaceForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceArgs *param) { + int32_t block_size = param->block_size_; + int32_t in_shape_dim2 = in_shape[2]; + int32_t in_shape_dim1 = in_shape[1]; + size_t copy_size = (size_t)block_size * param->out_stride_dim2_ * param->data_type_size_; + for (int i = 0; i < in_shape[0]; ++i) { + int64_t in_offset_n = i * param->in_stride_dim0_; + int64_t out_offset_n = i * param->out_stride_dim0_; + for (int j = 0; j < in_shape_dim1; ++j) { + int64_t in_offset_h = in_offset_n + j * param->in_stride_dim1_; + int64_t out_offset_h = out_offset_n + j * block_size * param->out_stride_dim1_; + for (int k = 0; k < in_shape_dim2; ++k) { + int64_t in_offset_w = in_offset_h + k * param->in_stride_dim2_; + int64_t out_offset_w = out_offset_h + k * block_size * param->out_stride_dim2_; + for (int l = 0; l < block_size; ++l) { + int64_t out_offset = (out_offset_w + l * param->out_stride_dim1_) * param->data_type_size_; + int64_t in_offset = (in_offset_w + l * block_size * param->out_stride_dim2_) * param->data_type_size_; + memcpy((int8_t *)output + out_offset, (int8_t *)input + in_offset, copy_size); + } + } + } + } +} + +void DepthToSpaceCRDForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceArgs *param) { + int32_t block_size = param->block_size_; + int32_t in_shape_dim3 = in_shape[3]; + int32_t in_shape_dim2 = in_shape[2]; + int32_t in_shape_dim1 = in_shape[1]; + size_t copy_size = param->data_type_size_; + for (int i = 0; i < in_shape[0]; ++i) { + int64_t in_offset_n = i * param->in_stride_dim0_; + int64_t out_offset_n = i * param->out_stride_dim0_; + for (int j = 0; j < in_shape_dim1; ++j) { + int64_t in_offset_h = in_offset_n + j * param->in_stride_dim1_; + int64_t out_offset_h = out_offset_n + j * block_size * param->out_stride_dim1_; + for (int k = 0; k < in_shape_dim2; ++k) { + int64_t in_offset_w = in_offset_h + k * param->in_stride_dim2_; + int64_t out_offset_w = out_offset_h + k * block_size * param->out_stride_dim2_; + for (int l = 0; l < in_shape_dim3; ++l) { + int64_t offset = l % (block_size * block_size); + int64_t out_offset_c = + out_offset_w + + offset / block_size * block_size * in_shape_dim2 * in_shape_dim3 / (block_size * block_size) + + offset % block_size * in_shape_dim3 / (block_size * block_size); + int64_t out_offset = (out_offset_c + l / (block_size * block_size)) * param->data_type_size_; + int64_t in_offset = (in_offset_w + l) * param->data_type_size_; + memcpy((int8_t *)output + out_offset, (int8_t *)input + in_offset, copy_size); + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/depth_to_space_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/depth_to_space_base.h new file mode 100644 index 0000000000000000000000000000000000000000..57c0e38a934f8e92c46ae4aa7bb76b2b6061c15b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/depth_to_space_base.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_DEPTH_TO_SPACE_H_ +#define NNACL_DEPTH_TO_SPACE_H_ + +#include +#include "nnacl_c/kernel/depth_to_space.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DepthToSpaceForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceArgs *param); +void DepthToSpaceCRDForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceArgs *param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_DEPTH_TO_SPACE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base.c new file mode 100644 index 0000000000000000000000000000000000000000..9e6a05b19cd36f83e11d037e6bd3ff11768af056 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base.c @@ -0,0 +1,59 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/fill_base.h" +#include "nnacl_c/fill_base_simd.h" + +int FillFp32(float *output, int size, float data) { + if (output == NULL) { + return NNACL_NULL_PTR; + } + + int index = 0; + + SIMD_RUN_NO_SCALAR(FillFp32, index, output, size, data); + + for (; index < size; ++index) { + output[index] = data; + } + return NNACL_OK; +} + +int FillInt32(int *output, int size, int data) { + if (output == NULL) { + return NNACL_NULL_PTR; + } + + int index = 0; + + SIMD_RUN_NO_SCALAR(FillInt32, index, output, size, data); + + for (; index < size; ++index) { + output[index] = data; + } + return NNACL_OK; +} + +int FillBool(bool *output, int size, bool data) { + if (output == NULL) { + return NNACL_NULL_PTR; + } + + for (int index = 0; index < size; ++index) { + output[index] = data; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base.h new file mode 100644 index 0000000000000000000000000000000000000000..8da977c71b26dfc47182112f4c1eefa464f47957 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FILL_BASE_H_ +#define NNACL_FILL_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fill_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int FillFp32(float *output, int size, float data); +int FillInt32(int *output, int size, int data); +int FillBool(bool *output, int size, bool data); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FILL_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..08bfb4c2fecb98a402f96ea4faa27817915f62a0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base_simd.h.in @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_FILL_BASE_@SIMD_INSTRUCTION@_H_ +#define NNACL_BASE_FILL_BASE_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int FillFp32@SIMD_INSTRUCTION@(int index, float *output, int size, float data) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_MOV_F32(data)); + } + return index; +} + +static inline int FillInt32@SIMD_INSTRUCTION@(int index, int *output, int size, int data) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(output + index, SIMD_MOV_EPI32(data)); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/format_transpose.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/format_transpose.c new file mode 100644 index 0000000000000000000000000000000000000000..e062b90182fadf67dfd696ecd01f7bfc79936c1d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/format_transpose.c @@ -0,0 +1,81 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/format_transpose.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp32/pack_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/pack_fp16.h" +#endif + +int TransposeFp32Data(void *src_data, void *dst_data, const FormatC src_format, const FormatC dst_format, + const int batch, const int channel, const int plane) { + if (src_format == Format_NHWC && dst_format == Format_NCHW) { + PackNHWCToNCHWFp32(src_data, dst_data, batch, plane, channel, 0, 1); + } else if (src_format == Format_NCHW && dst_format == Format_NHWC) { + PackNCHWToNHWCFp32(src_data, dst_data, batch, plane, channel, 0, 1); + } else if (src_format == Format_NCHW && dst_format == Format_NC4HW4) { + PackNCHWToNC4HW4Fp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NC4HW4 && dst_format == Format_NCHW) { + PackNC4HW4ToNCHWFp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NHWC && dst_format == Format_NC4HW4) { + PackNHWCToNC4HW4Fp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NC4HW4 && dst_format == Format_NHWC) { + PackNC4HW4ToNHWCFp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NHWC && dst_format == Format_NC8HW8) { + PackNHWCToNC8HW8Fp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NC8HW8 && dst_format == Format_NHWC) { + PackNC8HW8ToNHWCFp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NC8HW8 && dst_format == Format_NCHW) { + PackNC8HW8ToNCHWFp32(src_data, dst_data, batch, plane, channel); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} + +#ifdef ENABLE_FP16 +int TransposeFp16Data(void *src_data, void *dst_data, const FormatC src_format, const FormatC dst_format, int batch, + int channel, int plane) { + if (src_format == Format_NCHW && dst_format == Format_NC8HW8) { + PackNCHWFp16ToNC8HW8Fp16(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NHWC && dst_format == Format_NC8HW8) { + return NNACL_ERR; + } else if (src_format == Format_NC8HW8 && dst_format == Format_NCHW) { + PackNC8HW8ToNCHWFp16(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NC8HW8 && dst_format == Format_NHWC) { + PackNC8HW8ToNHWCFp16((float16_t *)src_data, (float16_t *)dst_data, batch, plane, channel); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} +#endif + +int TransData(void *src_data, void *dst_data, const FormatC src_format, const FormatC dst_format, TypeIdC data_type, + const int batch, const int channel, const int plane) { + switch (data_type) { + case kNumberTypeFloat: + case kNumberTypeFloat32: + return TransposeFp32Data(src_data, dst_data, src_format, dst_format, batch, channel, plane); +#ifdef ENABLE_FP16 + case kNumberTypeFloat16: + return TransposeFp16Data(src_data, dst_data, src_format, dst_format, batch, channel, plane); +#endif + default: + return NNACL_ERR; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/format_transpose.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/format_transpose.h new file mode 100644 index 0000000000000000000000000000000000000000..638e2f0faf73171158a5dfbaf93903395eaf5dd4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/format_transpose.h @@ -0,0 +1,30 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FORMAT_TRANSPOSE_H_ +#define NNACL_FORMAT_TRANSPOSE_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +int TransData(void *src_data, void *dst_data, const FormatC src_format, const FormatC dst_format, TypeIdC data_type, + const int batch, const int channel, const int plane); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FILL_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_base.c new file mode 100644 index 0000000000000000000000000000000000000000..8721b56820b1c46704f0440f7b59e3b9d90cb46d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_base.c @@ -0,0 +1,44 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/base/gather_base.h" + +int Gather(const void *input, int64_t outer_size, int64_t byte_inner_size, int64_t limit, const int *indices, + int64_t index_num, void *output, int64_t byte_out_stride, int *error_index) { + if (input == NULL || output == NULL || indices == NULL || error_index == NULL) { + return NNACL_NULL_PTR; + } + const int8_t *int8_in = (int8_t *)input; + int8_t *int8_out = (int8_t *)output; + int64_t in_stride = byte_inner_size * limit; + for (int64_t m = 0; m < outer_size; ++m) { + int8_t *int8_out_m = int8_out; + for (int64_t i = 0; i < index_num; ++i) { + int index = indices[i]; + index = index < 0 ? index + limit : index; + if (index < 0 || index >= limit) { + *error_index = index; + return NNACL_GATHER_INDICES_VALUE_INVALID; + } else { + memcpy(int8_out_m, int8_in + index * byte_inner_size, byte_inner_size); + } + int8_out_m += byte_inner_size; + } + int8_in += in_stride; + int8_out += byte_out_stride; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_base.h new file mode 100644 index 0000000000000000000000000000000000000000..f47a19b636ef6b1d836dd3ab26f95beb61938ff1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_GATHER_BASE_H_ +#define NNACL_GATHER_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int Gather(const void *input, int64_t outer_size, int64_t byte_inner_size, int64_t limit, const int *indices, + int64_t index_num, void *output, int64_t byte_out_stride, int *error_index); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_GATHER_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_d_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_d_base.c new file mode 100644 index 0000000000000000000000000000000000000000..93d460b2c94caa221f4f417142523f7935dad42b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_d_base.c @@ -0,0 +1,163 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/base/gather_d_base.h" + +int CheckIndexValue_int32_t(int32_t *index, const int max_index, const size_t *index_shape, + const size_t index_shape_size) { + // check index + size_t index_size = 1; + for (size_t i = 0; i < index_shape_size; ++i) { + index_size *= index_shape[i]; + } + + for (size_t i = 0; i < index_size; ++i) { + if (index[i] >= max_index || index[i] < -max_index) { + return NNACL_ERR; + } + if (index[i] < 0) { + index[i] = max_index + index[i]; + } + } + return NNACL_OK; +} + +int CheckIndexValue_int64_t(int64_t *index, const int max_index, const size_t *index_shape, + const size_t index_shape_size) { + // check index + size_t index_size = 1; + for (size_t i = 0; i < index_shape_size; ++i) { + index_size *= index_shape[i]; + } + for (size_t i = 0; i < index_size; ++i) { + if (index[i] >= max_index || index[i] < -max_index) { + return NNACL_ERR; + } + if (index[i] < 0) { + index[i] = max_index + index[i]; + } + } + return NNACL_OK; +} + +int InitCalVec(size_t *in_strides, size_t *out_strides, size_t *pos, const size_t *input_shape, + const size_t input_shape_size, const size_t *output_shape, const size_t output_shape_size) { + // in_strides + NNACL_CHECK_NULL_RETURN_ERR(in_strides); + for (size_t i = 0; i < input_shape_size; ++i) { + in_strides[i] = 1; + } + for (int i = (int)input_shape_size - 2; i >= 0; --i) { + in_strides[i] = input_shape[i + 1] * in_strides[i + 1]; + } + + // out_strides + NNACL_CHECK_NULL_RETURN_ERR(out_strides); + for (size_t i = 0; i < output_shape_size; ++i) { + out_strides[i] = 1; + } + for (int i = (int)output_shape_size - 2; i >= 0; --i) { + out_strides[i] = output_shape[i + 1] * out_strides[i + 1]; + } + + NNACL_CHECK_NULL_RETURN_ERR(pos); + for (size_t i = 0; i < output_shape_size; ++i) { + pos[i] = 0; + } + return NNACL_OK; +} + +#define COPY_TASK_IMPL(type0, type1) \ + int CopyTask_Input_##type0##_Index_##type1( \ + type0 *output, const type0 *input, const type1 *index, size_t cur_dim, size_t *pos, const size_t dim, \ + const size_t *output_shape, const size_t output_shape_size, const size_t *in_strides, const size_t *out_strides) { \ + if (pos == NULL || out_strides == NULL || in_strides == NULL) { \ + return NNACL_NULL_PTR; \ + } \ + for (size_t i = 0; i < output_shape[cur_dim]; ++i) { \ + pos[cur_dim] = i; \ + if (cur_dim == output_shape_size - 1) { \ + size_t input_offset = 0; \ + size_t out_offset = 0; \ + for (size_t j = 0; j < output_shape_size; ++j) { \ + out_offset += pos[j] * out_strides[j]; \ + } \ + size_t cur_index = pos[dim]; \ + pos[dim] = index[out_offset]; \ + for (size_t j = 0; j < output_shape_size; ++j) { \ + input_offset += pos[j] * in_strides[j]; \ + } \ + ((type0 *)output)[out_offset] = ((const type0 *)input)[input_offset]; \ + pos[dim] = cur_index; \ + } else { \ + CopyTask_Input_##type0##_Index_##type1(output, input, index, cur_dim + 1, pos, dim, output_shape, \ + output_shape_size, in_strides, out_strides); \ + } \ + } \ + return NNACL_OK; \ + } + +COPY_TASK_IMPL(bool, int32_t) +COPY_TASK_IMPL(bool, int64_t) +COPY_TASK_IMPL(int16_t, int32_t) +COPY_TASK_IMPL(int16_t, int64_t) +COPY_TASK_IMPL(int32_t, int32_t) +COPY_TASK_IMPL(int32_t, int64_t) +COPY_TASK_IMPL(int64_t, int32_t) +COPY_TASK_IMPL(int64_t, int64_t) +COPY_TASK_IMPL(float, int32_t) +COPY_TASK_IMPL(float, int64_t) +#ifdef ENABLE_FP16 +COPY_TASK_IMPL(float16_t, int32_t) +COPY_TASK_IMPL(float16_t, int64_t) +#endif + +#define GATHER_D_IMPL(type0, type1) \ + GATHER_D_IMPL_DECLARATION(type0, type1) { \ + if (output == NULL || input == NULL || index == NULL || input_shape == NULL || output_shape == NULL) { \ + return NNACL_NULL_PTR; \ + } \ + int max_index = input_shape[dim]; \ + int ret = CheckIndexValue_##type1(index, max_index, output_shape, output_shape_size); \ + if (ret != NNACL_OK) { \ + return ret; \ + } \ + size_t in_strides[MAX_SHAPE_SIZE]; \ + size_t out_strides[MAX_SHAPE_SIZE]; \ + size_t pos[MAX_SHAPE_SIZE]; \ + ret = InitCalVec(in_strides, out_strides, pos, input_shape, input_shape_size, output_shape, output_shape_size); \ + if (ret != NNACL_OK) { \ + return ret; \ + } \ + ret = CopyTask_Input_##type0##_Index_##type1(output, input, index, 0, pos, dim, output_shape, output_shape_size, \ + in_strides, out_strides); \ + return ret; \ + } + +GATHER_D_IMPL(bool, int32_t) +GATHER_D_IMPL(bool, int64_t) +GATHER_D_IMPL(int16_t, int32_t) +GATHER_D_IMPL(int16_t, int64_t) +GATHER_D_IMPL(int32_t, int32_t) +GATHER_D_IMPL(int32_t, int64_t) +GATHER_D_IMPL(int64_t, int32_t) +GATHER_D_IMPL(int64_t, int64_t) +GATHER_D_IMPL(float, int32_t) +GATHER_D_IMPL(float, int64_t) +#ifdef ENABLE_FP16 +GATHER_D_IMPL(float16_t, int32_t) +GATHER_D_IMPL(float16_t, int64_t) +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_d_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_d_base.h new file mode 100644 index 0000000000000000000000000000000000000000..a8270b01c18de548acc54e78e118953f7648bc81 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_d_base.h @@ -0,0 +1,55 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_GATHER_D_BASE_H_ +#define NNACL_GATHER_D_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +#define GATHER_D(type0, type1, output, input, index, input_shape, input_shape_size, output_shape, output_shape_size, \ + dim) \ + GatherD_Input_##type0##_Index_##type1(output, input, index, input_shape, input_shape_size, output_shape, \ + output_shape_size, dim) + +#define GATHER_D_IMPL_DECLARATION(type0, type1) \ + int GatherD_Input_##type0##_Index_##type1( \ + type0 *output, const type0 *input, type1 *index, const size_t *input_shape, const size_t input_shape_size, \ + const size_t *output_shape, const size_t output_shape_size, const size_t dim) + +GATHER_D_IMPL_DECLARATION(bool, int32_t); +GATHER_D_IMPL_DECLARATION(bool, int64_t); +GATHER_D_IMPL_DECLARATION(int16_t, int32_t); +GATHER_D_IMPL_DECLARATION(int16_t, int64_t); +GATHER_D_IMPL_DECLARATION(int32_t, int32_t); +GATHER_D_IMPL_DECLARATION(int32_t, int64_t); +GATHER_D_IMPL_DECLARATION(int64_t, int32_t); +GATHER_D_IMPL_DECLARATION(int64_t, int64_t); +GATHER_D_IMPL_DECLARATION(float, int32_t); +GATHER_D_IMPL_DECLARATION(float, int64_t); +#ifdef ENABLE_FP16 +GATHER_D_IMPL_DECLARATION(float16_t, int32_t); +GATHER_D_IMPL_DECLARATION(float16_t, int64_t); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_GATHER_D_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/minimal_filtering_generator.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/minimal_filtering_generator.c new file mode 100644 index 0000000000000000000000000000000000000000..41226bebb316f3f1d7907537a650055c2620e6ae --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/minimal_filtering_generator.c @@ -0,0 +1,342 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/base/minimal_filtering_generator.h" +#include +#include +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void Polynomial(const float *interval, float *m, int degree) { + for (int i = 0; i < degree; ++i) { + float mul = 1; + for (int j = 0; j < degree; ++j) { + if (i == j) { + continue; + } + mul *= (interval[i] - interval[j]); + } + m[i] = mul; + } +} + +void DiagonalPlusMatrix(const float *matrix, float *diagonal_matrix, int degree) { + int data_num = (degree + 1) * (degree + 1); + memset(diagonal_matrix, 0, data_num * sizeof(float)); + for (int i = 0; i < degree; ++i) { + for (int j = 0; j < degree; ++j) { + if (j == i) { + diagonal_matrix[i * (degree + 1) + j] = matrix[i]; + } + } + } + diagonal_matrix[data_num - 1] = 1; +} + +void ResidueMatrix(const float *interval, float *b, int row, int col) { + // row : input unit, col : output_unit + // result : matrix b + int len = row * col; + memset(b, 0, len * sizeof(float)); + for (int i = 0; i < row - 1; ++i) { + for (int j = 0; j < col; ++j) { + b[i * col + j] = pow(interval[i], j); + } + } + b[len - 1] = 1; +} + +int LT(const float *poly_array, float *matrix_lt, int n) { + if (n > MAX_LEN) { + return NNACL_ERR; + } + float coefficient_array[MAX_LEN]; // n + float poly[MAX_LEN]; // n + + Polynomial(poly_array, poly, n); + for (int i = 0; i < n; ++i) { + // get coefficient + int index = 1; + memset(coefficient_array, 0, n * sizeof(float)); + coefficient_array[0] = 1; + for (int j = 0; j < n; ++j) { + if (j == i) continue; + float poly_coe = poly_array[j] == 0 ? 0 : -poly_array[j]; + coefficient_array[index] = 1; + for (int k = index - 1; k > 0; --k) { + coefficient_array[k] = coefficient_array[k] * poly_coe + coefficient_array[k - 1]; + } + coefficient_array[0] *= poly_coe; + index++; + } + + // lx[i, 0].nth(j) / f[i] + int setp = i * n; + for (int l = 0; l < n; ++l) { + matrix_lt[setp + l] = coefficient_array[l] / poly[i]; + } + } // matrix L row loop + return NNACL_OK; +} + +void T(const float *poly_array, float *matrix_t, int n) { + memset(matrix_t, 0, n * (n + 1) * sizeof(float)); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n + 1; ++j) { + if (j == i) matrix_t[i * (n + 1) + j] = 1; + if (j == n) { + if (poly_array[i] == 0) { + matrix_t[i * (n + 1) + j] = 0; + } else { + matrix_t[i * (n + 1) + j] = -pow(poly_array[i], n); + } + } + } + } +} + +int B(const float *poly_array, float *matrix_b, int in_unit) { + memset(matrix_b, 0, in_unit * in_unit * sizeof(float)); + int n = in_unit - 1; + if ((n * n) > MAX_LEN || (n * in_unit) > MAX_LEN) { + return NNACL_ERR; + } + float matrix_l[MAX_LEN]; // n * n + float matrix_lt[MAX_LEN]; // n * n + float matrix_t[MAX_LEN]; // n * in_unit + + T(poly_array, matrix_t, n); + if (LT(poly_array, matrix_lt, n) != NNACL_OK) { + return NNACL_ERR; + } + MatrixTranspose(matrix_lt, matrix_l, n, n); + MatrixMultiply(matrix_l, matrix_t, matrix_b, n, n, in_unit); + matrix_b[in_unit * in_unit - 1] = 1; + return NNACL_OK; +} + +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) +void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, + int in_channel, int c4_channel) { + int cnt = 0; + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + for (int y = 0; y < in_channel; ++y) { + float tmp = 0; + for (int z = 0; z < k; ++z) { + tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; + } + matrix_c[cnt++] = tmp; + } + cnt += c4_channel / 4 - in_channel; + } + } +} +#endif + +void GenerateIntervalArray(float *array, float interval, int degree) { + array[0] = 0; + for (int i = 1; i < degree; ++i) { + int coefficient = pow(-1, i - 1); + array[i] = array[i - 1] + interval * i * coefficient; + } +} + +void MatrixTranspose(const float *matrix, float *trans_matrix, int row, int col) { + for (int i = 0; i < col; ++i) { + for (int j = 0; j < row; ++j) { + trans_matrix[i * row + j] = matrix[j * col + i]; + } + } +} + +void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_c, int m, int k, int n) { + int count = 0; + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + float res = 0; + for (int i = 0; i < k; i++) { + res += *(matrix_a + h_offset + i) * *(matrix_b + w + i * n); + } + *(matrix_c + count) = res; + count++; + } + } +} + +int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *matrix_bt, float *matrix_g, + float *matrix_gt, float coefficient, int out_unit, int filter_size) { + int in_unit = out_unit + filter_size - 1; + int degree = in_unit - 1; + if (degree > MAX_LEN || (in_unit * in_unit) > MAX_LEN || (in_unit * filter_size) > MAX_LEN) { + return NNACL_ERR; + } + float polynomial_m[MAX_LEN]; // degree + float diagonal_matrix[MAX_LEN]; // input_unit * input_unit + float inverse_diagonal_matrix[MAX_LEN]; // input_unit * input_unit + + // get diagonal matrix + float interval[MAX_LEN]; // degree + GenerateIntervalArray(interval, coefficient, degree); + Polynomial(interval, polynomial_m, degree); + DiagonalPlusMatrix(polynomial_m, diagonal_matrix, degree); + if (diagonal_matrix[0] < 0) { + for (int i = 0; i < in_unit; ++i) { + if (diagonal_matrix[i] != 0) diagonal_matrix[i] *= -1; + } + } + + // inverse diagonal matrix + for (int j = 0; j < in_unit * in_unit; ++j) { + if (diagonal_matrix[j] != 0) { + inverse_diagonal_matrix[j] = 1.0 / diagonal_matrix[j]; + } else { + inverse_diagonal_matrix[j] = 0; + } + } + + // get matrix A && AT + ResidueMatrix(interval, matrix_a, in_unit, out_unit); + MatrixTranspose(matrix_a, matrix_at, in_unit, out_unit); + + // get matrix B + int ret = B(interval, matrix_bt, in_unit); + if (ret != NNACL_OK) { + return ret; + } + MatrixTranspose(matrix_bt, matrix_b, in_unit, in_unit); + MatrixMultiply(diagonal_matrix, matrix_b, matrix_bt, in_unit, in_unit, in_unit); + MatrixTranspose(matrix_bt, matrix_b, in_unit, in_unit); + + // get matrix G && GT + float tmp_g[MAX_LEN]; // in_unit * filter_size + ResidueMatrix(interval, matrix_g, in_unit, filter_size); + MatrixTranspose(matrix_g, tmp_g, in_unit, filter_size); + MatrixMultiply(tmp_g, inverse_diagonal_matrix, matrix_gt, filter_size, in_unit, in_unit); + MatrixTranspose(matrix_gt, matrix_g, filter_size, in_unit); + return NNACL_OK; +} + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c, + const float *bias, int m, int k, int n) { + int count = 0; + MS_FLOAT32X4 bias_ptr = MS_MOVQ_F32(0); + if (bias != NULL) { + bias_ptr = MS_LDQ_F32(bias); + } + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + MS_FLOAT32X4 res = MS_MOVQ_F32(0); + for (int i = 0; i < k; i++) { + res = MS_MLAQ_F32(res, matrix_a[h_offset + i], matrix_b[w + i * n]); + } + matrix_c[count] = MS_ADDQ_F32(res, bias_ptr); + count++; + } + } +} +#endif + +int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, const float *matrix_gt, + int oc_block, int input_unit, int kernel_unit, int channel, int batch, bool pack) { + if (oc_block == 0) { + return NNACL_PARAM_INVALID; + } + // original weight format : ohwi + int oc_block_num = UP_DIV(batch, oc_block); + int block_stride = channel * oc_block; + int block_num_stride = block_stride * oc_block_num; + + // trans_filter = G*g*GT (g represents weight_data) + // separate into two steps ===> tmp = (g * GT)T ===> trans = (tmp * GT)T use same function:MatrixMultiplyWinograd + float *tmp_data = (float *)(malloc(channel * input_unit * kernel_unit * sizeof(float))); + if (tmp_data == NULL) { + return NNACL_ERR; + } + float *trans_out_data = (float *)(malloc(channel * input_unit * input_unit * sizeof(float))); + if (trans_out_data == NULL) { + free(tmp_data); + return NNACL_ERR; + } + +#ifndef ENABLE_ARM + float *tmp_data1 = (float *)(malloc(channel * input_unit * kernel_unit * sizeof(float))); + if (tmp_data1 == NULL) { + free(tmp_data); + free(trans_out_data); + return NNACL_ERR; + } + float *trans_out_data1 = (float *)(malloc(channel * input_unit * input_unit * sizeof(float))); + if (trans_out_data1 == NULL) { + free(tmp_data); + free(tmp_data1); + free(trans_out_data); + return NNACL_ERR; + } +#endif + + int input_oz_offset = kernel_unit * kernel_unit * channel; + for (int i = 0; i < batch; i++) { + int out_c_block = i / oc_block; + int out_c_res = i % oc_block; + int output_oz_offset = out_c_block * block_stride + out_c_res; + +#ifndef ENABLE_ARM + // tmp_data = g * GT + MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit, kernel_unit, input_unit, + channel, channel * 4); + // tmp_data1 = (tmp_data)T + PackHWCToWHC(tmp_data, tmp_data1, kernel_unit, input_unit, channel); + // trans_out_data1 = tmp * GT + MatrixMultiplyWinograd(tmp_data1, matrix_gt, trans_out_data1, input_unit, kernel_unit, input_unit, channel, + channel * 4); + // trans_out_data = (trans_out_data1)T + PackHWCToWHC(trans_out_data1, trans_out_data, input_unit, input_unit, channel); +#else + // tmp = (g * GT)T + MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit, kernel_unit, input_unit, + channel, channel * 4); + // trans = (tmp * GT)T + MatrixMultiplyWinograd(tmp_data, matrix_gt, trans_out_data, input_unit, kernel_unit, input_unit, channel, + channel * 4); +#endif + if (pack) { + int in_offset = 0; + for (int j = 0; j < input_unit; ++j) { + for (int k = 0; k < input_unit; ++k) { + for (int c = 0; c < channel; ++c) { + *(winograd_data + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c]; + } + in_offset += channel; + output_oz_offset += block_num_stride; + } + } + } else { + memcpy(winograd_data + i * channel * input_unit * input_unit, trans_out_data, + channel * input_unit * input_unit * sizeof(float)); + } + } +#ifndef ENABLE_ARM + free(tmp_data1); + free(trans_out_data1); +#endif + free(tmp_data); + free(trans_out_data); + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/minimal_filtering_generator.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/minimal_filtering_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..44f5bd005e31f551964da46f2b72eec4ccee5b06 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/minimal_filtering_generator.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_MINIMAL_FILTERING_GENERATOR_H_ +#define NNACL_MINIMAL_FILTERING_GENERATOR_H_ + +#ifdef ENABLE_ARM +#include +#endif +#include + +#ifdef __cplusplus +extern "C" { +#endif +void Polynomial(const float *interval, float *m, int degree); + +void DiagonalPlusMatrix(const float *matrix, float *diagonal_matrix, int degree); + +void ResidueMatrix(const float *interval, float *b, int row, int col); + +int LT(const float *poly_array, float *matrix_lt, int n); + +void T(const float *poly_array, float *matrix_t, int n); + +int B(const float *poly_array, float *matrix_b, int in_unit); + +void GenerateIntervalArray(float *array, float interval, int degree); + +void MatrixTranspose(const float *matrix, float *trans_matrix, int row, int col); + +void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_c, int m, int k, int n); + +int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *matrix_bt, float *matrix_g, + float *matrix_gt, float coefficient, int out_unit, int filter_size); +void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, + int in_channel, int c4_channel); + +int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, const float *matrix_gt, + int oc_block, int input_unit_, int kernel_unit_, int channel, int batch, bool pack); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_MINIMAL_FILTERING_GENERATOR_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary.c new file mode 100644 index 0000000000000000000000000000000000000000..d00c33da904b42b3eac4a1f610274b6b5862b1a2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary.c @@ -0,0 +1,111 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/scatter_nd_binary.h" +#include +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/scatter_nd_binary_simd.h" + +int ScatterNDAdd(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, + int task_id) { + if (update == NULL || output == NULL || output_unit_offsets == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->num_unit); + if (type == 0) { + float *update_fp32 = (float *)update; + float *output_fp32 = (float *)output; + for (int i = begin; i < end; i++) { + const float *update_data = update_fp32 + i * param->unit_size; + float *output_data = output_fp32 + output_unit_offsets[i]; + int j = 0; + + SIMD_RUN_NO_SCALAR(ScatterNDAddFp32, j, update_data, param->unit_size, output_data); + for (; j < param->unit_size; j++) { + output_data[j] += update_data[j]; + } + } + } else { + int *update_int32 = (int *)update; + int *output_int32 = (int *)output; + for (int i = begin; i < end; i++) { + const int *update_data = update_int32 + i * param->unit_size; + int *output_data = output_int32 + output_unit_offsets[i]; + int j = 0; + + SIMD_RUN_NO_SCALAR(ScatterNDAddInt32, j, update_data, param->unit_size, output_data); + for (; j < param->unit_size; j++) { + output_data[j] += update_data[j]; + } + } + } + return NNACL_OK; +} + +int ScatterNDUpdate(void *output, const void *update, int *output_unit_offsets, const ScatterNDParameter *param, + int task_id) { + if (param->op_parameter.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->num_unit); + + int data_type_len = param->data_type_len; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(param->unit_size, data_type_len, NNACL_ERR); + + for (int i = begin; i < end; i++) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_unit_offsets[i], data_type_len, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(i, param->unit_size * data_type_len, NNACL_ERR); + (void)memcpy((int8_t *)output + output_unit_offsets[i] * data_type_len, + (int8_t *)update + i * param->unit_size * data_type_len, param->unit_size * data_type_len); + } + return NNACL_OK; +} + +int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, + int task_id) { + if (update == NULL || output == NULL || output_unit_offsets == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->num_unit); + if (type == 0) { + float *update_fp32 = (float *)update; + float *output_fp32 = (float *)output; + for (int i = begin; i < end; i++) { + const float *update_data = update_fp32 + i * param->unit_size; + float *output_data = output_fp32 + output_unit_offsets[i]; + int j = 0; + + SIMD_RUN_NO_SCALAR(ScatterNDMaxFp32, j, update_data, param->unit_size, output_data); + for (; j < param->unit_size; j++) { + output_data[j] = fmaxf(update_data[j], output_data[j]); + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary.h new file mode 100644 index 0000000000000000000000000000000000000000..098d87d768844d1eeb84710f4f4f5cbd7450789c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary.h @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_SCATTER_ND_BINARY_H_ +#define NNACL_BASE_SCATTER_ND_BINARY_H_ + +#include "nnacl_c/scatter_nd_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ScatterNDUpdate(void *output, const void *update, int *output_unit_offsets, const ScatterNDParameter *param, + int task_id); + +int ScatterNDAdd(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, + int task_id); + +int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, + int task_id); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_BASE_SCATTER_ND_BINARY_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..25258ede204f5697ef05437c1e3a8ef78ae5ff1a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary_simd.h.in @@ -0,0 +1,59 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_SCATTER_ND_BINARY_@SIMD_INSTRUCTION@_H_ +#define NNACL_BASE_SCATTER_ND_BINARY_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + + static inline int ScatterNDAddFp32@SIMD_INSTRUCTION@(int index, const float *update, int size, float *output) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); + } + return index; +} + +static inline int ScatterNDAddInt32@SIMD_INSTRUCTION@(int index, const int *update, int size, int *output) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); + } + return index; +} + +static inline int ScatterNDMaxFp32@SIMD_INSTRUCTION@(int index, const float *update, int size, float *output) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); + } + return index; +} + +static inline int ScatterNDMaxInt32@SIMD_INSTRUCTION@(int index, const int *update, int size, int *output) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif // NNACL_BASE_SCATTER_ND_BINARY_@SIMD_INSTRUCTION@_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/sequence_unstack_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/sequence_unstack_base.h new file mode 100644 index 0000000000000000000000000000000000000000..988ce880831a033a160110483e2c08b7b846766c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/sequence_unstack_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_SEQUENCE_UNSTACK_H_ +#define MINDSPORE_NNACL_SEQUENCE_UNSTACK_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/sequence_unstack_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void SequenceUnstack(const void *input, void **output, const SequenceUnstackParameter *para, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_SEQUENCE_UNSTACK_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/slice_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/slice_base.c new file mode 100644 index 0000000000000000000000000000000000000000..c0bca17441838c172ecd0ea2c7a040ea3a56f24d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/slice_base.c @@ -0,0 +1,173 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/base/slice_base.h" +#include + +void InitSliceStruct(SliceStruct *slice, TensorC *in_tensor, TensorC *begin_tensor, TensorC *size_tensor) { + slice->param_length_ = in_tensor->shape_size_; + + int32_t *begin = (int32_t *)begin_tensor->data_; + int32_t *size = (int32_t *)size_tensor->data_; + + for (int i = 0; i < slice->param_length_; ++i) { + slice->shape_[i] = in_tensor->shape_[i]; + slice->begin_[i] = begin[i]; + slice->size_[i] = size[i] < 0 ? slice->shape_[i] - slice->begin_[i] : size[i]; + slice->end_[i] = slice->begin_[i] + slice->size_[i]; + } + return; +} + +void PadSliceParameterTo8D(SliceStruct *param) { + int32_t begin[DIMENSION_8D]; + int32_t end[DIMENSION_8D]; + int32_t slice_size[DIMENSION_8D]; + int32_t data_shape[DIMENSION_8D]; + for (int32_t i = 0; i < param->param_length_; ++i) { + begin[i] = param->begin_[i]; + end[i] = param->end_[i]; + slice_size[i] = param->size_[i] < 0 ? param->shape_[i] - begin[i] : param->size_[i]; + data_shape[i] = param->shape_[i]; + } + int32_t real_index = param->param_length_ - 1; + for (int32_t i = DIMENSION_8D - 1; i >= 0; --i) { + if (real_index >= 0) { + param->begin_[i] = begin[real_index]; + param->end_[i] = end[real_index]; + param->size_[i] = slice_size[real_index]; + param->shape_[i] = data_shape[real_index--]; + } else { + param->begin_[i] = 0; + param->end_[i] = 1; + param->size_[i] = 1; + param->shape_[i] = 1; + } + } + param->param_length_ = DIMENSION_8D; +} + +void DoSlice(const void *input, void *output, const SliceStruct *param, int thread_id, int thread_num, int data_size) { + int8_t *int8_in = (int8_t *)input; + int8_t *int8_out = (int8_t *)output; + + int out_stride[8]; + out_stride[7] = 1; + for (int i = 6; i >= 0; --i) { + out_stride[i] = out_stride[i + 1] * param->size_[i + 1]; + } + int count_per_thread = UP_DIV(param->size_[5], thread_num); + int thread_begin = thread_id * count_per_thread; + int thread_end = MSMIN(param->size_[5], thread_begin + count_per_thread); + int copy_size = param->size_[7] * data_size; + int in_stride[8]; + in_stride[7] = 1; + for (int i = 6; i >= 0; --i) { + in_stride[i] = param->shape_[i + 1] * in_stride[i + 1]; + } + + for (int ii = 0; ii < param->size_[0]; ++ii) { + int out_offset0 = ii * out_stride[0]; + int in_offset0 = (ii + param->begin_[0]) * in_stride[0] + param->begin_[7]; + for (int jj = 0; jj < param->size_[1]; ++jj) { + int out_offset1 = jj * out_stride[1] + out_offset0; + int in_offset1 = (jj + param->begin_[1]) * in_stride[1] + in_offset0; + for (int kk = 0; kk < param->size_[2]; ++kk) { + int out_offset2 = kk * out_stride[2] + out_offset1; + int in_offset2 = (kk + param->begin_[2]) * in_stride[2] + in_offset1; + for (int ll = 0; ll < param->size_[3]; ++ll) { + int out_offset3 = ll * out_stride[3] + out_offset2; + int in_offset3 = (ll + param->begin_[3]) * in_stride[3] + in_offset2; + for (int i = 0; i < param->size_[4]; ++i) { + int out_offset4 = i * out_stride[4] + out_offset3; + int in_offset4 = (i + param->begin_[4]) * in_stride[4] + in_offset3; + for (int j = thread_begin; j < thread_end; ++j) { + int out_offset5 = j * out_stride[5] + out_offset4; + int in_offset5 = (j + param->begin_[5]) * in_stride[5] + in_offset4; + for (int k = 0; k < param->size_[6]; ++k) { + int out_offset6 = k * out_stride[6] + out_offset5; + int in_offset6 = (k + param->begin_[6]) * in_stride[6] + in_offset5; + memcpy(int8_out + out_offset6 * data_size, int8_in + in_offset6 * data_size, copy_size); + } + } + } + } + } + } + } +} + +static bool WhetherCopyByAxis(const int32_t *begin, const int32_t *end, const int32_t *shape, int dim) { + for (int i = dim + 1; i < DIMENSION_8D; ++i) { + if (begin[i] != 0 || end[i] != shape[i]) return false; + } + return true; +} + +void DoSliceNoParallel(const void *input, void *output, const SliceStruct *param, int data_size) { + int8_t *int8_in = (int8_t *)input; + int8_t *int8_out = (int8_t *)output; + + int copy_size = param->size_[7] * data_size; + int in_stride[8]; + in_stride[7] = 1; + for (int i = 6; i >= 0; --i) { + in_stride[i] = param->shape_[i + 1] * in_stride[i + 1]; + } + bool axis_copy_flag[DIMENSION_8D] = {false}; + for (int i = 0; i < DIMENSION_8D; ++i) { + axis_copy_flag[i] = WhetherCopyByAxis(param->begin_, param->end_, param->shape_, i); + } + int out_offset = 0; + for (int32_t dim0 = param->begin_[0]; dim0 < param->end_[0]; ++dim0) { + int in_offset0 = dim0 * in_stride[0] + param->begin_[7]; +#define FAST_COPY_IF_NEED(rank) \ + if (axis_copy_flag[rank]) { \ + int left_block_num = param->end_[rank] - dim##rank; \ + memcpy(int8_out + out_offset * data_size, int8_in + in_offset##rank * data_size, \ + in_stride[rank] * left_block_num * data_size); \ + out_offset += in_stride[rank] * left_block_num; \ + dim##rank += left_block_num; \ + continue; \ + } + FAST_COPY_IF_NEED(0); + for (int dim1 = param->begin_[1]; dim1 < param->end_[1]; ++dim1) { + int in_offset1 = dim1 * in_stride[1] + in_offset0; + FAST_COPY_IF_NEED(1); + for (int32_t dim2 = param->begin_[2]; dim2 < param->end_[2]; ++dim2) { + int in_offset2 = in_offset1 + dim2 * in_stride[2]; + FAST_COPY_IF_NEED(2); + for (int32_t dim3 = param->begin_[3]; dim3 < param->end_[3]; ++dim3) { + int in_offset3 = in_offset2 + dim3 * in_stride[3]; + FAST_COPY_IF_NEED(3); + for (int32_t dim4 = param->begin_[4]; dim4 < param->end_[4]; ++dim4) { + int in_offset4 = in_offset3 + dim4 * in_stride[4]; + FAST_COPY_IF_NEED(4); + for (int32_t dim5 = param->begin_[5]; dim5 < param->end_[5]; ++dim5) { + int in_offset5 = in_offset4 + dim5 * in_stride[5]; + FAST_COPY_IF_NEED(5); +#undef FAST_COPY_IF_NEED + for (int32_t dim6 = param->begin_[6]; dim6 < param->end_[6]; ++dim6) { + int in_offset6 = in_offset5 + dim6 * in_stride[6]; + memcpy(int8_out + out_offset * data_size, int8_in + in_offset6 * data_size, copy_size); + out_offset += param->size_[7]; + } + } + } + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/slice_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/slice_base.h new file mode 100644 index 0000000000000000000000000000000000000000..8656cb824b4fde11aa3ddeb098b8af28e192c387 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/slice_base.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_SLICE_BASE_H_ +#define NNACL_BASE_SLICE_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/kernel/slice.h" + +#ifdef __cplusplus +extern "C" { +#endif +void InitSliceStruct(SliceStruct *slice, TensorC *in_tensor, TensorC *begin_tensor, TensorC *size_tensor); +void PadSliceParameterTo8D(SliceStruct *param); + +void DoSlice(const void *input, void *output, const SliceStruct *param, int thread_id, int thread_num, int data_size); +void DoSliceNoParallel(const void *input, void *output, const SliceStruct *param, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_BASE_SLICE_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/space_to_depth_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/space_to_depth_base.c new file mode 100644 index 0000000000000000000000000000000000000000..9633a35d87f98bd409579c514138b4cc2ee409fe --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/space_to_depth_base.c @@ -0,0 +1,54 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/space_to_depth_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/errorcode.h" + +int SpaceToDepthForNHWC(const void *input, void *output, const int *in_shape, const int *out_shape, int shape_size, + SpaceToDepthParameter *param, int task_id) { + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + int output_h = out_shape[kNHWC_H]; + int unit_per_thread = UP_DIV(output_h, param->op_parameter_.thread_num_); + int h_start = unit_per_thread * task_id; + int h_end = MSMIN(h_start + unit_per_thread, output_h); + + int block_size = param->block_size_; + int in_strides[C4NUM]; + int out_strides[C4NUM]; + ComputeStrides(in_shape, in_strides, shape_size); + ComputeStrides(out_shape, out_strides, shape_size); + for (int i = 0; i < out_shape[0]; ++i) { + int64_t in_offset_n = i * in_strides[0]; + int64_t out_offset_n = i * out_strides[0]; + for (int j = h_start; j < h_end; ++j) { + int64_t in_offset_h = in_offset_n + j * block_size * in_strides[1]; + int64_t out_offset_h = out_offset_n + j * out_strides[1]; + for (int k = 0; k < out_shape[2]; ++k) { + int64_t in_offset_w = in_offset_h + k * block_size * in_strides[2]; + int64_t out_offset_w = out_offset_h + k * out_strides[2]; + for (int l = 0; l < block_size; ++l) { + memcpy((int8_t *)output + (out_offset_w + l * block_size * in_strides[DIMENSION_2D]) * param->date_type_len, + (const int8_t *)input + (in_offset_w + l * in_strides[DIMENSION_1D]) * param->date_type_len, + block_size * in_strides[DIMENSION_2D] * param->date_type_len); + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/space_to_depth_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/space_to_depth_base.h new file mode 100644 index 0000000000000000000000000000000000000000..916f3c18f473cdd8dd0fcc35cd148b3edb50def5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/space_to_depth_base.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_SPACE_TO_DEPTH_BASE_H_ +#define NNACL_BASE_SPACE_TO_DEPTH_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/space_to_depth_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SpaceToDepthForNHWC(const void *input, void *output, const int *in_shape, const int *out_shape, int shape_size, + SpaceToDepthParameter *param, int task_id); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_SPACE_TO_DEPTH_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_base.c new file mode 100644 index 0000000000000000000000000000000000000000..f48922c3d7b962c3ea9c1df0fdbcfbb308adb4c4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_base.c @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/split_base.h" +#include "nnacl_c/split_parameter.h" +#include +#include "nnacl_c/errorcode.h" + +int DoSplit(const void *in_data, void **out_data, const int *input_shape, int offset, int num_unit, + const SplitParameter *split_param, int data_size) { + const int8_t *int8_in = (int8_t *)in_data; + + const int num_split = split_param->num_split_; + const int *split_sizes = split_param->split_sizes_; + const int *strides = split_param->strides_; + const int split_dim = split_param->split_dim_; + int in_stride = strides[split_dim]; + + int in_stride_bytes = in_stride * data_size; + + int split_which; + int split_times; + int stride_per_split = in_stride * input_shape[split_dim]; + + split_which = offset % num_split; + split_times = offset / num_split; + const int8_t *src = int8_in + split_times * stride_per_split * data_size; + + for (int i = 0; i < split_which; i++) { + src += split_sizes[i] * in_stride * data_size; + } + + for (int i = offset; i < offset + num_unit; i++) { + split_which = i % num_split; + split_times = i / num_split; + int split_size = split_sizes[split_which]; + int8_t *int8_out = (int8_t *)out_data[split_which]; + int8_t *dst = int8_out + split_times * in_stride * split_size * data_size; + (void)memcpy(dst, src, split_size * in_stride_bytes); + src += split_size * in_stride * data_size; + } + + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_base.h new file mode 100644 index 0000000000000000000000000000000000000000..71a5af467c5c8ebfaf7e3d4e4d026a9e2d8a6b1f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_SPLIT_BASE_H_ +#define NNACL_BASE_SPLIT_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoSplit(const void *in_data, void **out_data, const int *input_shape, int offset, int num_unit, + const SplitParameter *split_param, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_SPLIT_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_with_over_lap_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_with_over_lap_base.c new file mode 100644 index 0000000000000000000000000000000000000000..dfca91702fe18751ebca30370d93a949ec7c8ab4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_with_over_lap_base.c @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/split_with_over_lap_base.h" +#include +#include "nnacl_c/errorcode.h" + +int DoSplitWithOverlapParallel(const char *in_data, char **out_data, int slice_idx, + const SplitWithOverlapParameter *param, const int *start_indices, + const int *end_indices) { + int start_index = start_indices[slice_idx]; + int end_index = end_indices[slice_idx]; + + int input_stride = param->split_dim_size_ * param->inner_stride_ * param->element_bytes_; + int out_stride = (end_index - start_index) * param->inner_stride_ * param->element_bytes_; + + const char *src_ptr = in_data + start_index * param->inner_stride_ * param->element_bytes_; + char *dst_ptr = out_data[slice_idx]; + + for (int i = 0; i < param->outer_total_dim_; i++) { + (void)memcpy(dst_ptr + i * out_stride, src_ptr, out_stride); + src_ptr += input_stride; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_with_over_lap_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_with_over_lap_base.h new file mode 100644 index 0000000000000000000000000000000000000000..3b5db1c1516371d62b002a379c32fa03471fa133 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_with_over_lap_base.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ +#define NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoSplitWithOverlapParallel(const char *in_data, char **out_data, int slice_idx, + const SplitWithOverlapParameter *param, const int *start_indices, + const int *end_indices); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/stack_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/stack_base.c new file mode 100644 index 0000000000000000000000000000000000000000..64454dd5d532ee4e1d4604f6382331d7c89a745f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/stack_base.c @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/base/stack_base.h" + +void Stack(void **inputs, void *output, size_t input_num, size_t copy_size, int outer_start, int outer_end) { + size_t out_offset = 0; + for (size_t i = outer_start; i < outer_end; ++i) { + for (size_t j = 0; j < input_num; ++j) { + memcpy((char *)output + out_offset, (char *)inputs[j] + i * copy_size, copy_size); + out_offset += copy_size; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/stack_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/stack_base.h new file mode 100644 index 0000000000000000000000000000000000000000..b54a9e75a88d5922b219af1ec984cc40521010b0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/stack_base.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_STACK_BASE_H_ +#define NNACL_BASE_STACK_BASE_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/stack_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Stack(void **inputs, void *output, size_t input_num, size_t copy_size, int outer_start, int outer_end); +#ifdef __cplusplus +} +#endif +#endif // NNACL_BASE_STACK_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/tile_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/tile_base.c new file mode 100644 index 0000000000000000000000000000000000000000..e80328c9712ee88ddb5fd22df0fbc707ce8dcdf5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/tile_base.c @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/tile_base.h" +#include +#include "nnacl_c/errorcode.h" + +void DoCopyData(const uint8_t *input_data, uint8_t *output_data, size_t size, size_t data_size, size_t multiple) { + uint8_t *out_data = output_data; + for (size_t i = 0; i < multiple; ++i) { + (void)memcpy(out_data, input_data, size * sizeof(uint8_t) * data_size); + out_data += size * data_size; + } +} + +int DoTileOneDimension(uint8_t *input_data, uint8_t *output_data, size_t dim, const TileStruct *tile) { + int src_dim_size = tile->in_shape_[dim]; + if (dim == tile->in_dim_ - 1) { + DoCopyData(input_data, output_data, src_dim_size, tile->data_size_, tile->multiples_[dim]); + return NNACL_OK; + } + for (int i = 0; i < src_dim_size; ++i) { + for (int j = 0; j < tile->multiples_[dim]; ++j) { + int in_pos = tile->in_strides_[dim] * i; + int out_pos = tile->out_strides_[dim] * (i + j * src_dim_size); + DoTileOneDimension(input_data + in_pos * tile->data_size_, output_data + out_pos * tile->data_size_, dim + 1, + tile); + } + } + return NNACL_OK; +} + +void Tile(void *input_data, void *output_data, const TileStruct *tile) { + DoTileOneDimension((uint8_t *)input_data, (uint8_t *)output_data, 0, tile); +} + +void TileSimple(void *input_data, void *output_data, size_t begin, size_t end, const TileStruct *tile) { + uint8_t *out_data = output_data; + uint8_t *in_data = input_data; + size_t dst_one_row_size = tile->fast_stride_ * tile->fast_multiple_ * tile->data_size_; + for (size_t i = begin; i < end; ++i) { + uint8_t *src = in_data + i * tile->fast_stride_ * tile->data_size_; + uint8_t *dst = out_data + i * tile->fast_stride_ * tile->fast_multiple_ * tile->data_size_; + size_t offset = tile->fast_stride_ * tile->data_size_; + (void)memcpy(dst, src, offset); + // copy size double each time + while (2 * offset <= dst_one_row_size) { + (void)memcpy(dst + offset, dst, offset); + offset *= 2; + } + if (2 * offset > dst_one_row_size) { + (void)memcpy(dst + offset, dst, dst_one_row_size - offset); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/tile_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/tile_base.h new file mode 100644 index 0000000000000000000000000000000000000000..cc84e8d81a2c7c9a0de81df0a9e45b40d0348b62 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/tile_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_TILE_BASE_H_ +#define NNACL_BASE_TILE_BASE_H_ + +#include "nnacl_c/kernel/tile.h" +#include "nnacl_c/tile_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Tile(void *input_data, void *output_data, const TileStruct *tile); +void TileSimple(void *input_data, void *output_data, size_t begin, size_t end, const TileStruct *tile); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_TILE_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/transpose_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/transpose_base.c new file mode 100644 index 0000000000000000000000000000000000000000..55a53d581215be126d7ed77ad4aa7d8ba55f460c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/transpose_base.c @@ -0,0 +1,274 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/transpose_base.h" +#include "nnacl_c/errorcode.h" + +#define TRANSPOSE_TWO_DIMS(TYPE, NAME) \ + void TransposeDim2##NAME(const TYPE *in_data, TYPE *out_data, const int *strides, const int *out_strides, \ + const int *perm, const int *output_shape) { \ + const int stride0 = strides[perm[0]]; \ + const int stride1 = strides[perm[1]]; \ + const int output0 = output_shape[0]; \ + const int output1 = output_shape[1]; \ + for (int i = 0; i < output0; ++i) { \ + int out_stride0_i = i * output1; \ + int stride0_i = i * 1 * stride0; \ + for (int j = 0; j < output1; ++j) { \ + out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; \ + } \ + } \ + } + +#define TRANSPOSE_THREE_DIMS(TYPE, NAME) \ + void TransposeDim3##NAME(const TYPE *in_data, TYPE *out_data, const int *strides, const int *out_strides, \ + const int *perm, const int *output_shape) { \ + const int stride0 = strides[perm[0]]; \ + const int stride1 = strides[perm[1]]; \ + const int stride2 = strides[perm[2]]; \ + const int out_stride0 = out_strides[0]; \ + const int out_stride1 = out_strides[1]; \ + const int output0 = output_shape[0]; \ + const int output1 = output_shape[1]; \ + const int output2 = output_shape[2]; \ + for (int i = 0; i < output0; ++i) { \ + int out_stride0_i = i * out_stride0; \ + int stride0_i = i * stride0; \ + for (int j = 0; j < output1; ++j) { \ + int out_stride1_j = j * out_stride1; \ + int stride1_j = j * stride1; \ + for (int k = 0; k < output2; ++k) { \ + out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; \ + } \ + } \ + } \ + } + +#define TRANSPOSE_FOUR_DIMS(TYPE, NAME) \ + void TransposeDim4##NAME(const TYPE *in_data, TYPE *out_data, const int *strides, const int *out_strides, \ + const int *perm, const int *output_shape) { \ + const int stride0 = strides[perm[0]]; \ + const int stride1 = strides[perm[1]]; \ + const int stride2 = strides[perm[2]]; \ + const int stride3 = strides[perm[3]]; \ + const int out_stride0 = out_strides[0]; \ + const int out_stride1 = out_strides[1]; \ + const int out_stride2 = out_strides[2]; \ + const int output0 = output_shape[0]; \ + const int output1 = output_shape[1]; \ + const int output2 = output_shape[2]; \ + const int output3 = output_shape[3]; \ + for (int i = 0; i < output0; ++i) { \ + int out_stride0_i = i * out_stride0; \ + int stride0_i = i * stride0; \ + for (int j = 0; j < output1; ++j) { \ + int out_stride1_j = j * out_stride1; \ + int stride1_j = j * stride1; \ + for (int k = 0; k < output2; ++k) { \ + int out_stride2_k = k * out_stride2; \ + int stride2_k = k * stride2; \ + for (int m = 0; m < output3; ++m) { \ + out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = \ + in_data[stride0_i + stride1_j + stride2_k + m * stride3]; \ + } \ + } \ + } \ + } \ + } + +#define TRANSPOSE_FIVE_DIMS(TYPE, NAME) \ + void TransposeDim5##NAME(const TYPE *in_data, TYPE *out_data, const int *strides, const int *out_strides, \ + const int *perm, const int *output_shape) { \ + const int stride0 = strides[perm[0]]; \ + const int stride1 = strides[perm[1]]; \ + const int stride2 = strides[perm[2]]; \ + const int stride3 = strides[perm[3]]; \ + const int stride4 = strides[perm[4]]; \ + const int out_stride0 = out_strides[0]; \ + const int out_stride1 = out_strides[1]; \ + const int out_stride2 = out_strides[2]; \ + const int out_stride3 = out_strides[3]; \ + const int output0 = output_shape[0]; \ + const int output1 = output_shape[1]; \ + const int output2 = output_shape[2]; \ + const int output3 = output_shape[3]; \ + const int output4 = output_shape[4]; \ + for (int i = 0; i < output0; ++i) { \ + int out_stride0_i = i * out_stride0; \ + int stride0_i = i * stride0; \ + for (int j = 0; j < output1; ++j) { \ + int out_stride1_j = j * out_stride1; \ + int stride1_j = j * stride1; \ + for (int k = 0; k < output2; ++k) { \ + int out_stride2_k = k * out_stride2; \ + int stride2_k = k * stride2; \ + for (int m = 0; m < output3; ++m) { \ + int out_stride3_m = m * out_stride3; \ + int stride3_m = m * stride3; \ + for (int n = 0; n < output4; ++n) { \ + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = \ + in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; \ + } \ + } \ + } \ + } \ + } \ + } + +#define TRANSPOSE_SIX_DIMS(TYPE, NAME) \ + void TransposeDim6##NAME(const TYPE *in_data, TYPE *out_data, const int *strides, const int *out_strides, \ + const int *perm, const int *output_shape) { \ + const int stride0 = strides[perm[0]]; \ + const int stride1 = strides[perm[1]]; \ + const int stride2 = strides[perm[2]]; \ + const int stride3 = strides[perm[3]]; \ + const int stride4 = strides[perm[4]]; \ + const int stride5 = strides[perm[5]]; \ + const int out_stride0 = out_strides[0]; \ + const int out_stride1 = out_strides[1]; \ + const int out_stride2 = out_strides[2]; \ + const int out_stride3 = out_strides[3]; \ + const int out_stride4 = out_strides[4]; \ + const int output0 = output_shape[0]; \ + const int output1 = output_shape[1]; \ + const int output2 = output_shape[2]; \ + const int output3 = output_shape[3]; \ + const int output4 = output_shape[4]; \ + const int output5 = output_shape[5]; \ + for (int i = 0; i < output0; ++i) { \ + int out_stride0_i = i * out_stride0; \ + int stride0_i = i * stride0; \ + for (int j = 0; j < output1; ++j) { \ + int out_stride1_j = j * out_stride1; \ + int stride1_j = j * stride1; \ + for (int k = 0; k < output2; ++k) { \ + int out_stride2_k = k * out_stride2; \ + int stride2_k = k * stride2; \ + for (int m = 0; m < output3; ++m) { \ + int out_stride3_m = m * out_stride3; \ + int stride3_m = m * stride3; \ + for (int n = 0; n < output4; ++n) { \ + int out_stride4_n = n * out_stride4; \ + int stride4_n = n * stride4; \ + for (int g = 0; g < output5; ++g) { \ + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n + g] = \ + in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_n + g * stride5]; \ + } \ + } \ + } \ + } \ + } \ + } \ + } + +#define TRANSPOSE_DIMS(TYPE, NAME) \ + void TransposeDims##NAME(const TYPE *in_data, TYPE *out_data, const int *output_shape, \ + const TransposeParameter *transpose_param, int task_id, int thread_num) { \ + NNACL_CHECK_NULL_RETURN_VOID(in_data); \ + NNACL_CHECK_NULL_RETURN_VOID(out_data); \ + NNACL_CHECK_NULL_RETURN_VOID(output_shape); \ + NNACL_CHECK_NULL_RETURN_VOID(transpose_param); \ + NNACL_CHECK_ZERO_RETURN(thread_num); \ + const int *perm = transpose_param->perm_; \ + const int *strides = transpose_param->strides_; \ + const int *out_strides = transpose_param->out_strides_; \ + int num_axes = transpose_param->num_axes_; \ + size_t data_size = (*out_strides) * output_shape[0]; \ + size_t offset_size = UP_DIV(data_size, thread_num); \ + size_t task_offset = offset_size * task_id; \ + int count = data_size - task_offset; \ + if (count <= 0) { \ + return; \ + } \ + count = MSMIN(offset_size, count); \ + for (int idx = task_offset; idx < task_offset + count; ++idx) { \ + int pos = idx; \ + int output_idx = 0; \ + int input_idx = 0; \ + for (int i = 0; i < num_axes; ++i) { \ + NNACL_CHECK_ZERO_RETURN(*(out_strides + i)); \ + int position = pos / *(out_strides + i); \ + int out_stride = i < num_axes - 1 ? out_strides[i] : 1; \ + output_idx += (position * out_stride); \ + input_idx += (position * strides[perm[i]]); \ + pos -= position * (*(out_strides + i)); \ + } \ + out_data[output_idx] = in_data[input_idx]; \ + } \ + } + +#define DOTRANSPOSE(TYPE, NAME) \ + int DoTranspose##NAME(const TYPE *in_data, TYPE *out_data, const int *output_shape, \ + const TransposeParameter *transpose_param) { \ + NNACL_CHECK_NULL_RETURN_ERR(in_data); \ + NNACL_CHECK_NULL_RETURN_ERR(out_data); \ + NNACL_CHECK_NULL_RETURN_ERR(output_shape); \ + NNACL_CHECK_NULL_RETURN_ERR(transpose_param); \ + const int *perm = transpose_param->perm_; \ + const int *strides = transpose_param->strides_; \ + const int *out_strides = transpose_param->out_strides_; \ + int data_size = transpose_param->data_num_ * sizeof(TYPE); \ + int num_axes = transpose_param->num_axes_; \ + bool needTranspose = false; \ + for (int i = 1; i < num_axes; ++i) { \ + if (perm[i] - perm[i - 1] != 1) { \ + needTranspose = true; \ + break; \ + } \ + } \ + if (!needTranspose) { \ + (void)memcpy(out_data, in_data, data_size); \ + return NNACL_OK; \ + } \ + for (int i = 0; i < num_axes; ++i) { \ + if (perm[i] < 0) { \ + return NNACL_PARAM_INVALID; \ + } \ + } \ + if (num_axes == 2) { \ + TransposeDim2##NAME(in_data, out_data, strides, out_strides, perm, output_shape); \ + } else if (num_axes == 3) { \ + TransposeDim3##NAME(in_data, out_data, strides, out_strides, perm, output_shape); \ + } else if (num_axes == 4) { \ + TransposeDim4##NAME(in_data, out_data, strides, out_strides, perm, output_shape); \ + } else if (num_axes == 5) { \ + TransposeDim5##NAME(in_data, out_data, strides, out_strides, perm, output_shape); \ + } else if (num_axes == 6) { \ + TransposeDim6##NAME(in_data, out_data, strides, out_strides, perm, output_shape); \ + } else { \ + return NNACL_ERR; \ + } \ + return NNACL_OK; \ + } + +#define TRANSPOSE_TEMPLATE(TYPE, NAME) \ + TRANSPOSE_TWO_DIMS(TYPE, NAME) \ + TRANSPOSE_THREE_DIMS(TYPE, NAME) \ + TRANSPOSE_FOUR_DIMS(TYPE, NAME) \ + TRANSPOSE_FIVE_DIMS(TYPE, NAME) \ + TRANSPOSE_SIX_DIMS(TYPE, NAME) \ + TRANSPOSE_DIMS(TYPE, NAME) \ + DOTRANSPOSE(TYPE, NAME) + +TRANSPOSE_TEMPLATE(uint8_t, UInt8) +TRANSPOSE_TEMPLATE(uint16_t, UInt16) +TRANSPOSE_TEMPLATE(uint32_t, UInt32) +TRANSPOSE_TEMPLATE(uint64_t, UInt64) +TRANSPOSE_TEMPLATE(int16_t, Int16) +TRANSPOSE_TEMPLATE(int32_t, Int32) +TRANSPOSE_TEMPLATE(int64_t, Int64) +TRANSPOSE_TEMPLATE(double, Float64) +TRANSPOSE_TEMPLATE(bool, Bool) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/transpose_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/transpose_base.h new file mode 100644 index 0000000000000000000000000000000000000000..67fa636f40325100d635e1fcbf96df9b7a0a1d2e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/transpose_base.h @@ -0,0 +1,69 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_TRANSPOSE_BASE_H_ +#define NNACL_BASE_TRANSPOSE_BASE_H_ + +#include "nnacl_c/transpose_parameter.h" +#include + +#ifdef __cplusplus +extern "C" { +#endif + +int DoTransposeUInt8(const uint8_t *in_data, uint8_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeUInt16(const uint16_t *in_data, uint16_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeUInt32(const uint32_t *in_data, uint32_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeUInt64(const uint64_t *in_data, uint64_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeInt16(const int16_t *in_data, int16_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeInt32(const int32_t *in_data, int32_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeInt64(const int64_t *in_data, int64_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeFloat64(const double *in_data, double *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeBool(const bool *in_data, bool *out_data, const int *output_shape, + const TransposeParameter *transpose_param); + +void TransposeDimsUInt8(const uint8_t *in_data, uint8_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsUInt16(const uint16_t *in_data, uint16_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsUInt32(const uint32_t *in_data, uint32_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsUInt64(const uint64_t *in_data, uint64_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsInt16(const int16_t *in_data, int16_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsInt32(const int32_t *in_data, int32_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsInt64(const int64_t *in_data, int64_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsFloat64(const double *in_data, double *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsBool(const bool *in_data, bool *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_TRANSPOSE_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unsorted_segment_sum_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unsorted_segment_sum_base.c new file mode 100644 index 0000000000000000000000000000000000000000..d962914c11aa2740e6269275c895acab5bb60bbe --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unsorted_segment_sum_base.c @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/base/unsorted_segment_sum_base.h" +#include "nnacl_c/errorcode.h" + +#define UNSORTEDSEGMENTSUM(type, type1) \ + int UnsortedSegmentSum_##type##_##type1(const type *input, int unit_num, int input_dim1, const type1 *indices, \ + type *output, int output_dim0, int output_dim1) { \ + NNACL_CHECK_NULL_RETURN_ERR(input); \ + NNACL_CHECK_NULL_RETURN_ERR(indices); \ + NNACL_CHECK_NULL_RETURN_ERR(output); \ + if (input_dim1 == 0) { \ + return NNACL_ERR; \ + } \ + for (int i = 0; i < unit_num; ++i) { \ + int j = i / input_dim1; \ + int k = i % input_dim1; \ + \ + type1 index = indices[j]; \ + if (index < 0 || index >= output_dim0) { \ + continue; \ + } \ + type1 output_index = index * output_dim1 + k; \ + output[output_index] += input[i]; \ + } \ + return NNACL_OK; \ + } + +UNSORTEDSEGMENTSUM(int, int) +UNSORTEDSEGMENTSUM(float, int) +UNSORTEDSEGMENTSUM(int, int64_t) +UNSORTEDSEGMENTSUM(float, int64_t) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unsorted_segment_sum_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unsorted_segment_sum_base.h new file mode 100644 index 0000000000000000000000000000000000000000..e05d62f89215d09b20dfd65f81c44945f77717d7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unsorted_segment_sum_base.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_UNSORTED_SEGMENT_SUM_BASE_H_ +#define NNACL_BASE_UNSORTED_SEGMENT_SUM_BASE_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +#define UnsortedSegmentSum(type, type1, input, unit_num, input_dim1, indices, output, output_dim0, output_dim1) \ + UnsortedSegmentSum_##type##_##type1(input, unit_num, input_dim1, indices, output, output_dim0, output_dim1) +int UnsortedSegmentSum_int_int(const int *input, int unit_num, int input_dim1, const int *indices, int *output, + int output_dim0, int output_dim1); +int UnsortedSegmentSum_float_int(const float *input, int unit_num, int input_dim1, const int *indices, float *output, + int output_dim0, int output_dim1); +int UnsortedSegmentSum_int_int64_t(const int *input, int unit_num, int input_dim1, const int64_t *indices, int *output, + int output_dim0, int output_dim1); +int UnsortedSegmentSum_float_int64_t(const float *input, int unit_num, int input_dim1, const int64_t *indices, + float *output, int output_dim0, int output_dim1); +#ifdef __cplusplus +} +#endif +#endif // NNACL_BASE_UNSORTED_SEGMENT_SUM_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unstack_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unstack_base.c new file mode 100644 index 0000000000000000000000000000000000000000..d286de837d0d41b3103badf7c0a7f73e411f208c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unstack_base.c @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/unstack_base.h" + +void Unstack(const void *input, void **output, const UnstackParameter *para, int data_size) { + NNACL_CHECK_NULL_RETURN_VOID(input); + NNACL_CHECK_NULL_RETURN_VOID(output); + NNACL_CHECK_NULL_RETURN_VOID(para); + const int8_t *in_addr = (int8_t *)input; + for (int j = 0; j < para->num_; j++) { + int8_t *out_addr = (int8_t *)output[j]; + int out_offset = 0; + for (int i = 0; i < para->pre_dims_; i++) { + int in_offset = i * para->axis_dim_ * para->after_dims_ + j * para->after_dims_; + (void)memcpy(out_addr + out_offset * data_size, in_addr + in_offset * data_size, para->after_dims_ * data_size); + out_offset += para->after_dims_; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unstack_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unstack_base.h new file mode 100644 index 0000000000000000000000000000000000000000..9b0856de654a7d17cf5989cd65a0abd845c504ce --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unstack_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_UNSTACK_BASE_H_ +#define NNACL_BASE_UNSTACK_BASE_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/unstack_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Unstack(const void *input, void **output, const UnstackParameter *para, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_UNSTACK_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/batch_to_space_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/batch_to_space_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..ada3dd0b9f83c7dfca8c24cce86b4cc75b10760e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/batch_to_space_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BATCH_TO_SPACE_PARAMETER_H_ +#define NNACL_BATCH_TO_SPACE_PARAMETER_H_ + +#include +#include "nnacl_c/op_base.h" + +#define BATCH_TO_SPACE_BLOCK_SHAPE_SIZE 2 + +typedef struct BatchToSpaceParameter { + OpParameter op_parameter_; + int32_t block_shape_[BATCH_TO_SPACE_BLOCK_SHAPE_SIZE]; + int32_t crops_[COMM_SHAPE_SIZE]; +} BatchToSpaceParameter; + +#endif // NNACL_BATCH_TO_SPACE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/batchnorm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/batchnorm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..b140e9537f603c239251cff5167e10c32756d3b5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/batchnorm_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BATCHNORM_PARAMETER_H_ +#define NNACL_BATCHNORM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct BatchNormParameter { + OpParameter op_parameter_; + float epsilon_; + bool is_training_; + float momentum_; +} BatchNormParameter; + +#endif // NNACL_BATCHNORM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/broadcast_to_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/broadcast_to_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..6e934aadf593df286e85d46babf96def4f829c24 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/broadcast_to_parameter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BROADCAST_TO_PARAMETER_H_ +#define NNACL_BROADCAST_TO_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct BroadcastToParameter { + OpParameter op_parameter_; + int shape_[MAX_SHAPE_SIZE]; + size_t shape_size_; +} BroadcastToParameter; + +typedef struct BroadcastShapeInfo { + int input_shape_[MAX_SHAPE_SIZE]; + int input_shape_size_; + int output_shape_[MAX_SHAPE_SIZE]; + int output_shape_size_; +} BroadcastShapeInfo; + +#endif // NNACL_BROADCAST_TO_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/call_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/call_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..ea5e85e82fb4d6c130793aebd7c35a2315f8936d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/call_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_CALL_PARAMETER_H_ +#define NNACL_CALL_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct CallParameter { + OpParameter op_parameter_; + bool is_tail_call; +} CallParameter; + +#endif // NNACL_CALL_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/clip_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/clip_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..8f3fbf64abf45c47a8b632e8ffe0f7a7d10fb3cb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/clip_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_CLIP_PARAMETER_H_ +#define NNACL_CLIP_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct ClipParameter { + OpParameter op_parameter_; + float min_val_; + float max_val_; +} ClipParameter; + +#endif // NNACL_CLIP_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/common_func.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/common_func.c new file mode 100644 index 0000000000000000000000000000000000000000..cad7ea1914abf5de333922550cb5fdaf8c8fb2d0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/common_func.c @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/common_func.h" + +int Offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3) { + return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3] + dim3; +} + +int64_t OffsetComm(const int *shape, const int dim0, const int dim1, const int dim2) { + return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3]; +} + +int Offset4d(const int *shape, const int *dims) { return Offset(shape, dims[0], dims[1], dims[2], dims[3]); } + +int64_t Offset6d(const int *shape, const int *dims) { + return ((OffsetComm(shape, dims[0], dims[1], dims[2]) + dims[3]) * shape[4] + dims[4]) * shape[5]; +} + +int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); } + +int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); } diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/common_func.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/common_func.h new file mode 100644 index 0000000000000000000000000000000000000000..7463a30f50eb0718d37047d93ef0f6f999c14958 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/common_func.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_COMMON_FUNC_H_ +#define MINDSPORE_NNACL_COMMON_FUNC_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/nnacl_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int8_t MinInt8(int8_t a, int8_t b); +int8_t MaxInt8(int8_t a, int8_t b); +int Offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3); +int64_t OffsetComm(const int *shape, const int dim0, const int dim1, const int dim2); +int Offset4d(const int *shape, const int *dims); +int64_t Offset6d(const int *shape, const int *dims); + +static inline bool isAddOverflow(int32_t x, int32_t y) { + int32_t sum = x + y; + return (x > 0 && y > 0 && sum < 0) || (x < 0 && y < 0 && sum > 0); +} + +static inline bool isMulOverflow(int32_t x, int32_t y) { + int32_t p = x * y; + return (x != 0) && (p / x != y); +} + +static inline int GetStride(int *strides, const int *shape, int length) { + if (length <= 0) { + return 1; + } + int stride = 1; + for (int i = length - 1; i >= 0; --i) { + strides[i] = stride; + stride *= shape[i]; + } + return stride; +} +#ifdef __cplusplus +} +#endif + +#endif /* MINDSPORE_NNACL_COMMON_FUNC_H_ */ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/concat_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/concat_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..c902be3987b843cf9588a0494b11691b1d2fd945 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/concat_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_CONCAT_PARAMETER_H_ +#define MINDSPORE_NNACL_CONCAT_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct ConcatParameter { + OpParameter op_parameter_; + ConcatQuantArg quant_arg_; + int axis_; +} ConcatParameter; + +#endif // MINDSPORE_NNACL_CONCAT_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/constant_of_shape_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/constant_of_shape_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..54174e88765bb19220ea071b712c4f9bdf1de54c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/constant_of_shape_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_CONSTANT_OF_SHAPE_PARAMETER_H_ +#define NNACL_CONSTANT_OF_SHAPE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ConstantOfShapeParameter { + OpParameter op_parameter_; + union value_ { + float f32_value_; + int32_t int32_value_; + bool bool_value_; + } value_; + int data_type_; + int element_size_; +} ConstantOfShapeParameter; + +#endif // NNACL_CONSTANT_OF_SHAPE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/conv3d_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/conv3d_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..edaa02e65f959bce674f649710a2cc622cefb5cf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/conv3d_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_CONV3D_PARAMETER_H_ +#define NNACL_CONV3D_PARAMETER_H_ + +#include +#include "nnacl_c/op_base.h" + +typedef struct Conv3DParameter { + OpParameter op_parameter_; +} Conv3DParameter; + +#endif // NNACL_CONV3D_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/conv_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/conv_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..a3c301a507169d726a6ceba24da5ad90631e14c3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/conv_parameter.h @@ -0,0 +1,169 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_CONV_PARAMETER_H_ +#define NNACL_CONV_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct ConvParameter { + OpParameter op_parameter_; + ConvQuantArg conv_quant_arg_; + + int kernel_h_; + int kernel_w_; + int stride_h_; + int stride_w_; + int dilation_h_; + int dilation_w_; + int pad_u_; + int pad_d_; + int pad_l_; + int pad_r_; + int group_; + int tile_num_; /* # */ + int input_batch_; /* # */ + int input_h_; /* # */ + int input_w_; /* # */ + int input_channel_; + int output_batch_; /* # */ + int output_h_; /* # */ + int output_w_; /* # */ + int output_channel_; + int thread_num_; /* # */ + int input_unit_; /* # */ + int output_unit_; /* # */ + PadType pad_mode_; + ActType act_type_; + int channel_multiplie_; /* # */ + int output_padding_w_; /* # */ + int output_padding_h_; /* # */ + int out_format_; + + bool dynamic_shape_; +} ConvParameter; + +typedef struct ConvComputeParam { + int kernel_h_; + int kernel_w_; + int stride_h_; + int stride_w_; + int dilation_h_; + int dilation_w_; + int pad_u_; + int pad_d_; + int pad_l_; + int pad_r_; + + int in_n_; + int in_h_; + int in_w_; + int in_c_; + int out_n_; + int out_h_; + int out_w_; + int out_c_; + + int in_hw_; + int out_hw_; + int kernel_hw_; + int tile_num_; +} ConvComputeParam; + +typedef struct SlidingWindowParam { + int left_; + int right_; + int top_; + int bottom_; + int c_block_; + int block_channel_; + int ic_align_; + int out_step_; + int out_h_step_; + int out_c_step_; + int out_w_step_; + int out_block_step_; + int in_step_; + int in_h_step_; + int in_sh_step_; // stride H + int in_sw_step_; // stride W + int in_kh_step_; // kernel H + int in_kw_step_; // kernel W + int kernel_step_; +} SlidingWindowParam; + +typedef struct ConvDwCalcParam { + void *num_pixels_; + void *out_w_start_; + void *out_w_end_; + int first_calc_kw_; +} ConvDwCalcParam; + +#define OUPUT_UNIT 2 +#define DECONV_WINOGRAD_DEFAULT_UNIT 3 /* # */ +#define DECONV_WINOGRAD_DEFAULT_TILE 8 /* # */ +#define DECONV_WINOGRAD_BUFFER_COUNT 8 /* # */ +typedef struct DeConvWg { /* # */ + void *b_buffer_; + void *AT_; + void *BT_; + + int kh_; + int kw_; + + int k_; + int i_; + int o_; +} DeConvWg; + +typedef struct DeConvWgABuffer { /* # */ + bool buf_init_; + void *middle_buffer_; + void *dest_buffer_; +} DeConvWgABuffer; + +typedef struct DeConvComputeUnit { /* # */ + void *weight_; + void *tmp_buffer_; + int w_start_; + int h_start_; + int w_size_; + int h_size_; + bool use_winograd_; + DeConvWg winograd_; +} DeConvComputeUnit; + +typedef struct DeConvParam { /* # */ + DeConvComputeUnit *compute_units_; + int compute_size_; + DeConvWgABuffer a_buffer_[DECONV_WINOGRAD_BUFFER_COUNT]; + int input_plane_; + int output_plane_; + int kernel_plane_; + int ic_div_; + int oc_div_; + int ic_up_; + int oc_up_; + int thread_num_; + int in_tile_count_; + int in_tile_h_count_; + int in_tile_w_count_; + int out_tile_h_; + int out_tile_w_; +} DeConvParam; + +#endif // NNACL_CONV_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/crop_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/crop_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..ebae6d016defaab1273feeea1351bfb9fd1b051a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/crop_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_CROP_PARAMETER_H_ +#define NNACL_CROP_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct CropParameter { + OpParameter op_parameter_; + int64_t axis_; + int offset_size_; + int64_t offset_[COMM_SHAPE_SIZE]; +} CropParameter; + +#endif // NNACL_CROP_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/cumsum_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/cumsum_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..1e58b8dc5cbbe7c7d129d67706b65804af4934df --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/cumsum_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_CUMSUM_PARAMETER_H_ +#define NNACL_CUMSUM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct CumSumParameter { + OpParameter op_parameter_; + bool reverse_; + bool exclusive_; + int axis_; +} CumsumParameter; + +#endif // NNACL_CUMSUM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_gru_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_gru_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..e3970e7663ae7b624591b7b5555e23ec14921c05 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_gru_parameter.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_CUSTOM_GRU_PARAMETER_H_ +#define NNACL_CUSTOM_GRU_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct CustomGruParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int num_step; + int batch_size; + int input_size; + int hidden_size; +} CustomGruParameter; + +#endif // NNACL_CUSTOM_GRU_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_is_inf_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_is_inf_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..3a3034040bc67c6fe9c9b1dc3cddc71489bde6b6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_is_inf_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_ +#define MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct CustomIsInfParameter { + // Primitive parameter + OpParameter op_parameter_; +} CustomIsInfParameter; + +#endif // MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_masked_fill_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_masked_fill_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..81f68ef8fe9b341b315c186f93ded99cf193fdd2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_masked_fill_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_ +#define MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct CustomMaskedFillParameter { + // Primitive parameter + OpParameter op_parameter_; +} CustomMaskedFillParameter; + +#endif // MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..0e56cdf39b7f5a1509cacb986ccf0b6ecfb90170 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_CUSTOM_PARAMETER_H_ +#define NNACL_CUSTOM_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +#define MAX_STR_LEN 64 +#define MAX_ATTR_NUM 8 + +typedef struct CustomParameter { + OpParameter op_parameter_; + char type[MAX_STR_LEN]; + char attr_name[MAX_ATTR_NUM][MAX_STR_LEN]; + char *attr_data[MAX_ATTR_NUM]; + int attr_num; +} CustomParameter; +#endif // NNACL_CUSTOM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/depth_to_space_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/depth_to_space_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..571e22dbe507d35ede52a66cf59e55f916413727 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/depth_to_space_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_DEPTH_TO_SPACE_PARAMETER_H_ +#define NNACL_DEPTH_TO_SPACE_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +typedef struct DepthToSpaceParameter { + OpParameter op_parameter_; + int32_t block_size_; + int32_t mode_; +} DepthToSpaceParameter; + +#endif // NNACL_DEPTH_TO_SPACE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/detection_post_process_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/detection_post_process_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..74828b375d70aeb937ca6a34d8b67e2f6bbb9e2d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/detection_post_process_parameter.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_DETECTION_POST_PROCESS_PARAMETER_H_ +#define NNACL_DETECTION_POST_PROCESS_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +typedef struct DetectionPostProcessParameter { + OpParameter op_parameter_; + float h_scale_; + float w_scale_; + float x_scale_; + float y_scale_; + float nms_iou_threshold_; + float nms_score_threshold_; + int64_t max_detections_; + int64_t detections_per_class_; + int64_t max_classes_per_detection_; + int64_t num_classes_; + bool use_regular_nms_; + bool out_quantized_; + + float *anchors_; + + void *decoded_boxes_; + void *nms_candidate_; + void *indexes_; + void *scores_; + void *all_class_indexes_; + void *all_class_scores_; + void *single_class_indexes_; + void *selected_; +} DetectionPostProcessParameter; + +#endif // NNACL_DETECTION_POST_PROCESS_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/dynamic_quant_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/dynamic_quant_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..978a0365a93122858c8df44be33c9b554dd12081 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/dynamic_quant_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_DYNAMIC_QUANT_PARAMETER_H_ +#define NNACL_DYNAMIC_QUANT_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +typedef struct DynamicQuantParameter { + OpParameter op_parameter_; + bool symmetric_; + int dst_type_; + int axis_num_; + int prefer_axes_[MAX_SHAPE_SIZE]; +} DynamicQuantParameter; + +#endif // NNACL_DYNAMIC_QUANT_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/errorcode.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/errorcode.c new file mode 100644 index 0000000000000000000000000000000000000000..7e1a284400f776de4ea528cafdb4b981d3f424c7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/errorcode.c @@ -0,0 +1,46 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/errorcode.h" +#include + +void InitNNACLKernelErrorCode(char **nnacl_kernel_error_msg) { + nnacl_kernel_error_msg[NNACL_CROP_AND_RESIZE_BOX_IDX_INVALID] = + "In CropAndResize, the value of box idx should match: [0, batch)."; + nnacl_kernel_error_msg[NNACL_WHERE_INPUT_NUM_INVALID] = "Invalid input number. Where op input number support 1 or 3."; + nnacl_kernel_error_msg[NNACL_WHERE_CONDITION_DATA_TYPE_ERROR] = + "Invalid input data type. Where op input data type support int32 fp32 and bool."; + nnacl_kernel_error_msg[NNACL_WHERE_CONDITION_NUM_INVALID] = + "The length of three inputs are not equal to 1 or length of output, which is unacceptable."; + nnacl_kernel_error_msg[NNACL_WHERE_INVALID_OUT_NUM] = "The element number invalid."; + nnacl_kernel_error_msg[NNACL_WHERE_NUM_MAX_INVALID] = "Inputs' length are zero"; + nnacl_kernel_error_msg[NNACL_ERR] = "NNACL common error."; +} + +char *NNACLErrorMsg(int error_code) { + static char nnacl_kernel_error_msg[NNACL_COMMON_END][MAX_MSG_LEN]; + static bool inited = false; + if (!inited) { + inited = true; + InitNNACLKernelErrorCode((char **)nnacl_kernel_error_msg); + } + + if (error_code > NNACL_OK && error_code < NNACL_COMMON_END) { + return nnacl_kernel_error_msg[error_code]; + } + + return "NNACL execute error!"; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/errorcode.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/errorcode.h new file mode 100644 index 0000000000000000000000000000000000000000..a7c6190b2534bfb212c4cec997b009b8eb5c7039 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/errorcode.h @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ERRORCODE_H_ +#define NNACL_ERRORCODE_H_ + +#include + +#define MAX_MSG_LEN 256 + +typedef enum ErrorCodeCommonEnum { + NNACL_OK = 0, + NNACL_ERR = 1, + NNACL_NULL_PTR, + NNACL_PARAM_INVALID, + NNACL_INFER_INVALID, + NNACL_INPUT_TENSOR_ERROR, + NNACL_OUTPUT_TENSOR_ERROR, + NNACL_INPUT_OUTPUT_DATA_TYPE_UNMATCH, + NNACL_FORMAT_ERROR, + NNACL_BUFFER_OVERFLOW, + NNACL_TENSOR_SIZE_INVALID, + NNACL_UNSUPPORTED_DATA_TYPE, + NNACL_UNSUPPORTED_FORMAT, + NNACL_MALLOC_BUFFER_FAILED, + NNACL_MALLOC_SIZE_INVALID, + NNACL_DISABLE_FP16, + NNACL_ADDN_SHAPE_UNMATCH, + NNACL_ACTIVATION_TYPE_INVALID, + NNACL_ARITHMETIC_DATA_TYPE_UNMATCH, + NNACL_ARITHMETIC_SHAPE_INVALID, + NNACL_ARITHMETIC_SELF_DATA_TYPE_UNSUPPORT, + NNACL_ARG_MIN_MAX_AXIS_INVALID, + NNACL_BIAS_ADD_SHAPE_NOT_MATCH, + NNACL_BIAS_ADD_SHAPE_OVERFLOW, + NNACL_BATCH_TO_SPACE_BLOCK_SHAPE_INVALID, + NNACL_BATCH_TO_SPACE_CROP_INVALID, + NNACL_BATCH_NORM_CHANNEL_SHAPE_INVALID, + NNACL_CLIP_DATA_TYPE_INVALID, + NNACL_CLIP_MINMAX_VALUE_INVALID, + NNACL_CONCAT_AXIS_INVALID, + NNACL_CONCAT_F16_INVALID_DATA_TYPE, + NNACL_CONCAT_F16_OUTPUT_DATA_INVALID, + NNACL_CONCAT_SHAPE_INVALID, + NNACL_CONVOLUTION_INPUT_CHANNEL_UNMATCH, + NNACL_CONVOLUTION_INPUT_HW_OVERFLOW, + NNACL_CONVOLUTION_KERNEL_HW_OVERFLOW, + NNACL_CONVOLUTION_OUTPUT_HW_OVERFLOW, + NNACL_CONVOLUTION_WEIGHT_DATATYPE_INVALID, + NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID, + NNACL_CONVOLUTION_AVX512_UNSUPPORT_FORMAT, + NNACL_CONVOLUTION_WEIGHT_DATA_INVALID, + NNACL_CONVOLUTION_BIAS_DATATYPE_INVALID, + NNACL_CROP_AND_RESIZE_BOX_IDX_INVALID, + NNACL_DECONV_RESIZE_OC_INVALID, + NNACL_DECONVOLUTION_DEPTHWISE_DILATION_INVALID, + NNACL_DECONVOLUTION_DEPTHWISE_STRIDE_INVALID, + NNACL_DECONVOLUTION_DEPTHWISE_INVALID_WEIGHT_SHAPE, + NNACL_DECONVOLUTION_DEPTHWISE_INVALID_WEIGHT_REPACK, + NNACL_DECONVOLUTION_DEPTHWISE_CHANNEL_INVALID, + NNACL_DEPTH_TO_SPACE_INVALID_MODE, + NNACL_ELTWISE_INVALID_MOD, + NNACL_FILL_DATA_TYPE_INVALID, + NNACL_FUSED_BATCH_NORM_NO_CHANGE, + NNACL_FUSED_BATCH_DATA_TYPE_INVALID, + NNACL_FUSED_BATCH_NORM_TO_SCALE_FAILED, + NNACL_FUSED_BATCH_TRAIN_DATA_INVALID, + NNACL_FUSED_BATCH_TRAIN_PARAM_DATA_INVALID, + NNACL_GATHER_INDICES_DATA_TYPE_INVALID, + NNACL_GATHER_INDICES_VALUE_INVALID, + NNACL_GATHER_AXIS_INVALID, + NNACL_GATHER_INPUT_TENSOR_INVALID, + NNACL_GATHER_OUTPUT_TENSOR_INVALID, + NNACL_GATHER_D_AXIS_INVALID, + NNACL_GATHER_ND_COUNT_INVALID, + NNACL_GATHER_ND_INDICES_RANK_INVALID, + NNACL_GATHER_ND_INDICES_SHAPE_INVALID, + NNACL_GROUP_CONVOLUTION_GROUP_INVALID, + NNACL_GATHER_ND_INDICES_DATA_TYPE_INVALID, + NNACL_GROUP_CONVOLUTION_SHAPE_INVALID, + NNACL_GROUP_NORM_NUM_GROUPS_INVALID, + NNACL_GROUP_NORM_SHAPE_SIZE_INVALID, + NNACL_GROUP_NORM_FORMAT_INVALID, + NNACL_SOFTMAX_AXIS_INVALID, + NNACL_MATMUL_ACT_TYPE_INVALID, + NNACL_MATMUL_BIAS_INVALID, + NNACL_NON_ZERO_SHAPE_INVALID, + NNACL_NON_MAX_SUPPRESSION_TENSOR_SIZE_INVALID, + NNACL_NON_MAX_SUPPRESSION_PARAM_INVALID, + NNACL_NON_MAX_SUPPRESSION_BOX_DIMS_INVALID, + NNACL_NON_MAX_SUPPRESSION_BOX_DIMS_SCORE_UNMATCH, + NNACL_NON_MAX_SUPPRESSION_DIMENSION_SPATIAL_UNMATCH, + NNACL_NON_MAX_SUPPRESSION_UNSUPPORT_DEFINE_DATA, + NNACL_NON_MAX_SUPPRESSION_OUTPUT_SIZE_UNMATCH, + NNACL_ONE_HOT_AXIS_INVALID, + NNACL_ONE_HOT_OUTER_SIZE_INVALID, + NNACL_ONE_HOT_INNER_SIZE_INVALID, + NNACL_ONE_HOR_DEPTH_TENSOR_DATA_TYPE_INVALID, + NNACL_ONE_HOR_ON_VALUE_TENSOR_DATA_TYPE_INVALID, + NNACL_ONE_HOR_OFF_VALUE_TENSOR_DATA_TYPE_INVALID, + NNACL_ONE_HOR_ON_OFF_VALUE_TENSOR_DATA_TYPE_INVALID, + NNACL_PAD_SHAPE_INVALID, + NNACL_PAD_PADDING_VALID_INVALID, + NNACL_PAD_MIRROR_PAD_SIZE_INVALID, + NNACL_POW_INVALID_DATA_TYPE, + NNACL_PRELU_SLOPE_NUM_INVALID, + NNACL_PRIOR_BOX_VALUE_INVALID, + NNACL_PRIOR_BOX_RATIO_INVALID, + NNACL_LOCAL_RESPONSE_NORM_SHAPE_INVALID, + NNACL_LOCAL_RESPONSE_NORM_DEPTH_RADIUS_INVALID, + NNACL_LAYER_NORM_OUTPUT_NUM_INVALID, + NNACL_REDUCE_AXIS_SIZE_ERROR, + NNACL_REDUCE_AXES_TENSOR_ERROR, + NNACL_REDUCE_UNSUPPORTED_DATA_TYPE, + NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID, + NNACL_REDUCE_COEFF_DATA_TYPE_INVALID, + NNACL_REVERSE_AXIS_INVALID, + NNACL_REVERSE_AXIS_VALUE_INVALID, + NNACL_REVERSE_DATA_SIZE_INVALID, + NNACL_REVERSE_NUM_AXIS_INVALID, + NNACL_SCALE_AXIS_AND_SHAPE_UNMATCH, + NNACL_SCALE_UNSUPPORT_ACT_TYPE, + NNACL_SCALE_SCALE_SHAPE_UNMATCH, + NNACL_SCALE_INPUT_NUM_INVALID, + NNACL_STACK_TENSOR_SHAPE_INVALID, + NNACL_STRIDED_SLICE_INVALID_SHAPE_SIZE, + NNACL_STRIDED_SLICE_INVALID_DATA_SIZE, + NNACL_STRIDED_SLICE_UNSUPPORTED_DATA_TYPE, + NNACL_STRIDED_SLICE_INVALID_PARALLEL_MOD, + NNACL_STRIDED_SLICE_UNSUPPORTED_MAX_8D, + NNACL_SPLICE_SHAPE_INVALID, + NNACL_TILE_INPUT_SHAPE_INVALID, + NNACL_TILE_SECOND_INPUT_NUM_INVALID, + NNACL_TILE_SECOND_INPUT_VALUE_INVALID, + NNACL_TILE_SECOND_INPUT_DATA_TYPE_INVALID, + NNACL_TILE_RESIZE_IN_RUNTIME_FAILED, + NNACL_TRIU_TRIL_INPUT_SHAPE_ERROR, + NNACL_TRIU_K_TENSOR_DATA_TYPE_INVALID, + NNACL_TRIU_INPUT_DIMS_INVALID, + NNACL_TRANSPOSE_INSHAPE_OUT_OF_RANGE, + NNACL_TRANSPOSE_INPUT_TENSOR_NUM_INVALID, + NNACL_TRANSPOSE_INPUT_TENSOR_VALUD_INVALID, + NNACL_TRANSPOSE_PERM_DIMS_INVALID, + NNACL_TRANSPOSE_PERM_TENSOR_INVALID, + NNACL_TRANSPOSE_PERM_TENSOR_VALUE_INVALID, + NNACL_TRANSPOSE_PERM_DELETE_DIMENSION_FAILED, + NNACL_WHERE_INPUT_NUM_INVALID, + NNACL_WHERE_CONDITION_DATA_TYPE_ERROR, + NNACL_WHERE_CONDITION_NUM_INVALID, + NNACL_WHERE_INVALID_OUT_NUM, + NNACL_WHERE_NUM_MAX_INVALID, + NNACL_WHERE_BROAD_CAST_FAILED, + NNACL_COMMON_END +} ErrorCodeCommonEnum; + +typedef enum ErrorCodeFp32OpEnum { + NNACL_ERRCODE_OP_FP32_START = 10000, + NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC, + NNACL_ERRCODE_REVERSE_MALLOC, + NNACL_ERRCODE_SQRT_NEGATIVE, + NNACL_ERRCODE_RSQRT_NEGATIVE, + NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO, + NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO, + NNACL_ERRCODE_DIVISOR_ZERO, + NNACL_ERRCODE_INDEX_OUT_OF_RANGE, + NNACL_ERRCODE_WINOGRAD_GENERATOR_ERROR, + NNACL_ERRCODE_OP_FP32_END = 19999 +} ErrorCodeFp32OpEnum; + +typedef enum ErrorCodeFp16OpEnum { + NNACL_ERRCODE_OP_FP16_START = 20000, + NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR, + NNACL_ERRCODE_OP_FP16_END = 29999 +} ErrorCodeFp16OpEnum; + +typedef enum ErrorCodeUint8OpEnum { + NNACL_ERRCODE_OP_UINT8_START = 30000, + NNACL_ERRCODE_OP_UINT8_END = 39999 +} ErrorCodeUint8OpEnum; + +typedef enum ErrorCodeInt8OpEnum { + NNACL_ERRCODE_OP_INT8_START = 40000, + NNACL_ERRCODE_ADD_OVERFLOW, + NNACL_ERRCODE_MUL_OVERFLOW, + NNACL_ERRCODE_OP_INT8_END = 49999 +} ErrorCodeInt8OpEnums; + +#ifdef __cplusplus +extern "C" { +#endif +char *NNACLErrorMsg(int error_code); +#ifdef __cplusplus +} +#endif +#endif // NNACL_ERRORCODE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/exp_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/exp_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..1c1dd12e1470a78c32684def6a725120493ec9a9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/exp_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_EXP_PARAMETER_H_ +#define NNACL_EXP_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ExpParameter { + OpParameter op_parameter_; + float base_; + float scale_; + float shift_; +} ExpParameter; + +#endif // NNACL_EXP_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a3c760c1e4d4398a76cf9ab33242df00ed647989 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x16_kernel_nhwc_fp32.c @@ -0,0 +1,533 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_10x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6), [dst_9] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_9])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [dst_9] "r"(dst_9), [src_3] "r"(src_3), [src_6] "r"(src_6), + [src_9] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..f59b52dbec70921c920dc2d38a67ece3fe766ca2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x32_kernel_nhwc_fp32.c @@ -0,0 +1,781 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_10x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6), [dst_9] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 0(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 4(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 8(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 12(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 16(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 20(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 24(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 28(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 32(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 36(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 40(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 44(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 48(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 52(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 56(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 60(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 0(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [dst_9] "r"(dst_9), [src_3] "r"(src_3), [src_6] "r"(src_6), + [src_9] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c6b24cbac9afae95c1521ddffd1ad0c6ac0a7dd4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x16_kernel_nhwc_fp32.c @@ -0,0 +1,573 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_11x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm10\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6), [dst_9] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_9])\n" + "vmovups %%zmm10, 0(%[dst_9], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [dst_9] "r"(dst_9), [src_3] "r"(src_3), [src_6] "r"(src_6), + [src_9] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..9452a5d3125cb25425d49dcd61527873b75ba165 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x32_kernel_nhwc_fp32.c @@ -0,0 +1,844 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_11x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_9], %[dst_stride], 1), %%zmm21\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6), [dst_9] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 4(%[src_6]), %%zmm29\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_9]), %%zmm26\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 8(%[src_6]), %%zmm29\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_9]), %%zmm26\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 12(%[src_6]), %%zmm29\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_9]), %%zmm26\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 16(%[src_6]), %%zmm29\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_9]), %%zmm26\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 20(%[src_6]), %%zmm29\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_9]), %%zmm26\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 24(%[src_6]), %%zmm29\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_9]), %%zmm26\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 28(%[src_6]), %%zmm29\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_9]), %%zmm26\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 32(%[src_6]), %%zmm29\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_9]), %%zmm26\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 36(%[src_6]), %%zmm29\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_9]), %%zmm26\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 40(%[src_6]), %%zmm29\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_9]), %%zmm26\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 44(%[src_6]), %%zmm29\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_9]), %%zmm26\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 48(%[src_6]), %%zmm29\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_9]), %%zmm26\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 52(%[src_6]), %%zmm29\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_9]), %%zmm26\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 56(%[src_6]), %%zmm29\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_9]), %%zmm26\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 60(%[src_6]), %%zmm29\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_9]), %%zmm26\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9])\n" + "vmovups %%zmm20, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_9], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [dst_9] "r"(dst_9), [src_3] "r"(src_3), [src_6] "r"(src_6), + [src_9] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..14fa99e36bb636dab2a7b4f05b0e607c450ca7ea --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x16_kernel_nhwc_fp32.c @@ -0,0 +1,614 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_12x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm10\n" + "vmovups 0(%[dst_9], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 0(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6), [dst_9] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 4(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 8(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 12(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 16(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 20(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 24(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 28(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 32(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 36(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 40(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 44(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 48(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 52(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 56(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 60(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_9])\n" + "vmovups %%zmm10, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm11, 0(%[dst_9], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [dst_9] "r"(dst_9), [src_3] "r"(src_3), [src_6] "r"(src_6), + [src_9] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..6975229c5ab4c0e2fb4c8b088c2890f95d087b27 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c @@ -0,0 +1,908 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_12x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_9], %[dst_stride], 1), %%zmm21\n" + "vmovups 0(%[dst_9], %[dst_stride], 2), %%zmm22\n" + "vmovups 64(%[dst_9], %[dst_stride], 2), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 0(%[bias]), %%zmm22\n" + "vmovups 64(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6), [dst_9] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 4(%[src_6]), %%zmm29\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_9]), %%zmm26\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 8(%[src_6]), %%zmm29\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_9]), %%zmm26\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 12(%[src_6]), %%zmm29\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_9]), %%zmm26\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 16(%[src_6]), %%zmm29\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_9]), %%zmm26\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 20(%[src_6]), %%zmm29\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_9]), %%zmm26\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 24(%[src_6]), %%zmm29\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_9]), %%zmm26\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 28(%[src_6]), %%zmm29\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_9]), %%zmm26\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 32(%[src_6]), %%zmm29\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_9]), %%zmm26\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 36(%[src_6]), %%zmm29\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_9]), %%zmm26\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 40(%[src_6]), %%zmm29\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_9]), %%zmm26\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 44(%[src_6]), %%zmm29\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_9]), %%zmm26\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 48(%[src_6]), %%zmm29\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_9]), %%zmm26\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 52(%[src_6]), %%zmm29\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_9]), %%zmm26\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 56(%[src_6]), %%zmm29\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_9]), %%zmm26\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 60(%[src_6]), %%zmm29\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_9]), %%zmm26\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9])\n" + "vmovups %%zmm20, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm22, 0(%[dst_9], %[dst_stride], 2)\n" + "vmovups %%zmm23, 64(%[dst_9], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [dst_9] "r"(dst_9), [src_3] "r"(src_3), [src_6] "r"(src_6), + [src_9] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..5b51eb3fcbf6b06768bb75a8e59040dedc21e53d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x16_kernel_nhwc_fp32.c @@ -0,0 +1,158 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..788591996f5df56f0497462b9d27ffbccd158460 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c @@ -0,0 +1,198 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c2b38c19f6261d43e5e234629f2714a84d17a037 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x48_kernel_nhwc_fp32.c @@ -0,0 +1,238 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a7c32c93add8a1b6ddaaeef5997c88fc19b1993e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c @@ -0,0 +1,278 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a6a2faebb0333d0b39f13c4378cb3d92b12559af --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c @@ -0,0 +1,318 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c9a2b59e9715d2b6818e2dc0a547264f323ac4db --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c @@ -0,0 +1,358 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8db14458ca33249abecf59f2fa23d9c609217ccf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x16_kernel_nhwc_fp32.c @@ -0,0 +1,198 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..667f8698296da1951dd12192b8ce5eb7c43d3e8b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c @@ -0,0 +1,261 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..b494d3165698dbbfd6078ad88e4c0dfed9a04cfb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x48_kernel_nhwc_fp32.c @@ -0,0 +1,324 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..99cce4e20a7c3649ee3f770fe245b96ff808aa78 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c @@ -0,0 +1,387 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..f396dce2ac3221947d072093fe7aeab117da7385 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c @@ -0,0 +1,450 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7db94d09f13b3910c984d4ebe31a554134dc5292 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c @@ -0,0 +1,513 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d576a9868203cdc2aae6e485b9f0510c7645b558 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x16_kernel_nhwc_fp32.c @@ -0,0 +1,238 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..ffe63a35c2b12b437ea8b1536c3a0d15542a8724 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c @@ -0,0 +1,324 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..e94ec339d68935f22f35f35e10db99a49a3d8834 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x48_kernel_nhwc_fp32.c @@ -0,0 +1,410 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..976212df11de52e12fd79561862489f69ff51199 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c @@ -0,0 +1,496 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..42c46be60295bfd5f4dfcdebbd4142eaee9f7889 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c @@ -0,0 +1,583 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..450dd072521ebbc19a2104ab5a54e00171c67ab0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c @@ -0,0 +1,669 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n" + "vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 256(%[bias]), %%zmm16\n" + "vmovups 320(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1c592474bb6740f4b7f9c1f7138500323ebdd89b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x16_kernel_nhwc_fp32.c @@ -0,0 +1,283 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..527bcd9f10723c220435ea7bf609fbd5b3ff23da --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c @@ -0,0 +1,392 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..0443d2169703b23ba8f2ab607cd1e3decc3e26b6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x48_kernel_nhwc_fp32.c @@ -0,0 +1,501 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a88f43292c9323e78146e750a0e5d173a0ef0213 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c @@ -0,0 +1,611 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..10d52e6fb5bce0f684a5952de6c57957cce6bdb8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c @@ -0,0 +1,720 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 0(%[dst_3]), %%zmm15\n" + "vmovups 64(%[dst_3]), %%zmm16\n" + "vmovups 128(%[dst_3]), %%zmm17\n" + "vmovups 192(%[dst_3]), %%zmm18\n" + "vmovups 256(%[dst_3]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 192(%[bias]), %%zmm18\n" + "vmovups 256(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 0(%[dst_3])\n" + "vmovups %%zmm16, 64(%[dst_3])\n" + "vmovups %%zmm17, 128(%[dst_3])\n" + "vmovups %%zmm18, 192(%[dst_3])\n" + "vmovups %%zmm19, 256(%[dst_3])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..48f890f52f6608dadd6b992ecc9bfce4d25e68c9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c @@ -0,0 +1,830 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n" + "vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_3]), %%zmm18\n" + "vmovups 64(%[dst_3]), %%zmm19\n" + "vmovups 128(%[dst_3]), %%zmm20\n" + "vmovups 192(%[dst_3]), %%zmm21\n" + "vmovups 256(%[dst_3]), %%zmm22\n" + "vmovups 320(%[dst_3]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 256(%[bias]), %%zmm16\n" + "vmovups 320(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "vmovups 192(%[bias]), %%zmm21\n" + "vmovups 256(%[bias]), %%zmm22\n" + "vmovups 320(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_3])\n" + "vmovups %%zmm19, 64(%[dst_3])\n" + "vmovups %%zmm20, 128(%[dst_3])\n" + "vmovups %%zmm21, 192(%[dst_3])\n" + "vmovups %%zmm22, 256(%[dst_3])\n" + "vmovups %%zmm23, 320(%[dst_3])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..097ffc2c3ebfe45d235962286a8dbfc20dcb023c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x16_kernel_nhwc_fp32.c @@ -0,0 +1,323 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..4c0c43156b207f94f2ea6f75e12f234728376342 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c @@ -0,0 +1,455 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..e5d0f28359677b9995b00f58b82d1ba2542a774c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x48_kernel_nhwc_fp32.c @@ -0,0 +1,588 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3])\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..fc0a6bc40c7e4601c08fdbfe166f20b773cc4dd3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c @@ -0,0 +1,720 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 128(%[bias]), %%zmm18\n" + "vmovups 192(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3])\n" + "vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..9ec379130493a4c7d72b2896e22b19b1e5e71a20 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c @@ -0,0 +1,853 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 0(%[dst_3]), %%zmm15\n" + "vmovups 64(%[dst_3]), %%zmm16\n" + "vmovups 128(%[dst_3]), %%zmm17\n" + "vmovups 192(%[dst_3]), %%zmm18\n" + "vmovups 256(%[dst_3]), %%zmm19\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm21\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm22\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm23\n" + "vmovups 256(%[dst_3], %[dst_stride], 1), %%zmm24\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 192(%[bias]), %%zmm18\n" + "vmovups 256(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 128(%[bias]), %%zmm22\n" + "vmovups 192(%[bias]), %%zmm23\n" + "vmovups 256(%[bias]), %%zmm24\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + "vxorps %%zmm24, %%zmm24, %%zmm24\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "vmaxps %%zmm24, %%zmm31, %%zmm24\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + "vminps %%zmm24, %%zmm30, %%zmm24\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 0(%[dst_3])\n" + "vmovups %%zmm16, 64(%[dst_3])\n" + "vmovups %%zmm17, 128(%[dst_3])\n" + "vmovups %%zmm18, 192(%[dst_3])\n" + "vmovups %%zmm19, 256(%[dst_3])\n" + "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm24, 256(%[dst_3], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..05f6f84991b95ebfec5aec21c5dd032cb4e365b7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x16_kernel_nhwc_fp32.c @@ -0,0 +1,363 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..f59cc5f0a645064e48e3a28c8f24b7998c23af31 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c @@ -0,0 +1,518 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..f083e9940bf3a8bbeb7a7c28b0df81fb735e395b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x48_kernel_nhwc_fp32.c @@ -0,0 +1,674 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3])\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..71d8d74d1dbf14055fa1bc447771582ef0e54741 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c @@ -0,0 +1,830 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm20\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm21\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm22\n" + "vmovups 192(%[dst_3], %[dst_stride], 2), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 128(%[bias]), %%zmm18\n" + "vmovups 192(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 128(%[bias]), %%zmm22\n" + "vmovups 192(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3])\n" + "vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a53f0dc4f5b4e756335631c6a3ffaa5ba152b139 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x16_kernel_nhwc_fp32.c @@ -0,0 +1,408 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..9d4a669467bf8a5632d4d76dd803307f3d2c3819 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c @@ -0,0 +1,587 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..cc4b94c3aed1b51876cc081a32644643e8fa9087 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x48_kernel_nhwc_fp32.c @@ -0,0 +1,765 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_6]), %%zmm18\n" + "vmovups 64(%[dst_6]), %%zmm19\n" + "vmovups 128(%[dst_6]), %%zmm20\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 0(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 4(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 8(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 12(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 16(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 20(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 24(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 28(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 32(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 36(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 40(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 44(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 48(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 52(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 56(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 60(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 0(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3])\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_6])\n" + "vmovups %%zmm19, 64(%[dst_6])\n" + "vmovups %%zmm20, 128(%[dst_6])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..fb6fb03711a732f9c3035a643922306091a8dd76 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x16_kernel_nhwc_fp32.c @@ -0,0 +1,448 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..93d37efc9dfec61a6cc461f8214546050ef74fb3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c @@ -0,0 +1,650 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..b04839be8581daea7c9cb2a0faefe47a71bd2162 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x48_kernel_nhwc_fp32.c @@ -0,0 +1,852 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_6]), %%zmm18\n" + "vmovups 64(%[dst_6]), %%zmm19\n" + "vmovups 128(%[dst_6]), %%zmm20\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm21\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm22\n" + "vmovups 128(%[dst_6], %[dst_stride], 1), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "vmovups 0(%[bias]), %%zmm21\n" + "vmovups 64(%[bias]), %%zmm22\n" + "vmovups 128(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_6]), %%zmm27\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_6]), %%zmm27\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_6]), %%zmm27\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_6]), %%zmm27\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_6]), %%zmm27\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_6]), %%zmm27\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_6]), %%zmm27\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_6]), %%zmm27\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_6]), %%zmm27\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_6]), %%zmm27\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_6]), %%zmm27\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_6]), %%zmm27\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_6]), %%zmm27\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_6]), %%zmm27\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_6]), %%zmm27\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_6]), %%zmm27\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_6]), %%zmm27\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3])\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_6])\n" + "vmovups %%zmm19, 64(%[dst_6])\n" + "vmovups %%zmm20, 128(%[dst_6])\n" + "vmovups %%zmm21, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm22, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm23, 128(%[dst_6], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..da7707ae84e341333fbe0c256169fcb0f93af27e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x16_kernel_nhwc_fp32.c @@ -0,0 +1,488 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_9x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..95e0308e03ac2db2b30f90b775a391e9476ef1dd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x32_kernel_nhwc_fp32.c @@ -0,0 +1,713 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_9x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1603526877f885c801b24e6d0fd409340e508ff8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,297 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + __m256 dst9; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + __m256 src90 = _mm256_set1_ps(*(src + 72)); + dst9 = _mm256_fmadd_ps(dst9, src90, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + __m256 src91 = _mm256_set1_ps(*(src + 73)); + dst9 = _mm256_fmadd_ps(dst9, src91, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + __m256 src92 = _mm256_set1_ps(*(src + 74)); + dst9 = _mm256_fmadd_ps(dst9, src92, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + __m256 src93 = _mm256_set1_ps(*(src + 75)); + dst9 = _mm256_fmadd_ps(dst9, src93, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + __m256 src94 = _mm256_set1_ps(*(src + 76)); + dst9 = _mm256_fmadd_ps(dst9, src94, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + __m256 src95 = _mm256_set1_ps(*(src + 77)); + dst9 = _mm256_fmadd_ps(dst9, src95, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + __m256 src96 = _mm256_set1_ps(*(src + 78)); + dst9 = _mm256_fmadd_ps(dst9, src96, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + __m256 src97 = _mm256_set1_ps(*(src + 79)); + dst9 = _mm256_fmadd_ps(dst9, src97, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); + _mm256_store_ps(dst + 0 * src_stride + 72, dst9); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..0470cc51508e7060f30b0d051700c252fe4be953 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,303 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "vmovups 288(%[dst]), %%ymm9\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "vmovaps 0(%[bias]), %%ymm9\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 288(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 289(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 290(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 291(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 292(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 293(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 294(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 295(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + "vmovups %%ymm9, 288(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..dcc963de7ef6fd9c1206b062c5300b8e43a24ccd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,321 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + __m256 dst9; + __m256 dst10; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72); + dst10 = _mm256_load_ps(dst + 0 * dst_stride + 80); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 0); + dst10 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + __m256 src90 = _mm256_set1_ps(*(src + 72)); + dst9 = _mm256_fmadd_ps(dst9, src90, weight00); + __m256 src100 = _mm256_set1_ps(*(src + 80)); + dst10 = _mm256_fmadd_ps(dst10, src100, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + __m256 src91 = _mm256_set1_ps(*(src + 73)); + dst9 = _mm256_fmadd_ps(dst9, src91, weight01); + __m256 src101 = _mm256_set1_ps(*(src + 81)); + dst10 = _mm256_fmadd_ps(dst10, src101, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + __m256 src92 = _mm256_set1_ps(*(src + 74)); + dst9 = _mm256_fmadd_ps(dst9, src92, weight02); + __m256 src102 = _mm256_set1_ps(*(src + 82)); + dst10 = _mm256_fmadd_ps(dst10, src102, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + __m256 src93 = _mm256_set1_ps(*(src + 75)); + dst9 = _mm256_fmadd_ps(dst9, src93, weight03); + __m256 src103 = _mm256_set1_ps(*(src + 83)); + dst10 = _mm256_fmadd_ps(dst10, src103, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + __m256 src94 = _mm256_set1_ps(*(src + 76)); + dst9 = _mm256_fmadd_ps(dst9, src94, weight04); + __m256 src104 = _mm256_set1_ps(*(src + 84)); + dst10 = _mm256_fmadd_ps(dst10, src104, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + __m256 src95 = _mm256_set1_ps(*(src + 77)); + dst9 = _mm256_fmadd_ps(dst9, src95, weight05); + __m256 src105 = _mm256_set1_ps(*(src + 85)); + dst10 = _mm256_fmadd_ps(dst10, src105, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + __m256 src96 = _mm256_set1_ps(*(src + 78)); + dst9 = _mm256_fmadd_ps(dst9, src96, weight06); + __m256 src106 = _mm256_set1_ps(*(src + 86)); + dst10 = _mm256_fmadd_ps(dst10, src106, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + __m256 src97 = _mm256_set1_ps(*(src + 79)); + dst9 = _mm256_fmadd_ps(dst9, src97, weight07); + __m256 src107 = _mm256_set1_ps(*(src + 87)); + dst10 = _mm256_fmadd_ps(dst10, src107, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); + _mm256_store_ps(dst + 0 * src_stride + 72, dst9); + _mm256_store_ps(dst + 0 * src_stride + 80, dst10); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..83a66a8a53527571e7cf36450af33dac02dd9d2b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,325 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "vmovups 288(%[dst]), %%ymm9\n" + "vmovups 320(%[dst]), %%ymm10\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "vmovaps 0(%[bias]), %%ymm9\n" + "vmovaps 0(%[bias]), %%ymm10\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 288(%[src]), %%ymm14\n" + "vbroadcastss 320(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 289(%[src]), %%ymm14\n" + "vbroadcastss 321(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 290(%[src]), %%ymm14\n" + "vbroadcastss 322(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 291(%[src]), %%ymm14\n" + "vbroadcastss 323(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 292(%[src]), %%ymm14\n" + "vbroadcastss 324(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 293(%[src]), %%ymm14\n" + "vbroadcastss 325(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 294(%[src]), %%ymm14\n" + "vbroadcastss 326(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 295(%[src]), %%ymm14\n" + "vbroadcastss 327(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + "vmovups %%ymm9, 288(%[dst])\n" + "vmovups %%ymm10, 320(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8e39ab64baf7591f29c7dae4e14fd0d721a9b27f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,345 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + __m256 dst9; + __m256 dst10; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72); + dst10 = _mm256_load_ps(dst + 0 * dst_stride + 80); + dst11 = _mm256_load_ps(dst + 0 * dst_stride + 88); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 0); + dst10 = _mm256_load_ps(bias + 0); + dst11 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + __m256 src90 = _mm256_set1_ps(*(src + 72)); + dst9 = _mm256_fmadd_ps(dst9, src90, weight00); + __m256 src100 = _mm256_set1_ps(*(src + 80)); + dst10 = _mm256_fmadd_ps(dst10, src100, weight00); + __m256 src110 = _mm256_set1_ps(*(src + 88)); + dst11 = _mm256_fmadd_ps(dst11, src110, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + __m256 src91 = _mm256_set1_ps(*(src + 73)); + dst9 = _mm256_fmadd_ps(dst9, src91, weight01); + __m256 src101 = _mm256_set1_ps(*(src + 81)); + dst10 = _mm256_fmadd_ps(dst10, src101, weight01); + __m256 src111 = _mm256_set1_ps(*(src + 89)); + dst11 = _mm256_fmadd_ps(dst11, src111, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + __m256 src92 = _mm256_set1_ps(*(src + 74)); + dst9 = _mm256_fmadd_ps(dst9, src92, weight02); + __m256 src102 = _mm256_set1_ps(*(src + 82)); + dst10 = _mm256_fmadd_ps(dst10, src102, weight02); + __m256 src112 = _mm256_set1_ps(*(src + 90)); + dst11 = _mm256_fmadd_ps(dst11, src112, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + __m256 src93 = _mm256_set1_ps(*(src + 75)); + dst9 = _mm256_fmadd_ps(dst9, src93, weight03); + __m256 src103 = _mm256_set1_ps(*(src + 83)); + dst10 = _mm256_fmadd_ps(dst10, src103, weight03); + __m256 src113 = _mm256_set1_ps(*(src + 91)); + dst11 = _mm256_fmadd_ps(dst11, src113, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + __m256 src94 = _mm256_set1_ps(*(src + 76)); + dst9 = _mm256_fmadd_ps(dst9, src94, weight04); + __m256 src104 = _mm256_set1_ps(*(src + 84)); + dst10 = _mm256_fmadd_ps(dst10, src104, weight04); + __m256 src114 = _mm256_set1_ps(*(src + 92)); + dst11 = _mm256_fmadd_ps(dst11, src114, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + __m256 src95 = _mm256_set1_ps(*(src + 77)); + dst9 = _mm256_fmadd_ps(dst9, src95, weight05); + __m256 src105 = _mm256_set1_ps(*(src + 85)); + dst10 = _mm256_fmadd_ps(dst10, src105, weight05); + __m256 src115 = _mm256_set1_ps(*(src + 93)); + dst11 = _mm256_fmadd_ps(dst11, src115, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + __m256 src96 = _mm256_set1_ps(*(src + 78)); + dst9 = _mm256_fmadd_ps(dst9, src96, weight06); + __m256 src106 = _mm256_set1_ps(*(src + 86)); + dst10 = _mm256_fmadd_ps(dst10, src106, weight06); + __m256 src116 = _mm256_set1_ps(*(src + 94)); + dst11 = _mm256_fmadd_ps(dst11, src116, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + __m256 src97 = _mm256_set1_ps(*(src + 79)); + dst9 = _mm256_fmadd_ps(dst9, src97, weight07); + __m256 src107 = _mm256_set1_ps(*(src + 87)); + dst10 = _mm256_fmadd_ps(dst10, src107, weight07); + __m256 src117 = _mm256_set1_ps(*(src + 95)); + dst11 = _mm256_fmadd_ps(dst11, src117, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); + _mm256_store_ps(dst + 0 * src_stride + 72, dst9); + _mm256_store_ps(dst + 0 * src_stride + 80, dst10); + _mm256_store_ps(dst + 0 * src_stride + 88, dst11); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..e478150a22317e4ab2512693397f69887bbc3ab5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,347 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "vmovups 288(%[dst]), %%ymm9\n" + "vmovups 320(%[dst]), %%ymm10\n" + "vmovups 352(%[dst]), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "vmovaps 0(%[bias]), %%ymm9\n" + "vmovaps 0(%[bias]), %%ymm10\n" + "vmovaps 0(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 288(%[src]), %%ymm14\n" + "vbroadcastss 320(%[src]), %%ymm13\n" + "vbroadcastss 352(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 289(%[src]), %%ymm14\n" + "vbroadcastss 321(%[src]), %%ymm13\n" + "vbroadcastss 353(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 290(%[src]), %%ymm14\n" + "vbroadcastss 322(%[src]), %%ymm13\n" + "vbroadcastss 354(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 291(%[src]), %%ymm14\n" + "vbroadcastss 323(%[src]), %%ymm13\n" + "vbroadcastss 355(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 292(%[src]), %%ymm14\n" + "vbroadcastss 324(%[src]), %%ymm13\n" + "vbroadcastss 356(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 293(%[src]), %%ymm14\n" + "vbroadcastss 325(%[src]), %%ymm13\n" + "vbroadcastss 357(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 294(%[src]), %%ymm14\n" + "vbroadcastss 326(%[src]), %%ymm13\n" + "vbroadcastss 358(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 295(%[src]), %%ymm14\n" + "vbroadcastss 327(%[src]), %%ymm13\n" + "vbroadcastss 359(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + "vmovups %%ymm9, 288(%[dst])\n" + "vmovups %%ymm10, 320(%[dst])\n" + "vmovups %%ymm11, 352(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..87a830fbc57095e1e6d09a66107a7dd37c994cc1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,105 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src00, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src01, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src02, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src03, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src04, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src05, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src06, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src07, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 1 * src_stride + 0, dst1); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..5b592d4b6fac1698071c56a344b0d10e08f594b3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,127 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 32(%[bias]), %%ymm1\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..548f143c0a08bee2d2940411152d5866808f5fa3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 2 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src00, weight10); + dst2 = _mm256_fmadd_ps(dst2, src00, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src01, weight11); + dst2 = _mm256_fmadd_ps(dst2, src01, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src02, weight12); + dst2 = _mm256_fmadd_ps(dst2, src02, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src03, weight13); + dst2 = _mm256_fmadd_ps(dst2, src03, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src04, weight14); + dst2 = _mm256_fmadd_ps(dst2, src04, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src05, weight15); + dst2 = _mm256_fmadd_ps(dst2, src05, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src06, weight16); + dst2 = _mm256_fmadd_ps(dst2, src06, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src07, weight17); + dst2 = _mm256_fmadd_ps(dst2, src07, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 1 * src_stride + 0, dst1); + _mm256_store_ps(dst + 2 * src_stride + 0, dst2); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..ca1be9fcfd7bdf86e36c822be297a68dc7777cc4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm2\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 32(%[bias]), %%ymm1\n" + "vmovaps 64(%[bias]), %%ymm2\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 2)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c23a6ff6fa4dc13a398b51abdd34446df6a2214b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c @@ -0,0 +1,153 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 3 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 16); + dst3 = _mm256_load_ps(bias + 24); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 src00 = _mm256_set1_ps(*(src + 0)); + __m256 weight00 = _mm256_load_ps(weight + 0); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 weight10 = _mm256_load_ps(weight + 8); + dst1 = _mm256_fmadd_ps(dst1, src00, weight10); + __m256 weight20 = _mm256_load_ps(weight + 16); + dst2 = _mm256_fmadd_ps(dst2, src00, weight20); + __m256 weight30 = _mm256_load_ps(weight + 24); + dst3 = _mm256_fmadd_ps(dst3, src00, weight30); + // bock1 + __m256 src01 = _mm256_set1_ps(*(src + 1)); + __m256 weight01 = _mm256_load_ps(weight + 32); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 weight11 = _mm256_load_ps(weight + 40); + dst1 = _mm256_fmadd_ps(dst1, src01, weight11); + __m256 weight21 = _mm256_load_ps(weight + 48); + dst2 = _mm256_fmadd_ps(dst2, src01, weight21); + __m256 weight31 = _mm256_load_ps(weight + 56); + dst3 = _mm256_fmadd_ps(dst3, src01, weight31); + // bock2 + __m256 src02 = _mm256_set1_ps(*(src + 2)); + __m256 weight02 = _mm256_load_ps(weight + 64); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 weight12 = _mm256_load_ps(weight + 72); + dst1 = _mm256_fmadd_ps(dst1, src02, weight12); + __m256 weight22 = _mm256_load_ps(weight + 80); + dst2 = _mm256_fmadd_ps(dst2, src02, weight22); + __m256 weight32 = _mm256_load_ps(weight + 88); + dst3 = _mm256_fmadd_ps(dst3, src02, weight32); + // bock3 + __m256 src03 = _mm256_set1_ps(*(src + 3)); + __m256 weight03 = _mm256_load_ps(weight + 96); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 weight13 = _mm256_load_ps(weight + 104); + dst1 = _mm256_fmadd_ps(dst1, src03, weight13); + __m256 weight23 = _mm256_load_ps(weight + 112); + dst2 = _mm256_fmadd_ps(dst2, src03, weight23); + __m256 weight33 = _mm256_load_ps(weight + 120); + dst3 = _mm256_fmadd_ps(dst3, src03, weight33); + // bock4 + __m256 src04 = _mm256_set1_ps(*(src + 4)); + __m256 weight04 = _mm256_load_ps(weight + 128); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 weight14 = _mm256_load_ps(weight + 136); + dst1 = _mm256_fmadd_ps(dst1, src04, weight14); + __m256 weight24 = _mm256_load_ps(weight + 144); + dst2 = _mm256_fmadd_ps(dst2, src04, weight24); + __m256 weight34 = _mm256_load_ps(weight + 152); + dst3 = _mm256_fmadd_ps(dst3, src04, weight34); + // bock5 + __m256 src05 = _mm256_set1_ps(*(src + 5)); + __m256 weight05 = _mm256_load_ps(weight + 160); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 weight15 = _mm256_load_ps(weight + 168); + dst1 = _mm256_fmadd_ps(dst1, src05, weight15); + __m256 weight25 = _mm256_load_ps(weight + 176); + dst2 = _mm256_fmadd_ps(dst2, src05, weight25); + __m256 weight35 = _mm256_load_ps(weight + 184); + dst3 = _mm256_fmadd_ps(dst3, src05, weight35); + // bock6 + __m256 src06 = _mm256_set1_ps(*(src + 6)); + __m256 weight06 = _mm256_load_ps(weight + 192); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 weight16 = _mm256_load_ps(weight + 200); + dst1 = _mm256_fmadd_ps(dst1, src06, weight16); + __m256 weight26 = _mm256_load_ps(weight + 208); + dst2 = _mm256_fmadd_ps(dst2, src06, weight26); + __m256 weight36 = _mm256_load_ps(weight + 216); + dst3 = _mm256_fmadd_ps(dst3, src06, weight36); + // bock7 + __m256 src07 = _mm256_set1_ps(*(src + 7)); + __m256 weight07 = _mm256_load_ps(weight + 224); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 weight17 = _mm256_load_ps(weight + 232); + dst1 = _mm256_fmadd_ps(dst1, src07, weight17); + __m256 weight27 = _mm256_load_ps(weight + 240); + dst2 = _mm256_fmadd_ps(dst2, src07, weight27); + __m256 weight37 = _mm256_load_ps(weight + 248); + dst3 = _mm256_fmadd_ps(dst3, src07, weight37); + src = src + src_stride; + weight += 1024; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 1 * src_stride + 0, dst1); + _mm256_store_ps(dst + 2 * src_stride + 0, dst2); + _mm256_store_ps(dst + 3 * src_stride + 0, dst3); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..e327580af62c26cf80be5b8cde1afd8df0579120 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,173 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm2\n" + "vmovups 0(%[dst_4]), %%ymm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 32(%[bias]), %%ymm1\n" + "vmovaps 64(%[bias]), %%ymm2\n" + "vmovaps 96(%[bias]), %%ymm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_4] "r"(dst_4) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + asm volatile( + "0:\n" + // block 0 + "vbroadcastss 0(%[src]), %%ymm15\n" + "vmovaps 0(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 64(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 1 + "vbroadcastss 1(%[src]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 192(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 2 + "vbroadcastss 2(%[src]), %%ymm15\n" + "vmovaps 256(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 3 + "vbroadcastss 3(%[src]), %%ymm15\n" + "vmovaps 384(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 448(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 4 + "vbroadcastss 4(%[src]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 544(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 576(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 5 + "vbroadcastss 5(%[src]), %%ymm15\n" + "vmovaps 640(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 672(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 736(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 6 + "vbroadcastss 6(%[src]), %%ymm15\n" + "vmovaps 768(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 800(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 832(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 864(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 7 + "vbroadcastss 7(%[src]), %%ymm15\n" + "vmovaps 896(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 928(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 960(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 992(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 1024, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm3, 0(%[dst_4])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_4] "r"(dst_4) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a795d14be0e68391eb1af2a952028f56d88603b6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..97a62290fecc491ed6bbe46db870ee6b8a704327 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,105 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..55787a6905a7b9bc3bf09178c6ef15183033fd6e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,145 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst2; + __m256 dst1; + __m256 dst3; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst2 = _mm256_fmadd_ps(dst2, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst3 = _mm256_fmadd_ps(dst3, src10, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst2 = _mm256_fmadd_ps(dst2, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst3 = _mm256_fmadd_ps(dst3, src11, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst2 = _mm256_fmadd_ps(dst2, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst3 = _mm256_fmadd_ps(dst3, src12, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst2 = _mm256_fmadd_ps(dst2, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst3 = _mm256_fmadd_ps(dst3, src13, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst2 = _mm256_fmadd_ps(dst2, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst3 = _mm256_fmadd_ps(dst3, src14, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst2 = _mm256_fmadd_ps(dst2, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst3 = _mm256_fmadd_ps(dst3, src15, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst2 = _mm256_fmadd_ps(dst2, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst3 = _mm256_fmadd_ps(dst3, src16, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst2 = _mm256_fmadd_ps(dst2, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst3 = _mm256_fmadd_ps(dst3, src17, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 1 * src_stride + 0, dst2); + _mm256_store_ps(dst + 1 * src_stride + 8, dst3); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..d1f7cdce515900018a002d9a7101d08f8a9f20ed --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,163 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 32(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1de0370cb7c38f3ac6242fe18a7c4f5b9563057b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,185 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst2; + __m256 dst4; + __m256 dst1; + __m256 dst3; + __m256 dst5; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 2 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 16); + dst1 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst5 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst2 = _mm256_fmadd_ps(dst2, src00, weight10); + dst4 = _mm256_fmadd_ps(dst4, src00, weight20); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst3 = _mm256_fmadd_ps(dst3, src10, weight10); + dst5 = _mm256_fmadd_ps(dst5, src10, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst2 = _mm256_fmadd_ps(dst2, src01, weight11); + dst4 = _mm256_fmadd_ps(dst4, src01, weight21); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst3 = _mm256_fmadd_ps(dst3, src11, weight11); + dst5 = _mm256_fmadd_ps(dst5, src11, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst2 = _mm256_fmadd_ps(dst2, src02, weight12); + dst4 = _mm256_fmadd_ps(dst4, src02, weight22); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst3 = _mm256_fmadd_ps(dst3, src12, weight12); + dst5 = _mm256_fmadd_ps(dst5, src12, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst2 = _mm256_fmadd_ps(dst2, src03, weight13); + dst4 = _mm256_fmadd_ps(dst4, src03, weight23); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst3 = _mm256_fmadd_ps(dst3, src13, weight13); + dst5 = _mm256_fmadd_ps(dst5, src13, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst2 = _mm256_fmadd_ps(dst2, src04, weight14); + dst4 = _mm256_fmadd_ps(dst4, src04, weight24); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst3 = _mm256_fmadd_ps(dst3, src14, weight14); + dst5 = _mm256_fmadd_ps(dst5, src14, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst2 = _mm256_fmadd_ps(dst2, src05, weight15); + dst4 = _mm256_fmadd_ps(dst4, src05, weight25); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst3 = _mm256_fmadd_ps(dst3, src15, weight15); + dst5 = _mm256_fmadd_ps(dst5, src15, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst2 = _mm256_fmadd_ps(dst2, src06, weight16); + dst4 = _mm256_fmadd_ps(dst4, src06, weight26); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst3 = _mm256_fmadd_ps(dst3, src16, weight16); + dst5 = _mm256_fmadd_ps(dst5, src16, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst2 = _mm256_fmadd_ps(dst2, src07, weight17); + dst4 = _mm256_fmadd_ps(dst4, src07, weight27); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst3 = _mm256_fmadd_ps(dst3, src17, weight17); + dst5 = _mm256_fmadd_ps(dst5, src17, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 1 * src_stride + 0, dst2); + _mm256_store_ps(dst + 1 * src_stride + 8, dst3); + _mm256_store_ps(dst + 2 * src_stride + 0, dst4); + _mm256_store_ps(dst + 2 * src_stride + 8, dst5); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..2c5e9306d2c9ff93aa987955d6cbd66381d823af --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,199 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm5\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 32(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 64(%[bias]), %%ymm4\n" + "vmovaps 64(%[bias]), %%ymm5\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 2)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..0f9dbfa6ed9f4000171d9e7bde56654be4a49f9c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c @@ -0,0 +1,225 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst2; + __m256 dst4; + __m256 dst6; + __m256 dst1; + __m256 dst3; + __m256 dst5; + __m256 dst7; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 3 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 3 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 16); + dst6 = _mm256_load_ps(bias + 24); + dst1 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst5 = _mm256_load_ps(bias + 16); + dst7 = _mm256_load_ps(bias + 24); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 src00 = _mm256_set1_ps(*(src + 0)); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + __m256 weight00 = _mm256_load_ps(weight + 0); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 weight10 = _mm256_load_ps(weight + 8); + dst2 = _mm256_fmadd_ps(dst2, src00, weight10); + dst3 = _mm256_fmadd_ps(dst3, src10, weight10); + __m256 weight20 = _mm256_load_ps(weight + 16); + dst4 = _mm256_fmadd_ps(dst4, src00, weight20); + dst5 = _mm256_fmadd_ps(dst5, src10, weight20); + __m256 weight30 = _mm256_load_ps(weight + 24); + dst6 = _mm256_fmadd_ps(dst6, src00, weight30); + dst7 = _mm256_fmadd_ps(dst7, src10, weight30); + // bock1 + __m256 src01 = _mm256_set1_ps(*(src + 1)); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + __m256 weight01 = _mm256_load_ps(weight + 32); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 weight11 = _mm256_load_ps(weight + 40); + dst2 = _mm256_fmadd_ps(dst2, src01, weight11); + dst3 = _mm256_fmadd_ps(dst3, src11, weight11); + __m256 weight21 = _mm256_load_ps(weight + 48); + dst4 = _mm256_fmadd_ps(dst4, src01, weight21); + dst5 = _mm256_fmadd_ps(dst5, src11, weight21); + __m256 weight31 = _mm256_load_ps(weight + 56); + dst6 = _mm256_fmadd_ps(dst6, src01, weight31); + dst7 = _mm256_fmadd_ps(dst7, src11, weight31); + // bock2 + __m256 src02 = _mm256_set1_ps(*(src + 2)); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + __m256 weight02 = _mm256_load_ps(weight + 64); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 weight12 = _mm256_load_ps(weight + 72); + dst2 = _mm256_fmadd_ps(dst2, src02, weight12); + dst3 = _mm256_fmadd_ps(dst3, src12, weight12); + __m256 weight22 = _mm256_load_ps(weight + 80); + dst4 = _mm256_fmadd_ps(dst4, src02, weight22); + dst5 = _mm256_fmadd_ps(dst5, src12, weight22); + __m256 weight32 = _mm256_load_ps(weight + 88); + dst6 = _mm256_fmadd_ps(dst6, src02, weight32); + dst7 = _mm256_fmadd_ps(dst7, src12, weight32); + // bock3 + __m256 src03 = _mm256_set1_ps(*(src + 3)); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + __m256 weight03 = _mm256_load_ps(weight + 96); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 weight13 = _mm256_load_ps(weight + 104); + dst2 = _mm256_fmadd_ps(dst2, src03, weight13); + dst3 = _mm256_fmadd_ps(dst3, src13, weight13); + __m256 weight23 = _mm256_load_ps(weight + 112); + dst4 = _mm256_fmadd_ps(dst4, src03, weight23); + dst5 = _mm256_fmadd_ps(dst5, src13, weight23); + __m256 weight33 = _mm256_load_ps(weight + 120); + dst6 = _mm256_fmadd_ps(dst6, src03, weight33); + dst7 = _mm256_fmadd_ps(dst7, src13, weight33); + // bock4 + __m256 src04 = _mm256_set1_ps(*(src + 4)); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + __m256 weight04 = _mm256_load_ps(weight + 128); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 weight14 = _mm256_load_ps(weight + 136); + dst2 = _mm256_fmadd_ps(dst2, src04, weight14); + dst3 = _mm256_fmadd_ps(dst3, src14, weight14); + __m256 weight24 = _mm256_load_ps(weight + 144); + dst4 = _mm256_fmadd_ps(dst4, src04, weight24); + dst5 = _mm256_fmadd_ps(dst5, src14, weight24); + __m256 weight34 = _mm256_load_ps(weight + 152); + dst6 = _mm256_fmadd_ps(dst6, src04, weight34); + dst7 = _mm256_fmadd_ps(dst7, src14, weight34); + // bock5 + __m256 src05 = _mm256_set1_ps(*(src + 5)); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + __m256 weight05 = _mm256_load_ps(weight + 160); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 weight15 = _mm256_load_ps(weight + 168); + dst2 = _mm256_fmadd_ps(dst2, src05, weight15); + dst3 = _mm256_fmadd_ps(dst3, src15, weight15); + __m256 weight25 = _mm256_load_ps(weight + 176); + dst4 = _mm256_fmadd_ps(dst4, src05, weight25); + dst5 = _mm256_fmadd_ps(dst5, src15, weight25); + __m256 weight35 = _mm256_load_ps(weight + 184); + dst6 = _mm256_fmadd_ps(dst6, src05, weight35); + dst7 = _mm256_fmadd_ps(dst7, src15, weight35); + // bock6 + __m256 src06 = _mm256_set1_ps(*(src + 6)); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + __m256 weight06 = _mm256_load_ps(weight + 192); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 weight16 = _mm256_load_ps(weight + 200); + dst2 = _mm256_fmadd_ps(dst2, src06, weight16); + dst3 = _mm256_fmadd_ps(dst3, src16, weight16); + __m256 weight26 = _mm256_load_ps(weight + 208); + dst4 = _mm256_fmadd_ps(dst4, src06, weight26); + dst5 = _mm256_fmadd_ps(dst5, src16, weight26); + __m256 weight36 = _mm256_load_ps(weight + 216); + dst6 = _mm256_fmadd_ps(dst6, src06, weight36); + dst7 = _mm256_fmadd_ps(dst7, src16, weight36); + // bock7 + __m256 src07 = _mm256_set1_ps(*(src + 7)); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + __m256 weight07 = _mm256_load_ps(weight + 224); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 weight17 = _mm256_load_ps(weight + 232); + dst2 = _mm256_fmadd_ps(dst2, src07, weight17); + dst3 = _mm256_fmadd_ps(dst3, src17, weight17); + __m256 weight27 = _mm256_load_ps(weight + 240); + dst4 = _mm256_fmadd_ps(dst4, src07, weight27); + dst5 = _mm256_fmadd_ps(dst5, src17, weight27); + __m256 weight37 = _mm256_load_ps(weight + 248); + dst6 = _mm256_fmadd_ps(dst6, src07, weight37); + dst7 = _mm256_fmadd_ps(dst7, src17, weight37); + src = src + src_stride; + weight += 1024; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 1 * src_stride + 0, dst2); + _mm256_store_ps(dst + 1 * src_stride + 8, dst3); + _mm256_store_ps(dst + 2 * src_stride + 0, dst4); + _mm256_store_ps(dst + 2 * src_stride + 8, dst5); + _mm256_store_ps(dst + 3 * src_stride + 0, dst6); + _mm256_store_ps(dst + 3 * src_stride + 8, dst7); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..142ff10663f6d97f9f0fb2be69de065a328195b6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,237 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm5\n" + "vmovups 0(%[dst_4]), %%ymm6\n" + "vmovups 32(%[dst_4]), %%ymm7\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 32(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 64(%[bias]), %%ymm4\n" + "vmovaps 64(%[bias]), %%ymm5\n" + "vmovaps 96(%[bias]), %%ymm6\n" + "vmovaps 96(%[bias]), %%ymm7\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_4] "r"(dst_4) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + asm volatile( + "0:\n" + // block 0 + "vbroadcastss 0(%[src]), %%ymm15\n" + "vbroadcastss 32(%[src]), %%ymm14\n" + "vmovaps 0(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 32(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 96(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 1 + "vbroadcastss 1(%[src]), %%ymm15\n" + "vbroadcastss 33(%[src]), %%ymm14\n" + "vmovaps 128(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 192(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 224(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 2 + "vbroadcastss 2(%[src]), %%ymm15\n" + "vbroadcastss 34(%[src]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 288(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 320(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 3 + "vbroadcastss 3(%[src]), %%ymm15\n" + "vbroadcastss 35(%[src]), %%ymm14\n" + "vmovaps 384(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 416(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 480(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 4 + "vbroadcastss 4(%[src]), %%ymm15\n" + "vbroadcastss 36(%[src]), %%ymm14\n" + "vmovaps 512(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 576(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 608(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 5 + "vbroadcastss 5(%[src]), %%ymm15\n" + "vbroadcastss 37(%[src]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 672(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 704(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 6 + "vbroadcastss 6(%[src]), %%ymm15\n" + "vbroadcastss 38(%[src]), %%ymm14\n" + "vmovaps 768(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 800(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 832(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 864(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 7 + "vbroadcastss 7(%[src]), %%ymm15\n" + "vbroadcastss 39(%[src]), %%ymm14\n" + "vmovaps 896(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 928(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 960(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 992(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 1024, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm6, 0(%[dst_4])\n" + "vmovups %%ymm7, 32(%[dst_4])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_4] "r"(dst_4) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..6fcf7958a743436d8f0f4fe870a33acc148b94f6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,105 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..d072491715ce745a88f4bc19cf60798b98f56f57 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,127 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8cec925905bd13f881d6363751c3fba0944f3a78 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,185 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst3; + __m256 dst1; + __m256 dst4; + __m256 dst2; + __m256 dst5; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst3 = _mm256_fmadd_ps(dst3, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst4 = _mm256_fmadd_ps(dst4, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst5 = _mm256_fmadd_ps(dst5, src20, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst3 = _mm256_fmadd_ps(dst3, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst4 = _mm256_fmadd_ps(dst4, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst5 = _mm256_fmadd_ps(dst5, src21, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst3 = _mm256_fmadd_ps(dst3, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst4 = _mm256_fmadd_ps(dst4, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst5 = _mm256_fmadd_ps(dst5, src22, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst3 = _mm256_fmadd_ps(dst3, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst4 = _mm256_fmadd_ps(dst4, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst5 = _mm256_fmadd_ps(dst5, src23, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst3 = _mm256_fmadd_ps(dst3, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst4 = _mm256_fmadd_ps(dst4, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst5 = _mm256_fmadd_ps(dst5, src24, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst3 = _mm256_fmadd_ps(dst3, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst4 = _mm256_fmadd_ps(dst4, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst5 = _mm256_fmadd_ps(dst5, src25, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst3 = _mm256_fmadd_ps(dst3, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst4 = _mm256_fmadd_ps(dst4, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst5 = _mm256_fmadd_ps(dst5, src26, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst3 = _mm256_fmadd_ps(dst3, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst4 = _mm256_fmadd_ps(dst4, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst5 = _mm256_fmadd_ps(dst5, src27, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 1 * src_stride + 0, dst3); + _mm256_store_ps(dst + 1 * src_stride + 8, dst4); + _mm256_store_ps(dst + 1 * src_stride + 16, dst5); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..b03ef6e2bd326140778d2753960c7efd1b738d4c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,199 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..db5d05d6f0c73f5615e22fc10bc6a2d8abb298b8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,241 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst3; + __m256 dst6; + __m256 dst1; + __m256 dst4; + __m256 dst7; + __m256 dst2; + __m256 dst5; + __m256 dst8; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst8 = _mm256_load_ps(dst + 2 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst6 = _mm256_load_ps(bias + 16); + dst1 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst7 = _mm256_load_ps(bias + 16); + dst2 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst8 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst3 = _mm256_fmadd_ps(dst3, src00, weight10); + dst6 = _mm256_fmadd_ps(dst6, src00, weight20); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst4 = _mm256_fmadd_ps(dst4, src10, weight10); + dst7 = _mm256_fmadd_ps(dst7, src10, weight20); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst5 = _mm256_fmadd_ps(dst5, src20, weight10); + dst8 = _mm256_fmadd_ps(dst8, src20, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst3 = _mm256_fmadd_ps(dst3, src01, weight11); + dst6 = _mm256_fmadd_ps(dst6, src01, weight21); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst4 = _mm256_fmadd_ps(dst4, src11, weight11); + dst7 = _mm256_fmadd_ps(dst7, src11, weight21); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst5 = _mm256_fmadd_ps(dst5, src21, weight11); + dst8 = _mm256_fmadd_ps(dst8, src21, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst3 = _mm256_fmadd_ps(dst3, src02, weight12); + dst6 = _mm256_fmadd_ps(dst6, src02, weight22); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst4 = _mm256_fmadd_ps(dst4, src12, weight12); + dst7 = _mm256_fmadd_ps(dst7, src12, weight22); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst5 = _mm256_fmadd_ps(dst5, src22, weight12); + dst8 = _mm256_fmadd_ps(dst8, src22, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst3 = _mm256_fmadd_ps(dst3, src03, weight13); + dst6 = _mm256_fmadd_ps(dst6, src03, weight23); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst4 = _mm256_fmadd_ps(dst4, src13, weight13); + dst7 = _mm256_fmadd_ps(dst7, src13, weight23); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst5 = _mm256_fmadd_ps(dst5, src23, weight13); + dst8 = _mm256_fmadd_ps(dst8, src23, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst3 = _mm256_fmadd_ps(dst3, src04, weight14); + dst6 = _mm256_fmadd_ps(dst6, src04, weight24); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst4 = _mm256_fmadd_ps(dst4, src14, weight14); + dst7 = _mm256_fmadd_ps(dst7, src14, weight24); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst5 = _mm256_fmadd_ps(dst5, src24, weight14); + dst8 = _mm256_fmadd_ps(dst8, src24, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst3 = _mm256_fmadd_ps(dst3, src05, weight15); + dst6 = _mm256_fmadd_ps(dst6, src05, weight25); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst4 = _mm256_fmadd_ps(dst4, src15, weight15); + dst7 = _mm256_fmadd_ps(dst7, src15, weight25); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst5 = _mm256_fmadd_ps(dst5, src25, weight15); + dst8 = _mm256_fmadd_ps(dst8, src25, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst3 = _mm256_fmadd_ps(dst3, src06, weight16); + dst6 = _mm256_fmadd_ps(dst6, src06, weight26); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst4 = _mm256_fmadd_ps(dst4, src16, weight16); + dst7 = _mm256_fmadd_ps(dst7, src16, weight26); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst5 = _mm256_fmadd_ps(dst5, src26, weight16); + dst8 = _mm256_fmadd_ps(dst8, src26, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst3 = _mm256_fmadd_ps(dst3, src07, weight17); + dst6 = _mm256_fmadd_ps(dst6, src07, weight27); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst4 = _mm256_fmadd_ps(dst4, src17, weight17); + dst7 = _mm256_fmadd_ps(dst7, src17, weight27); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst5 = _mm256_fmadd_ps(dst5, src27, weight17); + dst8 = _mm256_fmadd_ps(dst8, src27, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 1 * src_stride + 0, dst3); + _mm256_store_ps(dst + 1 * src_stride + 8, dst4); + _mm256_store_ps(dst + 1 * src_stride + 16, dst5); + _mm256_store_ps(dst + 2 * src_stride + 0, dst6); + _mm256_store_ps(dst + 2 * src_stride + 8, dst7); + _mm256_store_ps(dst + 2 * src_stride + 16, dst8); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..893f352396e2e5606a9e40604e57b875edd5615f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,249 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm6\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm7\n" + "vmovups 64(%[dst], %[dst_stride], 2), %%ymm8\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 64(%[bias]), %%ymm6\n" + "vmovaps 64(%[bias]), %%ymm7\n" + "vmovaps 64(%[bias]), %%ymm8\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm7, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm8, 64(%[dst], %[dst_stride], 2)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..eaf7595f1bcaa5f966143ff0c413f32adf847cbf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c @@ -0,0 +1,297 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst3; + __m256 dst6; + __m256 dst9; + __m256 dst1; + __m256 dst4; + __m256 dst7; + __m256 dst10; + __m256 dst2; + __m256 dst5; + __m256 dst8; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst9 = _mm256_load_ps(dst + 3 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst10 = _mm256_load_ps(dst + 3 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst8 = _mm256_load_ps(dst + 2 * dst_stride + 16); + dst11 = _mm256_load_ps(dst + 3 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst6 = _mm256_load_ps(bias + 16); + dst9 = _mm256_load_ps(bias + 24); + dst1 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst7 = _mm256_load_ps(bias + 16); + dst10 = _mm256_load_ps(bias + 24); + dst2 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst8 = _mm256_load_ps(bias + 16); + dst11 = _mm256_load_ps(bias + 24); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 src00 = _mm256_set1_ps(*(src + 0)); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + __m256 weight00 = _mm256_load_ps(weight + 0); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 weight10 = _mm256_load_ps(weight + 8); + dst3 = _mm256_fmadd_ps(dst3, src00, weight10); + dst4 = _mm256_fmadd_ps(dst4, src10, weight10); + dst5 = _mm256_fmadd_ps(dst5, src20, weight10); + __m256 weight20 = _mm256_load_ps(weight + 16); + dst6 = _mm256_fmadd_ps(dst6, src00, weight20); + dst7 = _mm256_fmadd_ps(dst7, src10, weight20); + dst8 = _mm256_fmadd_ps(dst8, src20, weight20); + __m256 weight30 = _mm256_load_ps(weight + 24); + dst9 = _mm256_fmadd_ps(dst9, src00, weight30); + dst10 = _mm256_fmadd_ps(dst10, src10, weight30); + dst11 = _mm256_fmadd_ps(dst11, src20, weight30); + // bock1 + __m256 src01 = _mm256_set1_ps(*(src + 1)); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + __m256 weight01 = _mm256_load_ps(weight + 32); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 weight11 = _mm256_load_ps(weight + 40); + dst3 = _mm256_fmadd_ps(dst3, src01, weight11); + dst4 = _mm256_fmadd_ps(dst4, src11, weight11); + dst5 = _mm256_fmadd_ps(dst5, src21, weight11); + __m256 weight21 = _mm256_load_ps(weight + 48); + dst6 = _mm256_fmadd_ps(dst6, src01, weight21); + dst7 = _mm256_fmadd_ps(dst7, src11, weight21); + dst8 = _mm256_fmadd_ps(dst8, src21, weight21); + __m256 weight31 = _mm256_load_ps(weight + 56); + dst9 = _mm256_fmadd_ps(dst9, src01, weight31); + dst10 = _mm256_fmadd_ps(dst10, src11, weight31); + dst11 = _mm256_fmadd_ps(dst11, src21, weight31); + // bock2 + __m256 src02 = _mm256_set1_ps(*(src + 2)); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + __m256 weight02 = _mm256_load_ps(weight + 64); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 weight12 = _mm256_load_ps(weight + 72); + dst3 = _mm256_fmadd_ps(dst3, src02, weight12); + dst4 = _mm256_fmadd_ps(dst4, src12, weight12); + dst5 = _mm256_fmadd_ps(dst5, src22, weight12); + __m256 weight22 = _mm256_load_ps(weight + 80); + dst6 = _mm256_fmadd_ps(dst6, src02, weight22); + dst7 = _mm256_fmadd_ps(dst7, src12, weight22); + dst8 = _mm256_fmadd_ps(dst8, src22, weight22); + __m256 weight32 = _mm256_load_ps(weight + 88); + dst9 = _mm256_fmadd_ps(dst9, src02, weight32); + dst10 = _mm256_fmadd_ps(dst10, src12, weight32); + dst11 = _mm256_fmadd_ps(dst11, src22, weight32); + // bock3 + __m256 src03 = _mm256_set1_ps(*(src + 3)); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + __m256 weight03 = _mm256_load_ps(weight + 96); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 weight13 = _mm256_load_ps(weight + 104); + dst3 = _mm256_fmadd_ps(dst3, src03, weight13); + dst4 = _mm256_fmadd_ps(dst4, src13, weight13); + dst5 = _mm256_fmadd_ps(dst5, src23, weight13); + __m256 weight23 = _mm256_load_ps(weight + 112); + dst6 = _mm256_fmadd_ps(dst6, src03, weight23); + dst7 = _mm256_fmadd_ps(dst7, src13, weight23); + dst8 = _mm256_fmadd_ps(dst8, src23, weight23); + __m256 weight33 = _mm256_load_ps(weight + 120); + dst9 = _mm256_fmadd_ps(dst9, src03, weight33); + dst10 = _mm256_fmadd_ps(dst10, src13, weight33); + dst11 = _mm256_fmadd_ps(dst11, src23, weight33); + // bock4 + __m256 src04 = _mm256_set1_ps(*(src + 4)); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + __m256 weight04 = _mm256_load_ps(weight + 128); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 weight14 = _mm256_load_ps(weight + 136); + dst3 = _mm256_fmadd_ps(dst3, src04, weight14); + dst4 = _mm256_fmadd_ps(dst4, src14, weight14); + dst5 = _mm256_fmadd_ps(dst5, src24, weight14); + __m256 weight24 = _mm256_load_ps(weight + 144); + dst6 = _mm256_fmadd_ps(dst6, src04, weight24); + dst7 = _mm256_fmadd_ps(dst7, src14, weight24); + dst8 = _mm256_fmadd_ps(dst8, src24, weight24); + __m256 weight34 = _mm256_load_ps(weight + 152); + dst9 = _mm256_fmadd_ps(dst9, src04, weight34); + dst10 = _mm256_fmadd_ps(dst10, src14, weight34); + dst11 = _mm256_fmadd_ps(dst11, src24, weight34); + // bock5 + __m256 src05 = _mm256_set1_ps(*(src + 5)); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + __m256 weight05 = _mm256_load_ps(weight + 160); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 weight15 = _mm256_load_ps(weight + 168); + dst3 = _mm256_fmadd_ps(dst3, src05, weight15); + dst4 = _mm256_fmadd_ps(dst4, src15, weight15); + dst5 = _mm256_fmadd_ps(dst5, src25, weight15); + __m256 weight25 = _mm256_load_ps(weight + 176); + dst6 = _mm256_fmadd_ps(dst6, src05, weight25); + dst7 = _mm256_fmadd_ps(dst7, src15, weight25); + dst8 = _mm256_fmadd_ps(dst8, src25, weight25); + __m256 weight35 = _mm256_load_ps(weight + 184); + dst9 = _mm256_fmadd_ps(dst9, src05, weight35); + dst10 = _mm256_fmadd_ps(dst10, src15, weight35); + dst11 = _mm256_fmadd_ps(dst11, src25, weight35); + // bock6 + __m256 src06 = _mm256_set1_ps(*(src + 6)); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + __m256 weight06 = _mm256_load_ps(weight + 192); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 weight16 = _mm256_load_ps(weight + 200); + dst3 = _mm256_fmadd_ps(dst3, src06, weight16); + dst4 = _mm256_fmadd_ps(dst4, src16, weight16); + dst5 = _mm256_fmadd_ps(dst5, src26, weight16); + __m256 weight26 = _mm256_load_ps(weight + 208); + dst6 = _mm256_fmadd_ps(dst6, src06, weight26); + dst7 = _mm256_fmadd_ps(dst7, src16, weight26); + dst8 = _mm256_fmadd_ps(dst8, src26, weight26); + __m256 weight36 = _mm256_load_ps(weight + 216); + dst9 = _mm256_fmadd_ps(dst9, src06, weight36); + dst10 = _mm256_fmadd_ps(dst10, src16, weight36); + dst11 = _mm256_fmadd_ps(dst11, src26, weight36); + // bock7 + __m256 src07 = _mm256_set1_ps(*(src + 7)); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + __m256 weight07 = _mm256_load_ps(weight + 224); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 weight17 = _mm256_load_ps(weight + 232); + dst3 = _mm256_fmadd_ps(dst3, src07, weight17); + dst4 = _mm256_fmadd_ps(dst4, src17, weight17); + dst5 = _mm256_fmadd_ps(dst5, src27, weight17); + __m256 weight27 = _mm256_load_ps(weight + 240); + dst6 = _mm256_fmadd_ps(dst6, src07, weight27); + dst7 = _mm256_fmadd_ps(dst7, src17, weight27); + dst8 = _mm256_fmadd_ps(dst8, src27, weight27); + __m256 weight37 = _mm256_load_ps(weight + 248); + dst9 = _mm256_fmadd_ps(dst9, src07, weight37); + dst10 = _mm256_fmadd_ps(dst10, src17, weight37); + dst11 = _mm256_fmadd_ps(dst11, src27, weight37); + src = src + src_stride; + weight += 1024; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 1 * src_stride + 0, dst3); + _mm256_store_ps(dst + 1 * src_stride + 8, dst4); + _mm256_store_ps(dst + 1 * src_stride + 16, dst5); + _mm256_store_ps(dst + 2 * src_stride + 0, dst6); + _mm256_store_ps(dst + 2 * src_stride + 8, dst7); + _mm256_store_ps(dst + 2 * src_stride + 16, dst8); + _mm256_store_ps(dst + 3 * src_stride + 0, dst9); + _mm256_store_ps(dst + 3 * src_stride + 8, dst10); + _mm256_store_ps(dst + 3 * src_stride + 16, dst11); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..e4d2f5b86ae9795d5af32fd85df49c41801d906a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,301 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm6\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm7\n" + "vmovups 64(%[dst], %[dst_stride], 2), %%ymm8\n" + "vmovups 0(%[dst_4]), %%ymm9\n" + "vmovups 32(%[dst_4]), %%ymm10\n" + "vmovups 64(%[dst_4]), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 64(%[bias]), %%ymm6\n" + "vmovaps 64(%[bias]), %%ymm7\n" + "vmovaps 64(%[bias]), %%ymm8\n" + "vmovaps 96(%[bias]), %%ymm9\n" + "vmovaps 96(%[bias]), %%ymm10\n" + "vmovaps 96(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_4] "r"(dst_4) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vbroadcastss 0(%[src]), %%ymm15\n" + "vbroadcastss 32(%[src]), %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vmovaps 0(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 32(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 64(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 96(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 1 + "vbroadcastss 1(%[src]), %%ymm15\n" + "vbroadcastss 33(%[src]), %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vmovaps 128(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 160(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 192(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 224(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 2 + "vbroadcastss 2(%[src]), %%ymm15\n" + "vbroadcastss 34(%[src]), %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vmovaps 256(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 288(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 320(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 352(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 3 + "vbroadcastss 3(%[src]), %%ymm15\n" + "vbroadcastss 35(%[src]), %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vmovaps 384(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 416(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 448(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 480(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 4 + "vbroadcastss 4(%[src]), %%ymm15\n" + "vbroadcastss 36(%[src]), %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vmovaps 512(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 544(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 576(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 608(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 5 + "vbroadcastss 5(%[src]), %%ymm15\n" + "vbroadcastss 37(%[src]), %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vmovaps 640(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 672(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 704(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 736(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 6 + "vbroadcastss 6(%[src]), %%ymm15\n" + "vbroadcastss 38(%[src]), %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vmovaps 768(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 800(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 832(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 864(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 7 + "vbroadcastss 7(%[src]), %%ymm15\n" + "vbroadcastss 39(%[src]), %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vmovaps 896(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 928(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 960(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 992(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 1024, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm7, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm8, 64(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm9, 0(%[dst_4])\n" + "vmovups %%ymm10, 32(%[dst_4])\n" + "vmovups %%ymm11, 64(%[dst_4])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_4] "r"(dst_4) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..aff74f6f09f2c2aca45d1fc8bf96d784188f1ecd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..dae2ecbb3dc604516049cbcb9d7261c57a6ff245 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..4b34d19bea615bc94cd3c293fa31c5fc999c394b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,225 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst4; + __m256 dst1; + __m256 dst5; + __m256 dst2; + __m256 dst6; + __m256 dst3; + __m256 dst7; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 24); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst3 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst4 = _mm256_fmadd_ps(dst4, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst5 = _mm256_fmadd_ps(dst5, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst6 = _mm256_fmadd_ps(dst6, src20, weight10); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst7 = _mm256_fmadd_ps(dst7, src30, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst4 = _mm256_fmadd_ps(dst4, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst5 = _mm256_fmadd_ps(dst5, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst6 = _mm256_fmadd_ps(dst6, src21, weight11); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst7 = _mm256_fmadd_ps(dst7, src31, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst4 = _mm256_fmadd_ps(dst4, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst5 = _mm256_fmadd_ps(dst5, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst6 = _mm256_fmadd_ps(dst6, src22, weight12); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst7 = _mm256_fmadd_ps(dst7, src32, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst4 = _mm256_fmadd_ps(dst4, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst5 = _mm256_fmadd_ps(dst5, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst6 = _mm256_fmadd_ps(dst6, src23, weight13); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst7 = _mm256_fmadd_ps(dst7, src33, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst4 = _mm256_fmadd_ps(dst4, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst5 = _mm256_fmadd_ps(dst5, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst6 = _mm256_fmadd_ps(dst6, src24, weight14); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst7 = _mm256_fmadd_ps(dst7, src34, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst4 = _mm256_fmadd_ps(dst4, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst5 = _mm256_fmadd_ps(dst5, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst6 = _mm256_fmadd_ps(dst6, src25, weight15); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst7 = _mm256_fmadd_ps(dst7, src35, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst4 = _mm256_fmadd_ps(dst4, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst5 = _mm256_fmadd_ps(dst5, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst6 = _mm256_fmadd_ps(dst6, src26, weight16); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst7 = _mm256_fmadd_ps(dst7, src36, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst4 = _mm256_fmadd_ps(dst4, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst5 = _mm256_fmadd_ps(dst5, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst6 = _mm256_fmadd_ps(dst6, src27, weight17); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst7 = _mm256_fmadd_ps(dst7, src37, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 1 * src_stride + 0, dst4); + _mm256_store_ps(dst + 1 * src_stride + 8, dst5); + _mm256_store_ps(dst + 1 * src_stride + 16, dst6); + _mm256_store_ps(dst + 1 * src_stride + 24, dst7); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..976b074a58292ae1b3fb3df4f5db54f12a7a2eab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,235 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm7\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 96(%[dst], %[dst_stride], 1)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8d987c6153a1a4c550e4846dbb7de79b65dea5f7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,297 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst4; + __m256 dst8; + __m256 dst1; + __m256 dst5; + __m256 dst9; + __m256 dst2; + __m256 dst6; + __m256 dst10; + __m256 dst3; + __m256 dst7; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst8 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst9 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst10 = _mm256_load_ps(dst + 2 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 24); + dst11 = _mm256_load_ps(dst + 2 * dst_stride + 24); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst8 = _mm256_load_ps(bias + 16); + dst1 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst9 = _mm256_load_ps(bias + 16); + dst2 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst10 = _mm256_load_ps(bias + 16); + dst3 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + dst11 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst4 = _mm256_fmadd_ps(dst4, src00, weight10); + dst8 = _mm256_fmadd_ps(dst8, src00, weight20); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst5 = _mm256_fmadd_ps(dst5, src10, weight10); + dst9 = _mm256_fmadd_ps(dst9, src10, weight20); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst6 = _mm256_fmadd_ps(dst6, src20, weight10); + dst10 = _mm256_fmadd_ps(dst10, src20, weight20); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst7 = _mm256_fmadd_ps(dst7, src30, weight10); + dst11 = _mm256_fmadd_ps(dst11, src30, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst4 = _mm256_fmadd_ps(dst4, src01, weight11); + dst8 = _mm256_fmadd_ps(dst8, src01, weight21); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst5 = _mm256_fmadd_ps(dst5, src11, weight11); + dst9 = _mm256_fmadd_ps(dst9, src11, weight21); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst6 = _mm256_fmadd_ps(dst6, src21, weight11); + dst10 = _mm256_fmadd_ps(dst10, src21, weight21); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst7 = _mm256_fmadd_ps(dst7, src31, weight11); + dst11 = _mm256_fmadd_ps(dst11, src31, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst4 = _mm256_fmadd_ps(dst4, src02, weight12); + dst8 = _mm256_fmadd_ps(dst8, src02, weight22); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst5 = _mm256_fmadd_ps(dst5, src12, weight12); + dst9 = _mm256_fmadd_ps(dst9, src12, weight22); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst6 = _mm256_fmadd_ps(dst6, src22, weight12); + dst10 = _mm256_fmadd_ps(dst10, src22, weight22); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst7 = _mm256_fmadd_ps(dst7, src32, weight12); + dst11 = _mm256_fmadd_ps(dst11, src32, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst4 = _mm256_fmadd_ps(dst4, src03, weight13); + dst8 = _mm256_fmadd_ps(dst8, src03, weight23); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst5 = _mm256_fmadd_ps(dst5, src13, weight13); + dst9 = _mm256_fmadd_ps(dst9, src13, weight23); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst6 = _mm256_fmadd_ps(dst6, src23, weight13); + dst10 = _mm256_fmadd_ps(dst10, src23, weight23); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst7 = _mm256_fmadd_ps(dst7, src33, weight13); + dst11 = _mm256_fmadd_ps(dst11, src33, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst4 = _mm256_fmadd_ps(dst4, src04, weight14); + dst8 = _mm256_fmadd_ps(dst8, src04, weight24); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst5 = _mm256_fmadd_ps(dst5, src14, weight14); + dst9 = _mm256_fmadd_ps(dst9, src14, weight24); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst6 = _mm256_fmadd_ps(dst6, src24, weight14); + dst10 = _mm256_fmadd_ps(dst10, src24, weight24); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst7 = _mm256_fmadd_ps(dst7, src34, weight14); + dst11 = _mm256_fmadd_ps(dst11, src34, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst4 = _mm256_fmadd_ps(dst4, src05, weight15); + dst8 = _mm256_fmadd_ps(dst8, src05, weight25); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst5 = _mm256_fmadd_ps(dst5, src15, weight15); + dst9 = _mm256_fmadd_ps(dst9, src15, weight25); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst6 = _mm256_fmadd_ps(dst6, src25, weight15); + dst10 = _mm256_fmadd_ps(dst10, src25, weight25); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst7 = _mm256_fmadd_ps(dst7, src35, weight15); + dst11 = _mm256_fmadd_ps(dst11, src35, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst4 = _mm256_fmadd_ps(dst4, src06, weight16); + dst8 = _mm256_fmadd_ps(dst8, src06, weight26); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst5 = _mm256_fmadd_ps(dst5, src16, weight16); + dst9 = _mm256_fmadd_ps(dst9, src16, weight26); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst6 = _mm256_fmadd_ps(dst6, src26, weight16); + dst10 = _mm256_fmadd_ps(dst10, src26, weight26); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst7 = _mm256_fmadd_ps(dst7, src36, weight16); + dst11 = _mm256_fmadd_ps(dst11, src36, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst4 = _mm256_fmadd_ps(dst4, src07, weight17); + dst8 = _mm256_fmadd_ps(dst8, src07, weight27); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst5 = _mm256_fmadd_ps(dst5, src17, weight17); + dst9 = _mm256_fmadd_ps(dst9, src17, weight27); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst6 = _mm256_fmadd_ps(dst6, src27, weight17); + dst10 = _mm256_fmadd_ps(dst10, src27, weight27); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst7 = _mm256_fmadd_ps(dst7, src37, weight17); + dst11 = _mm256_fmadd_ps(dst11, src37, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 1 * src_stride + 0, dst4); + _mm256_store_ps(dst + 1 * src_stride + 8, dst5); + _mm256_store_ps(dst + 1 * src_stride + 16, dst6); + _mm256_store_ps(dst + 1 * src_stride + 24, dst7); + _mm256_store_ps(dst + 2 * src_stride + 0, dst8); + _mm256_store_ps(dst + 2 * src_stride + 8, dst9); + _mm256_store_ps(dst + 2 * src_stride + 16, dst10); + _mm256_store_ps(dst + 2 * src_stride + 24, dst11); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..3470b7f2ce701a4b5e933c16e62060a52d81f76f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,299 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm7\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm8\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm9\n" + "vmovups 64(%[dst], %[dst_stride], 2), %%ymm10\n" + "vmovups 96(%[dst], %[dst_stride], 2), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "vmovaps 64(%[bias]), %%ymm8\n" + "vmovaps 64(%[bias]), %%ymm9\n" + "vmovaps 64(%[bias]), %%ymm10\n" + "vmovaps 64(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 96(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm8, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm9, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm10, 64(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm11, 96(%[dst], %[dst_stride], 2)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c972e8a846b6d220110fc099d21ad24def80b9fa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,153 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..9377253b6d59fc306229aa914882c4d4e885d9e5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,171 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..5426fb627ad7ca63de3818c9696658955776ce8f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,265 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst5; + __m256 dst1; + __m256 dst6; + __m256 dst2; + __m256 dst7; + __m256 dst3; + __m256 dst8; + __m256 dst4; + __m256 dst9; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst8 = _mm256_load_ps(dst + 1 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst9 = _mm256_load_ps(dst + 1 * dst_stride + 32); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + dst3 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst5 = _mm256_fmadd_ps(dst5, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst6 = _mm256_fmadd_ps(dst6, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst7 = _mm256_fmadd_ps(dst7, src20, weight10); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst8 = _mm256_fmadd_ps(dst8, src30, weight10); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + dst9 = _mm256_fmadd_ps(dst9, src40, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst5 = _mm256_fmadd_ps(dst5, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst6 = _mm256_fmadd_ps(dst6, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst7 = _mm256_fmadd_ps(dst7, src21, weight11); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst8 = _mm256_fmadd_ps(dst8, src31, weight11); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + dst9 = _mm256_fmadd_ps(dst9, src41, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst5 = _mm256_fmadd_ps(dst5, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst6 = _mm256_fmadd_ps(dst6, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst7 = _mm256_fmadd_ps(dst7, src22, weight12); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst8 = _mm256_fmadd_ps(dst8, src32, weight12); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + dst9 = _mm256_fmadd_ps(dst9, src42, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst5 = _mm256_fmadd_ps(dst5, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst6 = _mm256_fmadd_ps(dst6, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst7 = _mm256_fmadd_ps(dst7, src23, weight13); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst8 = _mm256_fmadd_ps(dst8, src33, weight13); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + dst9 = _mm256_fmadd_ps(dst9, src43, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst5 = _mm256_fmadd_ps(dst5, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst6 = _mm256_fmadd_ps(dst6, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst7 = _mm256_fmadd_ps(dst7, src24, weight14); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst8 = _mm256_fmadd_ps(dst8, src34, weight14); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + dst9 = _mm256_fmadd_ps(dst9, src44, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst5 = _mm256_fmadd_ps(dst5, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst6 = _mm256_fmadd_ps(dst6, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst7 = _mm256_fmadd_ps(dst7, src25, weight15); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst8 = _mm256_fmadd_ps(dst8, src35, weight15); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + dst9 = _mm256_fmadd_ps(dst9, src45, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst5 = _mm256_fmadd_ps(dst5, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst6 = _mm256_fmadd_ps(dst6, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst7 = _mm256_fmadd_ps(dst7, src26, weight16); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst8 = _mm256_fmadd_ps(dst8, src36, weight16); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + dst9 = _mm256_fmadd_ps(dst9, src46, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst5 = _mm256_fmadd_ps(dst5, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst6 = _mm256_fmadd_ps(dst6, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst7 = _mm256_fmadd_ps(dst7, src27, weight17); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst8 = _mm256_fmadd_ps(dst8, src37, weight17); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + dst9 = _mm256_fmadd_ps(dst9, src47, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 1 * src_stride + 0, dst5); + _mm256_store_ps(dst + 1 * src_stride + 8, dst6); + _mm256_store_ps(dst + 1 * src_stride + 16, dst7); + _mm256_store_ps(dst + 1 * src_stride + 24, dst8); + _mm256_store_ps(dst + 1 * src_stride + 32, dst9); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..ba49549907f64e330939920a3827f05489d8e330 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,271 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm7\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm8\n" + "vmovups 128(%[dst], %[dst_stride], 1), %%ymm9\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "vmovaps 32(%[bias]), %%ymm8\n" + "vmovaps 32(%[bias]), %%ymm9\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm8, 96(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm9, 128(%[dst], %[dst_stride], 1)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c3a6e1a6197c5cfa6ac970655fac724c1936e2a2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,177 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..2cc835a7d88465f271f23a5a0090a9a59b3e52d9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,193 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..5e074d19f9a637e8800e350f2d5ffd4153bc76eb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,305 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst6; + __m256 dst1; + __m256 dst7; + __m256 dst2; + __m256 dst8; + __m256 dst3; + __m256 dst9; + __m256 dst4; + __m256 dst10; + __m256 dst5; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst8 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst9 = _mm256_load_ps(dst + 1 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst10 = _mm256_load_ps(dst + 1 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst11 = _mm256_load_ps(dst + 1 * dst_stride + 40); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 8); + dst3 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 0); + dst10 = _mm256_load_ps(bias + 8); + dst5 = _mm256_load_ps(bias + 0); + dst11 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst6 = _mm256_fmadd_ps(dst6, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst7 = _mm256_fmadd_ps(dst7, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst8 = _mm256_fmadd_ps(dst8, src20, weight10); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst9 = _mm256_fmadd_ps(dst9, src30, weight10); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + dst10 = _mm256_fmadd_ps(dst10, src40, weight10); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + dst11 = _mm256_fmadd_ps(dst11, src50, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst6 = _mm256_fmadd_ps(dst6, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst7 = _mm256_fmadd_ps(dst7, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst8 = _mm256_fmadd_ps(dst8, src21, weight11); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst9 = _mm256_fmadd_ps(dst9, src31, weight11); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + dst10 = _mm256_fmadd_ps(dst10, src41, weight11); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + dst11 = _mm256_fmadd_ps(dst11, src51, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst6 = _mm256_fmadd_ps(dst6, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst7 = _mm256_fmadd_ps(dst7, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst8 = _mm256_fmadd_ps(dst8, src22, weight12); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst9 = _mm256_fmadd_ps(dst9, src32, weight12); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + dst10 = _mm256_fmadd_ps(dst10, src42, weight12); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + dst11 = _mm256_fmadd_ps(dst11, src52, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst6 = _mm256_fmadd_ps(dst6, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst7 = _mm256_fmadd_ps(dst7, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst8 = _mm256_fmadd_ps(dst8, src23, weight13); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst9 = _mm256_fmadd_ps(dst9, src33, weight13); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + dst10 = _mm256_fmadd_ps(dst10, src43, weight13); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + dst11 = _mm256_fmadd_ps(dst11, src53, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst6 = _mm256_fmadd_ps(dst6, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst7 = _mm256_fmadd_ps(dst7, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst8 = _mm256_fmadd_ps(dst8, src24, weight14); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst9 = _mm256_fmadd_ps(dst9, src34, weight14); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + dst10 = _mm256_fmadd_ps(dst10, src44, weight14); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + dst11 = _mm256_fmadd_ps(dst11, src54, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst6 = _mm256_fmadd_ps(dst6, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst7 = _mm256_fmadd_ps(dst7, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst8 = _mm256_fmadd_ps(dst8, src25, weight15); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst9 = _mm256_fmadd_ps(dst9, src35, weight15); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + dst10 = _mm256_fmadd_ps(dst10, src45, weight15); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + dst11 = _mm256_fmadd_ps(dst11, src55, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst6 = _mm256_fmadd_ps(dst6, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst7 = _mm256_fmadd_ps(dst7, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst8 = _mm256_fmadd_ps(dst8, src26, weight16); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst9 = _mm256_fmadd_ps(dst9, src36, weight16); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + dst10 = _mm256_fmadd_ps(dst10, src46, weight16); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + dst11 = _mm256_fmadd_ps(dst11, src56, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst6 = _mm256_fmadd_ps(dst6, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst7 = _mm256_fmadd_ps(dst7, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst8 = _mm256_fmadd_ps(dst8, src27, weight17); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst9 = _mm256_fmadd_ps(dst9, src37, weight17); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + dst10 = _mm256_fmadd_ps(dst10, src47, weight17); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + dst11 = _mm256_fmadd_ps(dst11, src57, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 1 * src_stride + 0, dst6); + _mm256_store_ps(dst + 1 * src_stride + 8, dst7); + _mm256_store_ps(dst + 1 * src_stride + 16, dst8); + _mm256_store_ps(dst + 1 * src_stride + 24, dst9); + _mm256_store_ps(dst + 1 * src_stride + 32, dst10); + _mm256_store_ps(dst + 1 * src_stride + 40, dst11); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..ba927555c4bfa59c07a1420983a214b8a1c8787c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,307 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm7\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm8\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm9\n" + "vmovups 128(%[dst], %[dst_stride], 1), %%ymm10\n" + "vmovups 160(%[dst], %[dst_stride], 1), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "vmovaps 32(%[bias]), %%ymm8\n" + "vmovaps 32(%[bias]), %%ymm9\n" + "vmovaps 32(%[bias]), %%ymm10\n" + "vmovaps 32(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm8, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm9, 96(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm10, 128(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm11, 160(%[dst], %[dst_stride], 1)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7efdcd8d0e0f5376860d29693b87a4396764ece7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,201 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..67457e0e5e87371ed6aae642728e5acde94c79cf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,215 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..e99e39da7ebab9fa149b4ca4a33dfe50acf1f4d1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,225 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..1923c152d04e43d4e76109bcc373a159231a11d8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,237 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..e11e2af8412c26e329a8d2630b41c35d71b00581 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,249 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..c391edc04532ad134df4bc8c394d69d0ef54a348 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,259 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..3dc30cff4e321a8d6f2d5981ad462a86fba2797b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,273 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..3f25c287df933bd4b2388e363880ab7d0bdc9419 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,281 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8cb78d4c1a89d2a40c08d23a41c0f8df48ce11cd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,536 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_9]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7de18d347b284c2c1fcf6b73d7500899384ef92b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,784 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 0(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 4(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 8(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 12(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 16(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 20(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 24(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 28(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 32(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 36(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 40(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 44(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 48(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 52(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 56(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 60(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 0(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..90642d55d95650f53c06cc13c3ac3516143b831f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,577 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm10\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", + "%zmm10"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_9]) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_9], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8988e02c693ab4d15eff4182fb3a9fcf5158c346 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,847 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_9], %[dst_stride], 1), %%zmm21\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 4(%[src_6]), %%zmm29\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_9]), %%zmm26\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 8(%[src_6]), %%zmm29\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_9]), %%zmm26\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 12(%[src_6]), %%zmm29\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_9]), %%zmm26\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 16(%[src_6]), %%zmm29\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_9]), %%zmm26\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 20(%[src_6]), %%zmm29\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_9]), %%zmm26\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 24(%[src_6]), %%zmm29\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_9]), %%zmm26\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 28(%[src_6]), %%zmm29\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_9]), %%zmm26\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 32(%[src_6]), %%zmm29\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_9]), %%zmm26\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 36(%[src_6]), %%zmm29\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_9]), %%zmm26\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 40(%[src_6]), %%zmm29\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_9]), %%zmm26\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 44(%[src_6]), %%zmm29\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_9]), %%zmm26\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 48(%[src_6]), %%zmm29\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_9]), %%zmm26\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 52(%[src_6]), %%zmm29\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_9]), %%zmm26\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 56(%[src_6]), %%zmm29\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_9]), %%zmm26\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 60(%[src_6]), %%zmm29\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_9]), %%zmm26\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9]) %{{%%k1}}\n" + "vmovups %%zmm20, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_9], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1970c957247376aee28890dfabe7f478c83454ee --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,617 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm10\n" + "vmovups 0(%[dst_9], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 0(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 4(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 8(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 12(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 16(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 20(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 24(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 28(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 32(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 36(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 40(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 44(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 48(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 52(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 56(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 60(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10 %{{%%k1}}\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10 %{{%%k1}}\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_9]) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_9], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm11, 0(%[dst_9], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c22649ec1a4f74e6b902133b456ef6d2489cfd72 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,911 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_9], %[dst_stride], 1), %%zmm21\n" + "vmovups 0(%[dst_9], %[dst_stride], 2), %%zmm22\n" + "vmovups 64(%[dst_9], %[dst_stride], 2), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 0(%[bias]), %%zmm22\n" + "vmovups 64(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 4(%[src_6]), %%zmm29\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_9]), %%zmm26\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 8(%[src_6]), %%zmm29\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_9]), %%zmm26\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 12(%[src_6]), %%zmm29\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_9]), %%zmm26\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 16(%[src_6]), %%zmm29\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_9]), %%zmm26\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 20(%[src_6]), %%zmm29\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_9]), %%zmm26\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 24(%[src_6]), %%zmm29\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_9]), %%zmm26\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 28(%[src_6]), %%zmm29\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_9]), %%zmm26\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 32(%[src_6]), %%zmm29\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_9]), %%zmm26\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 36(%[src_6]), %%zmm29\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_9]), %%zmm26\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 40(%[src_6]), %%zmm29\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_9]), %%zmm26\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 44(%[src_6]), %%zmm29\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_9]), %%zmm26\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 48(%[src_6]), %%zmm29\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_9]), %%zmm26\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 52(%[src_6]), %%zmm29\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_9]), %%zmm26\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 56(%[src_6]), %%zmm29\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_9]), %%zmm26\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 60(%[src_6]), %%zmm29\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_9]), %%zmm26\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21 %{{%%k1}}\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21 %{{%%k1}}\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9]) %{{%%k1}}\n" + "vmovups %%zmm20, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_9], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm22, 0(%[dst_9], %[dst_stride], 2)\n" + "vmovups %%zmm23, 64(%[dst_9], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d1174d80059da4c34b43b4313b07131546268292 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,161 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d4ecdf146d8d6880a03b1634824cd3893a4d957c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,201 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..149bb2d88f28ea31fe726e7aede26bb07fd525ff --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,241 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..4e9447052cdc8ecea57484478e6b8c5670141194 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,281 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..885fad1624b3a26728d94a1050e12988393b7962 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32.c @@ -0,0 +1,321 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..12eecaf66b20babd2083a0c4911fec6822013d2a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32.c @@ -0,0 +1,361 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7e4d87e60426f70b4b2bd0cc0ea157f3a11e9d94 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,201 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..be5f4408b57c03661e3f5d8f84e205a21488dc39 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,264 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..996d1af1e310c03edff4db03f1b13d5b5e63a3a1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,327 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..63d538b3eaaf718fd9aa1bbbd18c9bf2405d42e0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,390 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d18a9654bf29c24cbe7a8571105198d06537d07a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32.c @@ -0,0 +1,453 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..ddc82b0afa6ca455e986753b040e63d99b3d099e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32.c @@ -0,0 +1,517 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..bb3bc480209db85fb7adce722ec1b9bca2b76427 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,241 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..dc26494b4a57eb9ec8f05ed046395024f84b162a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,327 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..e0c9413396c3b586dbc0993967e8a786159475fd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,413 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7dce3d6bd68c576519cbff951500602c66eb8bd5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,500 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d039b5b8365f62905d792b675a165b9d99b968a7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32.c @@ -0,0 +1,586 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..11e7ae2507ad3ae9ed3ef30e5b13eb5fe322f7ee --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32.c @@ -0,0 +1,672 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n" + "vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 256(%[bias]), %%zmm16\n" + "vmovups 320(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..211edb7be31ea9e587b0827666c167f679a36a90 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,286 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..697b6c9c9cba3666a44a2b5e4843f51fb87171a3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,395 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..534ecd34701d63e1d0417ba354aa8963076bdd42 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,505 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..3dd8b7e3a3a254f651d7f03382994e62a41e9fd4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,614 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..59e3ac4e51812d95cf8e43ea97f7560c13977ad1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32.c @@ -0,0 +1,723 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 0(%[dst_3]), %%zmm15\n" + "vmovups 64(%[dst_3]), %%zmm16\n" + "vmovups 128(%[dst_3]), %%zmm17\n" + "vmovups 192(%[dst_3]), %%zmm18\n" + "vmovups 256(%[dst_3]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 192(%[bias]), %%zmm18\n" + "vmovups 256(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm15, 0(%[dst_3])\n" + "vmovups %%zmm16, 64(%[dst_3])\n" + "vmovups %%zmm17, 128(%[dst_3])\n" + "vmovups %%zmm18, 192(%[dst_3])\n" + "vmovups %%zmm19, 256(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..6bdd51405da3ad224cb55f7b46913c6fc1c87233 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32.c @@ -0,0 +1,833 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n" + "vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_3]), %%zmm18\n" + "vmovups 64(%[dst_3]), %%zmm19\n" + "vmovups 128(%[dst_3]), %%zmm20\n" + "vmovups 192(%[dst_3]), %%zmm21\n" + "vmovups 256(%[dst_3]), %%zmm22\n" + "vmovups 320(%[dst_3]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 256(%[bias]), %%zmm16\n" + "vmovups 320(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "vmovups 192(%[bias]), %%zmm21\n" + "vmovups 256(%[bias]), %%zmm22\n" + "vmovups 320(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_3])\n" + "vmovups %%zmm19, 64(%[dst_3])\n" + "vmovups %%zmm20, 128(%[dst_3])\n" + "vmovups %%zmm21, 192(%[dst_3])\n" + "vmovups %%zmm22, 256(%[dst_3])\n" + "vmovups %%zmm23, 320(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..9306e9d857d478fed90be8c7618f79e1be6f485e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,326 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c1aea19b14f7549afd9a7156a7a695081e8dc03a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,458 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..dd6e4627d21e43301b8322ba18a18b44b882463b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,591 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..fa03a12d24428ef3c5ad19586c933a97946b891d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,723 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 128(%[bias]), %%zmm18\n" + "vmovups 192(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..b2e60fa11a0af40d22b50a9ff4a75cdfac2f2716 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32.c @@ -0,0 +1,856 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 0(%[dst_3]), %%zmm15\n" + "vmovups 64(%[dst_3]), %%zmm16\n" + "vmovups 128(%[dst_3]), %%zmm17\n" + "vmovups 192(%[dst_3]), %%zmm18\n" + "vmovups 256(%[dst_3]), %%zmm19\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm21\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm22\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm23\n" + "vmovups 256(%[dst_3], %[dst_stride], 1), %%zmm24\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 192(%[bias]), %%zmm18\n" + "vmovups 256(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 128(%[bias]), %%zmm22\n" + "vmovups 192(%[bias]), %%zmm23\n" + "vmovups 256(%[bias]), %%zmm24\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + "vxorps %%zmm24, %%zmm24, %%zmm24\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "vmaxps %%zmm24, %%zmm31, %%zmm24 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + "vminps %%zmm24, %%zmm30, %%zmm24 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm15, 0(%[dst_3])\n" + "vmovups %%zmm16, 64(%[dst_3])\n" + "vmovups %%zmm17, 128(%[dst_3])\n" + "vmovups %%zmm18, 192(%[dst_3])\n" + "vmovups %%zmm19, 256(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm24, 256(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..963d83d5c5cea7098873446bc444bab124d4dc09 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,366 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..db0e244d58810f8022c126148a038422ba1cfb45 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,522 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..38dfe4c1b6b7623dc1c433598ecdb81f464ae3c9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,677 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..3dd0b65eeaa32582c3eb8245fb39b21bd53c98f7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,833 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm20\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm21\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm22\n" + "vmovups 192(%[dst_3], %[dst_stride], 2), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 128(%[bias]), %%zmm18\n" + "vmovups 192(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 128(%[bias]), %%zmm22\n" + "vmovups 192(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..67f76c0b37c574775aa1e351a77dbc2a195459b1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,410 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..16c5ed1ea0cdc6eeef01c144dd28666b67905dcd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,589 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..928e80c72a8a11dec34e930bf50e71a23bd85c95 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,767 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_6]), %%zmm18\n" + "vmovups 64(%[dst_6]), %%zmm19\n" + "vmovups 128(%[dst_6]), %%zmm20\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 0(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 4(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 8(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 12(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 16(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 20(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 24(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 28(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 32(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 36(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 40(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 44(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 48(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 52(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 56(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 60(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 0(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_6])\n" + "vmovups %%zmm19, 64(%[dst_6])\n" + "vmovups %%zmm20, 128(%[dst_6]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..644154c274658270d6ea7d81b73410a8075fcecc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,450 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..bc15930ea49cc8aaa7c71d9359a02ea20f80d309 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,652 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d3c08cf0a7fc2e93ae27de80c4a272cfb4b79fff --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,854 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_6]), %%zmm18\n" + "vmovups 64(%[dst_6]), %%zmm19\n" + "vmovups 128(%[dst_6]), %%zmm20\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm21\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm22\n" + "vmovups 128(%[dst_6], %[dst_stride], 1), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "vmovups 0(%[bias]), %%zmm21\n" + "vmovups 64(%[bias]), %%zmm22\n" + "vmovups 128(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_6]), %%zmm27\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_6]), %%zmm27\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_6]), %%zmm27\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_6]), %%zmm27\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_6]), %%zmm27\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_6]), %%zmm27\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_6]), %%zmm27\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_6]), %%zmm27\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_6]), %%zmm27\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_6]), %%zmm27\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_6]), %%zmm27\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_6]), %%zmm27\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_6]), %%zmm27\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_6]), %%zmm27\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_6]), %%zmm27\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_6]), %%zmm27\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_6]), %%zmm27\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20 %{{%%k1}}\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20 %{{%%k1}}\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_6])\n" + "vmovups %%zmm19, 64(%[dst_6])\n" + "vmovups %%zmm20, 128(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm21, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm22, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm23, 128(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..76431df87c88daaf21dbc68af43ac4b19a3db85f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,490 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..4b3e39b72cfda680eeac6392e9c55b33af901a5f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,715 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/generate_hpc.sh b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/generate_hpc.sh new file mode 100644 index 0000000000000000000000000000000000000000..cd2fbf519b4bae376ca3cef489e7c8099e011ea1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/generate_hpc.sh @@ -0,0 +1,88 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +CRTDIR=$( cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# generate gemm fma asm code +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c +# +## generate gemm fma intrinics code +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c + +# generate gemm avx512 asm code +n=(96 80 64 48 32 16) +m=(4 5 6 8 12 12) +for ((index = 0; index < 6; index++)) +do + for ((row = 1; row <= ${m[index]}; row++)) + do + dst_file=$CRTDIR"/gemm_avx512/nnacl_gemm_avx512_$row""x${n[index]}_kernel_nhwc_fp32.c" + python3 $CRTDIR/generator.py -I $CRTDIR/template_file/gemm_avx512_nhwc_asm.c.in -A row_block=$row col_block=${n[index]} -O $dst_file + done +done diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/generator.py b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..89342097d9b7d703d4cb56a75f5a7773dd5cc1f7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/generator.py @@ -0,0 +1,162 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""HPC generator""" + +import sys +import os +import io +import stat +import argparse +from itertools import chain + +def key_value_pair(line): + """ + split key and value + :param line: + :return: + """ + key = None + value = None + try: + key, value = line.split("=", 1) + except ValueError: + print("line must be format: key=value, but now is:", line) + sys.exit(1) + try: + value = int(value) + except ValueError: + print("Error: you input value must be integer, but now is:", value) + sys.exit(1) + return key, value + +def get_indent(line): + """ + get indent length + :param line: + :return: + """ + index = 0 + for i in line: + if i == " ": + index += 1 + else: + break + return index + +def print_line(line): + """ + Convert line to a python string + :param line: + :return: + """ + global PYTHON_INDENT + global GENERATE_CODE_INDENT + if line.strip()[0] == "}" or line.strip()[0] == ")": + PYTHON_INDENT = -1 + split_str = line.split("@") + if line.strip()[0] != "@" and len(split_str) == 1: + if get_indent(line) == PYTHON_INDENT or PYTHON_INDENT == -1: + result = ["print(", line, ", file=OUT_STREAM)"] + PYTHON_INDENT = -1 + if "{" in line or "asm volatile(" in line: + GENERATE_CODE_INDENT = get_indent(line) + if line.strip().startswith("}") and "{" not in line: + GENERATE_CODE_INDENT -= 4 + if len(line) == 1 and line[0] == "}": + # modify next fun GENERATE_CODE_INDENT + GENERATE_CODE_INDENT = -4 + return "\"".join(result) + + if line.strip()[0] == '@': + # get python indent and first GENERATE_CODE_INDENT + if PYTHON_INDENT == -1: + GENERATE_CODE_INDENT = get_indent(line) - 4 + PYTHON_INDENT = get_indent(line) + result = split_str[0][PYTHON_INDENT:] + split_str[1] + return result + + index = get_indent(split_str[0]) + result = [split_str[0][PYTHON_INDENT:index] + "print("] + prefix = " " * (GENERATE_CODE_INDENT + 4) + split_str[0].lstrip() + + suffix = " %(" + for str_tmp in split_str[1:]: + second = str_tmp.find("}") + suffix += str_tmp[1:second] + ', ' + str_tmp = str_tmp.replace(str_tmp[0:second + 1], "%d") + prefix += str_tmp + result.append(prefix) + result.append(suffix + "), file=OUT_STREAM)") + return "\"".join(result) + +def generate_code(template_file, exec_dict): + """ + generate hpc + :param template_file: template file path + :param exec_dict: dict + :return: hpc + """ + output_stream = io.StringIO() + with open(template_file, 'r') as f: + generate_code_lines = [] + for line in f: + line = line.replace("\n", "") + if line.strip() and line.strip()[0] != "@": + line = line.replace("\"", "\\\"") + line = line.replace("%", "%%") + if "print" in line: + line = line.replace("%%", "%") + if not line: + generate_code_lines.append("print(" + "\"" + line + "\"" + ", file=OUT_STREAM)") + else: + str = print_line(line) + if "%(" not in str: + str = str.replace("%%[", "%[") + generate_code_lines.append(str) + c = compile('\n'.join(generate_code_lines), '', 'exec') + exec_dict["OUT_STREAM"] = output_stream + exec(c, exec_dict) + return output_stream.getvalue() + +def check_python_version(): + if sys.version_info < (3, 6): + sys.stdout.write("At least python 3.6 is required, but now is " + str(sys.version_info.major) + "." + + str(sys.version_info.minor) + "\n") + sys.exit(1) + +GENERATE_CODE_INDENT = -4 +PYTHON_INDENT = -1 + +parser = argparse.ArgumentParser(description="MSLite NNACL Code Generator") +parser.add_argument("-I", dest="Template_File", nargs=1, help="template file to generate code") +parser.add_argument("-A", dest="defines", metavar="KEY=VALUE", nargs="*", type=key_value_pair, action="append", + help="Custom Parameters") +parser.add_argument("-O", dest="Output_File", nargs=1, help="generate code output file path") + +if __name__ == "__main__": + check_python_version() + parameters = parser.parse_args(sys.argv[1:]) + exec_globals = dict(chain(*parameters.defines)) + + generate_code_str = generate_code(parameters.Template_File[0], exec_globals) + if os.path.exists(parameters.Output_File[0]): + os.remove(parameters.Output_File[0]) + + saveDir = os.path.dirname(parameters.Output_File[0]) + if not os.path.exists(saveDir): + os.mkdir(saveDir, 0o700) + with open(parameters.Output_File[0], "w", encoding='utf-8') as output_file: + output_file.write(generate_code_str) + os.chmod(parameters.Output_File[0], stat.S_IWUSR + stat.S_IRUSR) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_avx512_mask_nhwc_asm.c.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_avx512_mask_nhwc_asm.c.in new file mode 100644 index 0000000000000000000000000000000000000000..d60110df76b064579c8d323bd477280e94a02e01 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_avx512_mask_nhwc_asm.c.in @@ -0,0 +1,263 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_@{row_block}x@{col_block}_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, + const float *bias, const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, const u_int16_t* mask) { + @import math + @row_stride_map = {6 : 4, 5 : 5, 4 : 6, 3 : 8, 2 : 12, 1 : 20} + @src_addr_stride = 3 + @asm_flag_list = [] + @row_split_number = [row for row in range(3, row_block, 3)] + @for row in row_split_number: + const float *dst_@{row} = dst + @{row} * dst_stride; + @asm_flag_list.append("[dst_" + str(row) + "] " + "\"r\"(dst_" + str(row) + ")"); + size_t dst_stride_t = dst_stride << 2; + @col_split_num = col_block >> 4; + asm volatile( + // inc in depth + "movq %[inc_flag], %rax\\n" + "kmovw (%[mask]), %k1\\n" + "and $0x1, %rax\\n" + "je 0f\\n" + @for row in range(0, row_block): + @src_addr = int(row / 3) * 3 + @for col in range(0, col_split_num): + @if row % 3 == 0: + "vmovups @{col * 64}(%[dst_@{src_addr}]), %%zmm@{row * col_split_num + col}\\n" + @else: + "vmovups @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr}), %%zmm@{row * col_split_num + col}\\n" + "jmp 2f\\n" + ".align 16\\n" + "0:\\n" + "cmpq $0, %[bias]\\n" + "je 1f\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vmovups @{col * 64}(%[bias]), %%zmm@{row * col_split_num + col}\\n" + "jmp 2f\\n" + ".align 16\\n" + "1:\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vxorps %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}\\n" + ".align 16\\n" + "2:\\n" + : + @list = ["[dst_0] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)", "[mask] \"r\"(mask)"] + @list.extend(asm_flag_list) + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : \"%rax\", \"%k1\", " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, row_block * col_block >>4)]), file=OUT_STREAM) + ); + @for row in row_split_number: + const float *src_@{row} = src + @{row} * src_stride; + @asm_flag_list.append("[src_" + str(row) + "] " + "\"r\"(src_" + str(row) + ")"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %k1\\n" + @loop_count = 16 + "cmp $@{loop_count}, %[depth]\\n" + "jb 1f\\n" + ".align 16\\n" + "0:\\n" + @for i in range(0, loop_count): + // block @{i} + @for col in range(0, col_split_num): + "vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n" + @if row_block * col_split_num + row_block + col_split_num <= 32: + @for row in range(0, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - row + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_block): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @else: + @row_stride = 32 - (row_stride_map[col_split_num] + 1) * col_split_num + @row_split_num = math.floor(row_block / row_stride) + @for row_index in range(0, row_split_num): + @row_split_start = row_index * row_stride + @for row in range(row_split_start, row_split_start + row_stride): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_stride): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = (row_split_start + row) * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @row_split_start = row_split_num * row_stride + @for row in range(row_split_start, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(row_split_start, row_block): + @src_index = 31 - col_split_num - (row - row_split_start) + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + "add $@{col_block * 4 * loop_count}, %[weight]\\n" + "add $@{loop_count * 4}, %[src_0]\\n" + @for row in row_split_number: + "add $@{loop_count * 4}, %[src_@{row}]\\n" + "sub $@{loop_count}, %[depth]\\n" + "cmp $@{loop_count}, %[depth]\\n" + "jge 0b\\n" + "cmp $0, %[depth]\\n" + "je 2f\\n" + ".align 16\\n" + "1:\\n" + @loop_count = 1 + @for i in range(0, loop_count): + // block @{i} + @for col in range(0, col_split_num): + "vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n" + @if row_block * col_split_num + row_block + col_split_num <= 32: + @for row in range(0, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - row + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_block): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @else: + @row_stride = 32 - (row_stride_map[col_split_num] + 1) * col_split_num + @row_split_num = math.floor(row_block / row_stride) + @for row_index in range(0, row_split_num): + @row_split_start = row_index * row_stride + @for row in range(row_split_start, row_split_start + row_stride): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_stride): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = (row_split_start + row) * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @row_split_start = row_split_num * row_stride + @for row in range(row_split_start, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(row_split_start, row_block): + @src_index = 31 - col_split_num - (row - row_split_start) + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + "add $@{col_block * 4 * loop_count}, %[weight]\\n" + "add $@{loop_count * 4}, %[src_0]\\n" + @for row in row_split_number: + "add $@{loop_count * 4}, %[src_@{row}]\\n" + "dec %[depth]\\n" + "jg 1b\\n" + ".align 16\\n" + "2:\\n" + "and $0x2, %[inc_flag]\\n" + "je 3f\\n" + "and $0x3, %[act_flag]\\n" + "je 3f\\n" + // relu + "vxorps %zmm31, %zmm31, %zmm31\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + @if col == col_split_num - 1: + "vmaxps %%zmm@{row * col_split_num + col}, %%zmm31, %%zmm@{row * col_split_num + col} %{{%%k1}}\\n" + @else: + "vmaxps %%zmm@{row * col_split_num + col}, %%zmm31, %%zmm@{row * col_split_num + col}\\n" + "and $0x1, %[act_flag]\\n" + "je 3f\\n" + // relu6 + "mov $0x40C00000, %eax\\n" + "vmovd %eax, %xmm30\\n" + "vbroadcastss %%xmm@{30}, %%zmm30\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + @if col == col_split_num - 1: + "vminps %%zmm@{row * col_split_num + col}, %%zmm30, %%zmm@{row * col_split_num + col} %{{%%k1}}\\n" + @else: + "vminps %%zmm@{row * col_split_num + col}, %%zmm30, %%zmm@{row * col_split_num + col}\\n" + ".align 16\\n" + "3:\\n" + @for row in range(0, row_block): + @src_addr = int(row / 3) * 3 + @for col in range(0, col_split_num): + @if col == col_split_num - 1: + @if row % 3 == 0: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}]) %{{%%k1}}\\n" + @else: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr}) %{{%%k1}}\\n" + @else: + @if row % 3 == 0: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}])\\n" + @else: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr})\\n" + : + @list = ["[src_0] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[depth] \"r\"(depth)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst_0] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)", "[mask] \"r\"(mask)"] + @list.extend(asm_flag_list) + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, 32)]), file=OUT_STREAM) + ); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in new file mode 100644 index 0000000000000000000000000000000000000000..335ed7f244937b5f10a6aed9be1a1613a7daea5d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in @@ -0,0 +1,231 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, + const float *bias, const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag) { + @import math + @row_stride_map = {6 : 4, 5 : 5, 4 : 6, 3 : 8, 2 : 12, 1 : 20} + @src_addr_stride = 3 + @asm_flag_list = [] + @row_split_number = [row for row in range(3, row_block, 3)] + @for row in row_split_number: + const float *dst_@{row} = dst + @{row} * dst_stride; + @asm_flag_list.append("[dst_" + str(row) + "] " + "\"r\"(dst_" + str(row) + ")"); + size_t dst_stride_t = dst_stride << 2; + @col_split_num = col_block >> 4; + asm volatile( + // inc in depth + "movq %[inc_flag], %rax\\n" + "and $0x1, %rax\\n" + "je 0f\\n" + @for row in range(0, row_block): + @src_addr = int(row / 3) * 3 + @for col in range(0, col_split_num): + @if row % 3 == 0: + "vmovups @{col * 64}(%[dst_@{src_addr}]), %%zmm@{row * col_split_num + col}\\n" + @else: + "vmovups @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr}), %%zmm@{row * col_split_num + col}\\n" + "jmp 2f\\n" + ".align 16\\n" + "0:\\n" + "cmpq $0, %[bias]\\n" + "je 1f\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vmovups @{col * 64}(%[bias]), %%zmm@{row * col_split_num + col}\\n" + "jmp 2f\\n" + ".align 16\\n" + "1:\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vxorps %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}\\n" + ".align 16\\n" + "2:\\n" + : + @list = ["[dst_0] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)"] + @list.extend(asm_flag_list) + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, row_block * col_block >>4)]), file=OUT_STREAM) + ); + @for row in row_split_number: + const float *src_@{row} = src + @{row} * src_stride; + @asm_flag_list.append("[src_" + str(row) + "] " + "\"r\"(src_" + str(row) + ")"); + size_t src_stride_t = src_stride << 2; + asm volatile( + @loop_count = 16 + "cmp $@{loop_count}, %[depth]\\n" + "jb 1f\\n" + ".align 16\\n" + "0:\\n" + @for i in range(0, loop_count): + // block @{i} + @for col in range(0, col_split_num): + "vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n" + @if row_block * col_split_num + row_block + col_split_num <= 32: + @for row in range(0, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - row + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_block): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @else: + @row_stride = 32 - (row_stride_map[col_split_num] + 1) * col_split_num + @row_split_num = math.floor(row_block / row_stride) + @for row_index in range(0, row_split_num): + @row_split_start = row_index * row_stride + @for row in range(row_split_start, row_split_start + row_stride): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_stride): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = (row_split_start + row) * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @row_split_start = row_split_num * row_stride + @for row in range(row_split_start, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(row_split_start, row_block): + @src_index = 31 - col_split_num - (row - row_split_start) + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + "add $@{col_block * 4 * loop_count}, %[weight]\\n" + "add $@{loop_count * 4}, %[src_0]\\n" + @for row in row_split_number: + "add $@{loop_count * 4}, %[src_@{row}]\\n" + "sub $@{loop_count}, %[depth]\\n" + "cmp $@{loop_count}, %[depth]\\n" + "jge 0b\\n" + "cmp $0, %[depth]\\n" + "je 2f\\n" + ".align 16\\n" + "1:\\n" + @loop_count = 1 + @for i in range(0, loop_count): + // block @{i} + @for col in range(0, col_split_num): + "vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n" + @if row_block * col_split_num + row_block + col_split_num <= 32: + @for row in range(0, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - row + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_block): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @else: + @row_stride = 32 - (row_stride_map[col_split_num] + 1) * col_split_num + @row_split_num = math.floor(row_block / row_stride) + @for row_index in range(0, row_split_num): + @row_split_start = row_index * row_stride + @for row in range(row_split_start, row_split_start + row_stride): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_stride): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = (row_split_start + row) * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @row_split_start = row_split_num * row_stride + @for row in range(row_split_start, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(row_split_start, row_block): + @src_index = 31 - col_split_num - (row - row_split_start) + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + "add $@{col_block * 4 * loop_count}, %[weight]\\n" + "add $@{loop_count * 4}, %[src_0]\\n" + @for row in row_split_number: + "add $@{loop_count * 4}, %[src_@{row}]\\n" + "dec %[depth]\\n" + "jg 1b\\n" + ".align 16\\n" + "2:\\n" + "and $0x2, %[inc_flag]\\n" + "je 3f\\n" + "and $0x3, %[act_flag]\\n" + "je 3f\\n" + // relu + "vxorps %zmm31, %zmm31, %zmm31\\n" + @for col in range(0, col_split_num): + @for row in range(0, row_block): + "vmaxps %%zmm@{row + col * row_block}, %%zmm31, %%zmm@{row + col * row_block}\\n" + "and $0x1, %[act_flag]\\n" + "je 3f\\n" + // relu6 + "mov $0x40C00000, %eax\\n" + "vmovd %eax, %xmm30\\n" + "vbroadcastss %xmm30, %zmm30\\n" + @for col in range(0, col_split_num): + @for row in range(0, row_block): + "vminps %%zmm@{row + col * row_block}, %%zmm30, %%zmm@{row + col * row_block}\\n" + ".align 16\\n" + "3:\\n" + @for row in range(0, row_block): + @src_addr = int(row / 3) * 3 + @for col in range(0, col_split_num): + @if row % 3 == 0: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}])\\n" + @else: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr})\\n" + : + @list = ["[src_0] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[depth] \"r\"(depth)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst_0] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"] + @list.extend(asm_flag_list) + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : \"%rax\", " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, 32)]), file=OUT_STREAM) + ); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_fma_nc8hw8.c.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_fma_nc8hw8.c.in new file mode 100644 index 0000000000000000000000000000000000000000..641b1857e787440a97bf91843a781afaa307baaf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_fma_nc8hw8.c.in @@ -0,0 +1,85 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_@{row_block}x@{col_block}_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, + const float *bias, const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t deep, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag) { + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + __m256 dst@{j * row_block + i}; + if (inc_flag) { + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{j * row_block + i} = _mm256_load_ps(dst + @{j} * dst_stride + @{i * 8}); + } else if (bias == NULL) { + @for i in range(0, row_block * col_block >> 3): + dst@{i} = _mm256_setzero_ps(); + } else { + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{j * row_block + i} = _mm256_load_ps(bias + @{j * 8}); + } + for (int i = 0; i < (deep >> 3); ++i) { + @for i in range(0, 8): + // bock@{i} + @if col_block == 32: + @for row in range(0, row_block): + __m256 src@{row}@{i} = _mm256_set1_ps(*(src + @{row * 8 + i})); + @for col in range(0, col_block >> 3): + __m256 weight@{col}@{i} = _mm256_load_ps(weight + @{col * 8 + i * col_block}); + @for row in range(0, row_block): + dst@{row + col * row_block} = _mm256_fmadd_ps(dst@{row + col * row_block}, src@{row}@{i}, weight@{col}@{i}); + @else: + @for col in range(0, col_block >> 3): + __m256 weight@{col}@{i} = _mm256_load_ps(weight + @{col * 8 + i * col_block}); + @for row in range(0, row_block): + __m256 src@{row}@{i} = _mm256_set1_ps(*(src + @{row * 8 + i})); + @for col in range(0, col_block >> 3): + dst@{row + col * row_block} = _mm256_fmadd_ps(dst@{row + col * row_block}, src@{row}@{i}, weight@{col}@{i}); + src = src + src_stride; + weight += @{8 * col_block * 4}; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{i + j * row_block} = _mm256_min_ps(dst@{i + j * row_block}, relu6); + // relu + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{i + j * row_block} = _mm256_max_ps(dst@{i + j * row_block}, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{i + j * row_block} = _mm256_max_ps(dst@{i + j * row_block}, relu); + } + @if col_block == 32: + @for j in range(0, col_block >> 3): + @for i in range(0, row_block): + _mm256_store_ps(dst + @{j} * src_stride + @{i * 8}, dst@{j * row_block + i}); + @else: + @for j in range(0, col_block >> 3): + @for i in range(0, row_block): + _mm256_store_ps(dst + @{j} * src_stride + @{i * 8}, dst@{j * row_block + i}); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_fma_nc8hw8_asm.c.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_fma_nc8hw8_asm.c.in new file mode 100644 index 0000000000000000000000000000000000000000..70178cf583ceafa1cc6fb9485e5364b393d64af6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_fma_nc8hw8_asm.c.in @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_@{row_block}x@{col_block}_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, + const float *bias, const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t deep, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag) { + @if col_block == 32: + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\\n" + "je 0f\\n" + @for col in range(0, min((col_block >> 3), 3)): + @for row in range(0, row_block): + @if col == 0: + "vmovups @{row * 32}(%[dst]), %%ymm@{row + col * row_block}\\n" + @else: + "vmovups @{row * 32}(%[dst], %[dst_stride], @{col}), %%ymm@{row + col * row_block}\\n" + @if col_block == 32: + @for row in range(0, row_block): + "vmovups @{row * 32}(%[dst_4]), %%ymm@{row + (col + 1) * row_block}\\n" + "jmp 2f\\n" + "0:\\n" + "cmpq $0, %[bias]\\n" + "je 1f\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vmovaps @{col * 32}(%[bias]), %%ymm@{row + col * row_block}\\n" + "jmp 2f\\n" + "1:\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vxorps %%ymm@{row + col * row_block}, %%ymm@{row + col * row_block}, %%ymm@{row + col * row_block}\\n" + "2:\\n" + : + @list = ["[dst] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)"] + @if col_block == 32: + @list.append("[dst_4] \"r\"(dst_4)") + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : " + ", ".join(["\"%ymm" + str(i) + "\"" for i in range(0, row_block * col_block >> 3)]), file=OUT_STREAM) + ); + asm volatile( + "0:\\n" + @for i in range(0, 8): + // block @{i} + @if col_block == 32: + @for row in range(0, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{15 - row}\\n" + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - row_block}\\n" + @for row in range(0, row_block): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - row_block}, %%ymm@{15 - row}\\n" + @elif col_block == 24: + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n" + @for row in range(0, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{14 - col}\\n" + @for col in range(0, col_block >> 3): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3)}, %%ymm@{15 - col}\\n" + @elif col_block == 16: + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n" + @for row in range(0, row_block >> 1): + "vbroadcastss @{row * 64 + i}(%[src]), %%ymm@{14 - col}\\n" + "vbroadcastss @{row * 64 + 32 + i}(%[src]), %%ymm@{13 - col}\\n" + @for col in range(0, col_block >> 3): + @for j in range(0, 2): + "vfmadd231ps %%ymm@{row * 2 + j + col * row_block}, %%ymm@{15 - (col_block >> 3) - j}, %%ymm@{15 - col}\\n" + @for row in range(row_block >> 1 << 1, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{14 - col}\\n" + @for col in range(0, col_block >> 3): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3)}, %%ymm@{15 - col}\\n" + @else: + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n" + @split_num = 3 + @for row in range(0, int(row_block / split_num)): + @for j in range(0, split_num): + "vbroadcastss @{row * 96 + j * 32 + i}(%[src]), %%ymm@{15 - (col_block >> 3) - j}\\n" + @for col in range(0, col_block >> 3): + @for j in range(0, split_num): + "vfmadd231ps %%ymm@{row * split_num + j + col * row_block}, %%ymm@{15 - (col_block >> 3) - j}, %%ymm@{15 - col}\\n" + @for row in range(int(row_block / split_num) * split_num, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{15 - (col_block >> 3) - (row - int(row_block / split_num) * split_num)}\\n" + @for col in range(0, col_block >> 3): + @for row in range(int(row_block / split_num) * split_num, row_block): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3) - (row - int(row_block / split_num) * split_num)}, %%ymm@{15 - col}\\n" + "dec %[deep]\\n" + "add @{col_block * 4 * 8}, %[weight]\\n" + "add %[src_stride], %[src]\\n" + "jg 0b\\n" + + "movq %[inc_flag], %rax\\n" + "and $0x2, %eax\\n" + "je 3f\\n" + "movq %[act_flag], %rax\\n" + "and $0x3, %eax\\n" + "je 3f\\n" + // relu + "vxorps %ymm15, %ymm15, %ymm15\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vmaxps %%ymm@{row + col * row_block}, %%ymm15, %%ymm@{row + col * row_block}\\n" + "and $0x1, %eax\\n" + "je 3f\\n" + // relu6 + "mov $0x40C00000, %eax\\n" + "vmovd %eax, %xmm14\\n" + "vpermps %ymm14, %ymm15, %ymm14\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vminps %%ymm@{row + col * row_block}, %%ymm14, %%ymm@{row + col * row_block}\\n" + "3:\\n" + @for col in range(0, min((col_block >> 3), 3)): + @for row in range(0, row_block): + @if col == 0: + "vmovups %%ymm@{row + col * row_block}, @{row * 32}(%[dst])\\n" + @else: + "vmovups %%ymm@{row + col * row_block}, @{row * 32}(%[dst], %[dst_stride], @{col})\\n" + @if col_block == 32: + @for row in range(0, row_block): + "vmovups %%ymm@{row + (col + 1) * row_block}, @{row * 32}(%[dst_4])\\n" + : + @list = ["[src] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[deep] \"r\"(deep_t)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"] + @if col_block == 32: + @list.append("[dst_4] \"r\"(dst_4)") + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : \"%rax\", " + ", ".join(["\"%ymm" + str(i) + "\"" for i in range(0, 16)]), file=OUT_STREAM) + ); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fill_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fill_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..433bd4bb295509d6603402989d8a0bb1426216c7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fill_parameter.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FILL_PARAMETER_H_ +#define NNACL_FILL_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct FillParameter { + OpParameter op_parameter_; +} FillParameter; + +#endif // NNACL_FILL_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/flatten_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/flatten_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..74049d6962d25489e3c3a662cea3d20733cd3c2d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/flatten_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FLATTEN_PARAMETER_H_ +#define NNACL_FLATTEN_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct FlattenParameter { + OpParameter op_parameter_; + int axis_; +} FlattenParameter; + +#endif // NNACL_FLATTEN_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/format_transpose_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/format_transpose_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..0015c1786f4923f0dc5a05701b3bdafddacf8602 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/format_transpose_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FORMAT_TRANSPOSE_PARAMETER_H_ +#define NNACL_FORMAT_TRANSPOSE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/infer/common_infer.h" +static const int FormatTransposeInput = 2; +typedef struct FormatTransposeParameter { + // Primitive parameter + OpParameter op_parameter_; + FormatC src_format_; + FormatC dst_format_; +} FormatTransposeParameter; + +#endif // NNACL_FORMAT_TRANSPOSE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/activation_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/activation_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..d03d2798cf148f0727a0f61001a24be5e9a4e574 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/activation_fp16.c @@ -0,0 +1,319 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/activation_fp16.h" +#include +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/fp16/exp_fp16.h" +#include "nnacl_c/errorcode.h" + +int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) { + int offset = 0; +#ifdef ENABLE_NEON + float16x8_t zero = vdupq_n_f16(0); + for (; offset <= ele_num - C8NUM; offset += C8NUM) { + float16x8_t src_value = vld1q_f16(src + offset); + float16x8_t rst_value = vmaxq_f16(src_value, zero); + vst1q_f16(dst + offset, rst_value); + } +#endif + for (; offset < ele_num; offset++) { + dst[offset] = src[offset] < 0.0f ? 0.0f : src[offset]; + } + return NNACL_OK; +} + +int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) { + int offset = 0; +#ifdef ENABLE_NEON + float16x8_t zero_data = vdupq_n_f16(0); + float16x8_t six_data = vdupq_n_f16(6); + for (; offset <= ele_num - C8NUM; offset += C8NUM) { + float16x8_t relu6_data = vld1q_f16(data + offset); + relu6_data = vmaxq_f16(relu6_data, zero_data); + relu6_data = vminq_f16(relu6_data, six_data); + vst1q_f16(dst + offset, relu6_data); + } +#endif + for (; offset < ele_num; offset++) { + dst[offset] = data[offset] < 0.0f ? 0.0f : data[offset]; + dst[offset] = dst[offset] > 6.0f ? 6.0f : dst[offset]; + } + return NNACL_OK; +} + +int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha) { + int offset = 0; +#ifdef ENABLE_NEON + float16x8_t zero_data = vdupq_n_f16(0); + float16x8_t alpha_data = vdupq_n_f16(alpha); + for (; offset <= ele_num - C8NUM; offset += C8NUM) { + float16x8_t src_tmp = vld1q_f16(src + offset); + float16x8_t mul_tmp = vmulq_f16(src_tmp, alpha_data); + uint16x8_t mask = vcleq_f16(src_tmp, zero_data); + vst1q_f16(dst + offset, vbslq_f16(mask, mul_tmp, src_tmp)); + } +#endif + for (; offset < ele_num; ++offset) { + dst[offset] = src[offset] > (float16_t)0.0f ? src[offset] : (src[offset] * alpha); + } + return NNACL_OK; +} + +int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { + int i = 0; +#ifdef ENABLE_NEON + int count = (ele_num / C4NUM) * C4NUM; + for (; i < count; i += C4NUM) { + float32x4_t tmp; + simd_exp128(vnegq_f32(vcvt_f32_f16(vld1_f16(src + i))), (float *)&tmp); + vst1_f16(dst + i, vcvt_f16_f32(MS_DIVQ_F32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), tmp)))); + } +#endif + for (; i < ele_num; ++i) { + float temp; + simd_exp32(-src[i], &temp); + dst[i] = (float16_t)1.0f / ((float16_t)1.0f + temp); + } + return NNACL_OK; +} + +float16_t TanhOptFp16(float16_t src) { + if (src > 5.0f) { + return 1.0f; + } else if (src < -5.0f) { + return -1.0f; + } else { + float square = src * src; + float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * src; + float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; + return a / b; + } +} + +int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { + int i = 0; +#ifdef ENABLE_NEON + static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f}, + {17325.0f, 17325.0f, 17325.0f, 17325.0f}, + {135135.0f, 135135.0f, 135135.0f, 135135.0f}, + {28.0f, 28.0f, 28.0f, 28.0f}, + {3150.0f, 3150.0f, 3150.0f, 3150.0f}, + {62370.0f, 62370.0f, 62370.0f, 62370.0f}}; + float32x4_t neg_one = {-1.0f, -1.0f, -1.0f, -1.0f}; + float32x4_t pos_one = {1.0f, 1.0f, 1.0f, 1.0f}; + int count = (ele_num / C4NUM) * C4NUM; + for (; i < count; i += C4NUM) { + float32x4_t input = vcvt_f32_f16(vld1_f16(src + i)); + float32x4_t square = vmulq_f32(input, input); + float32x4_t a = vmulq_f32( + vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(square, paramv[0]), square), paramv[1]), square), paramv[2]), + input); + float32x4_t b = vaddq_f32( + vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square), + paramv[2]); + vst1_f16(dst + i, vcvt_f16_f32(vminq_f32(vmaxq_f32(MS_DIVQ_F32(a, b), neg_one), pos_one))); + } +#endif + for (; i < ele_num; ++i) { + float input = src[i]; + float square = input * input; + float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * input; + float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; + dst[i] = a / b; + dst[i] = MSMAX(dst[i], -1); + dst[i] = MSMIN(dst[i], 1); + } + return NNACL_OK; +} + +int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num) { + int i = 0; +#ifdef ENABLE_NEON + const MS_FLOAT16X8 zero_data = vdupq_n_f16(0); + const MS_FLOAT16X8 three_data = vdupq_n_f16(3); + const MS_FLOAT16X8 six_data = vdupq_n_f16(6); + for (; i <= ele_num - C8NUM; i += C8NUM) { + MS_FLOAT16X8 in_data = MS_LDQ_F16(src + i); + MS_FLOAT16X8 tmp = MS_MAXQ_F16(in_data + three_data, zero_data); + tmp = MS_MINQ_F16(tmp, six_data); + MS_STQ_F16(dst + i, vmulq_f16(in_data, MS_DIVQ_F16(tmp, six_data))); + } +#endif + for (; i < ele_num; ++i) { + float16_t in = src[i]; + float16_t relu6 = MSMIN(MSMAX(in + 3.0f, 0.0f), 6.0f); + dst[i] = in * relu6 / (float16_t)6.0f; + } + return NNACL_OK; +} + +int SwishFp16(const float16_t *src, float16_t *dst, int ele_num) { + int i = 0; +#ifdef ENABLE_NEON + float32x4_t const_val = vdupq_n_f32(1.0f); + for (int num_max = ele_num - C16NUM; i <= num_max; i += C16NUM) { + float16x4x4_t ins = vld4_f16(src + i); + float32x4_t in0 = MS_CVT_F32_F16(ins.val[0]); + float32x4_t in1 = MS_CVT_F32_F16(ins.val[1]); + float32x4_t in2 = MS_CVT_F32_F16(ins.val[2]); + float32x4_t in3 = MS_CVT_F32_F16(ins.val[3]); + float32x4_t exp0 = simd_exp128_f32(vnegq_f32(in0)); + float32x4_t exp1 = simd_exp128_f32(vnegq_f32(in1)); + float32x4_t exp2 = simd_exp128_f32(vnegq_f32(in2)); + float32x4_t exp3 = simd_exp128_f32(vnegq_f32(in3)); + float32x4_t res0 = MS_DIVQ_F32(in0, vaddq_f32(const_val, exp0)); + float32x4_t res1 = MS_DIVQ_F32(in1, vaddq_f32(const_val, exp1)); + float32x4_t res2 = MS_DIVQ_F32(in2, vaddq_f32(const_val, exp2)); + float32x4_t res3 = MS_DIVQ_F32(in3, vaddq_f32(const_val, exp3)); + float16x4x4_t res = {MS_CVT_F16_F32(res0), MS_CVT_F16_F32(res1), MS_CVT_F16_F32(res2), MS_CVT_F16_F32(res3)}; + vst4_f16(dst + i, res); + } + for (int num_max = ele_num - C4NUM; i <= num_max; i += C4NUM) { + float32x4_t in = MS_CVT_F32_F16(vld1_f16(src + i)); + float16x4_t res = MS_CVT_F16_F32(MS_DIVQ_F32(in, vaddq_f32(const_val, simd_exp128_f32(vnegq_f32(in))))); + vst1_f16(dst + i, res); + } +#endif + for (; i < ele_num; ++i) { + float temp = simd_exp32_f32(-src[i]); + dst[i] = src[i] / (1.0f + temp); + } + return NNACL_OK; +} + +int HSigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { + int offset = 0; +#ifdef ENABLE_NEON + const MS_FLOAT16X8 zero_data = vdupq_n_f16(0); + const MS_FLOAT16X8 three_data = vdupq_n_f16(3); + const MS_FLOAT16X8 six_data = vdupq_n_f16(6); + for (; offset <= ele_num - C8NUM; offset += C8NUM) { + MS_FLOAT16X8 relu6_data = MS_LDQ_F16(src + offset) + three_data; + relu6_data = MS_MAXQ_F16(relu6_data, zero_data); + relu6_data = MS_MINQ_F16(relu6_data, six_data); + MS_STQ_F16(dst + offset, MS_DIVQ_F16(relu6_data, six_data)); + } +#endif + + for (; offset < ele_num; offset++) { + float16_t tmp = (src[offset] + 3.0 < 0.0) ? 0.0 : src[offset] + 3.0; + dst[offset] = ((tmp < 6.0) ? tmp : 6.0) / 6.0; + } + + return NNACL_OK; +} + +int HardTanhFp16(const float16_t *src, int length, float16_t *dst, float min_val, float max_val) { + if (max_val <= min_val) { + return NNACL_ERR; + } + int i = 0; + if (min_val == FLT_MIN) { + for (i = 0; i < length; ++i) { + dst[i] = src[i] > max_val ? max_val : src[i]; + } + } else if (max_val == FLT_MAX) { + for (i = 0; i < length; ++i) { + dst[i] = src[i] < min_val ? min_val : src[i]; + } + } else { + for (i = 0; i < length; ++i) { + dst[i] = src[i] < min_val ? min_val : (src[i] > max_val ? max_val : src[i]); + } + } + return NNACL_OK; +} + +int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate) { + if (src == NULL || dst == NULL) { + return NNACL_ERR; + } + int i = 0; + if (approximate) { + // dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3))) +#ifdef ENABLE_NEON + for (int num_max = length - C16NUM; i <= num_max; i += C16NUM) { + float16x4x4_t ins = vld4_f16(src + i); + float32x4_t in0 = MS_CVT_F32_F16(ins.val[0]); + float32x4_t in1 = MS_CVT_F32_F16(ins.val[1]); + float32x4_t in2 = MS_CVT_F32_F16(ins.val[2]); + float32x4_t in3 = MS_CVT_F32_F16(ins.val[3]); + float32x4_t res0 = 0.5f * in0 * (1.0f + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in0 * in0) * in0)); + float32x4_t res1 = 0.5f * in1 * (1.0f + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in1 * in1) * in1)); + float32x4_t res2 = 0.5f * in2 * (1.0f + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in2 * in2) * in2)); + float32x4_t res3 = 0.5f * in3 * (1.0f + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in3 * in3) * in3)); + float16x4x4_t res = { + MS_CVT_F16_F32(res0), + MS_CVT_F16_F32(res1), + MS_CVT_F16_F32(res2), + MS_CVT_F16_F32(res3), + }; + vst4_f16(dst + i, res); + } + for (int num_max = length - C4NUM; i <= num_max; i += C4NUM) { + float32x4_t in = MS_CVT_F32_F16(vld1_f16(src + i)); + float32x4_t res = 0.5f * in * (1.0f + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in * in) * in)); + vst1_f16(dst + i, MS_CVT_F16_F32(res)); + } +#endif + for (; i < length; i++) { + dst[i] = + 0.5f * src[i] * + (1.0f + TanhOptFp16(((float16_t)0.79788456080287f + (float16_t)0.035677408136f * src[i] * src[i]) * src[i])); + } + } else { +#ifdef ENABLE_NEON + int C8 = DOWN_ROUND(length, C8NUM); + for (; i < C8; i += C8NUM) { + float16x8_t in = vld1q_f16(src + i); + const float16x8_t res = 0.5f * in * (1.0f + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f)); + vst1q_f16(dst + i, res); + } +#endif + for (; i < length; i++) { + dst[i] = 0.5f * src[i] * (1.0f + erff(src[i] / 1.4142135623730951f)); + } + } + return NNACL_OK; +} + +int SoftplusFp16(const float16_t *src, int length, float16_t *dst) { + int i = 0; + for (; i < length; ++i) { + single_exp_fp16(src[i], dst + i); + dst[i] = log1p(dst[i]); + } + return NNACL_OK; +} + +int EluFp16(const float16_t *src, int length, float16_t *dst, float16_t alpha) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t one = MS_MOVQ_F16(1.0f); + for (; i <= length - 8; i += 8) { + float16x8_t src_tmp = MS_LDQ_F16(src + i); + float16x8_t exp_tmp = VexpFp16(src_tmp); // exp(x) + exp_tmp = MS_SUBQ_F16(exp_tmp, one); // exp(x) - 1 + float16x8_t elu_tmp = MS_MULQ_N_F16(exp_tmp, alpha); + uint16x8_t mask = vcleq_f16(src_tmp, MS_MOVQ_F16(0.0f)); + MS_STQ_F16(dst + i, vbslq_f16(mask, elu_tmp, src_tmp)); + } +#endif + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : (expm1(src[i]) * alpha); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/activation_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/activation_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..2283cd0cab0fdf9718821a2e61e4b269119d2a4b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/activation_fp16.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_ACTIVATION_FP16_H_ +#define NNACL_FP16_ACTIVATION_FP16_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/activation_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ReluFp16(const float16_t *src, float16_t *dst, int ele_num); +int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num); +int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha); +int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num); +int TanhFp16(const float16_t *src, float16_t *dst, int ele_num); +int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num); +int HSigmoidFp16(const float16_t *src, float16_t *dst, int ele_num); +int SwishFp16(const float16_t *src, float16_t *dst, int ele_num); +int HardTanhFp16(const float16_t *src, int length, float16_t *dst, float min_val, float max_val); +int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate); +int SoftplusFp16(const float16_t *src, int length, float16_t *dst); +int EluFp16(const float16_t *src, int length, float16_t *dst, float16_t alpha); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_ACTIVATION_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arg_min_max_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arg_min_max_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..2b9ffdbe56678190dc385633a01d4f44299d8fa7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arg_min_max_fp16.c @@ -0,0 +1,273 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/arg_min_max_fp16.h" + +int ArgCompareAscFp16(const void *a, const void *b) { + float16_t a_value = ((ArgElement *)a)->data_.f16_data_; + float16_t b_value = ((ArgElement *)b)->data_.f16_data_; + if (b_value > a_value) { + return -1; + } + if (b_value < a_value) { + return 1; + } + + return 0; +} + +int ArgCompareDescFp16(const void *a, const void *b) { + float16_t b_value = ((ArgElement *)b)->data_.f16_data_; + float16_t a_value = ((ArgElement *)a)->data_.f16_data_; + if (b_value > a_value) { + return 1; + } + if (b_value < a_value) { + return -1; + } + + return 0; +} + +void ArgMaxTopK1Fp16(const float16_t *input, void *output, float16_t *output_value, const ArgMinMaxComputeParam *param, + int pre_axis_count, int axis_count, int after_axis_count) { + bool out_value = param->out_value_; + float16_t *outputfp16 = (float16_t *)output; + int *outputint = (int *)output; + for (int i = 0; i < pre_axis_count; ++i) { + size_t output_offset = i * after_axis_count; + size_t input_offset = output_offset * axis_count; + for (int j = 0; j < after_axis_count; ++j) { + float16_t value = -FLT_MAX; + int index = 0; + for (int k = 0; k < axis_count; ++k) { + float16_t value_tmp = input[input_offset + k * after_axis_count + j]; + if (value_tmp > value) { + value = value_tmp; + index = k; + } + } + if (out_value) { + outputfp16[output_offset + j] = value; + } else { + outputint[output_offset + j] = index; + } + if (output_value != NULL) { + output_value[output_offset + j] = value; + } + } + } +} + +void ArgMinTopK1Fp16(const float16_t *input, void *output, float16_t *output_value, const ArgMinMaxComputeParam *param, + int pre_axis_count, int axis_count, int after_axis_count) { + bool out_value = param->out_value_; + float16_t *outputfp16 = (float16_t *)output; + int *outputint = (int *)output; + for (int i = 0; i < pre_axis_count; ++i) { + size_t output_offset = i * after_axis_count; + size_t input_offset = output_offset * axis_count; + for (int j = 0; j < after_axis_count; ++j) { + float16_t value = FLT_MAX; + int index = 0; + for (int k = 0; k < axis_count; ++k) { + float16_t value_tmp = input[input_offset + k * after_axis_count + j]; + if (value_tmp < value) { + value = value_tmp; + index = k; + } + } + if (out_value) { + outputfp16[output_offset + j] = value; + } else { + outputint[output_offset + j] = index; + } + if (output_value != NULL) { + output_value[output_offset + j] = value; + } + } + } +} + +void ArgMinMaxDim0Fp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) { + float16_t *outputfp16 = (float16_t *)output; + int *outputint = (int *)output; + for (int32_t i = 0; i < param->in_strides_[0]; ++i) { + for (int j = 0; j < in_shape[0]; ++j) { + size_t offset = param->in_strides_[0] * j + i; + param->arg_elements_[j].index_ = j; + param->arg_elements_[j].data_.f16_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), *compare_func); + for (int j = 0; j < param->topk_; ++j) { + size_t out_offset = j * param->out_strides_[0] + i; + if (param->out_value_) { + outputfp16[out_offset] = param->arg_elements_[j].data_.f16_data_; + } else { + outputint[out_offset] = param->arg_elements_[j].index_; + } + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[j].data_.f16_data_; + } + } + } + return; +} + +void ArgMinMaxDim1Fp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) { + int in_shape1 = in_shape[1]; + float16_t *outputfp16 = (float16_t *)output; + int *outputint = (int *)output; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < param->in_strides_[1]; ++j) { + for (int k = 0; k < in_shape1; ++k) { + size_t offset = param->in_strides_[1] * k + in_dim0_offset + j; + param->arg_elements_[k].index_ = k; + param->arg_elements_[k].data_.f16_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), *compare_func); + for (int k = 0; k < param->topk_; ++k) { + size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1]; + if (param->out_value_) { + outputfp16[out_offset] = param->arg_elements_[k].data_.f16_data_; + } else { + outputint[out_offset] = param->arg_elements_[k].index_; + } + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[k].data_.f16_data_; + } + } + } + } + return; +} + +void ArgMinMaxDim2Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + float *outputfp16 = (float *)output; + int *outputint = (int *)output; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < param->in_strides_[2]; ++k) { + for (int l = 0; l < in_shape2; ++l) { + size_t offset = param->in_strides_[2] * l + k + in_dim1_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f16_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), *compare_func); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2]; + if (param->out_value_) { + outputfp16[out_offset] = param->arg_elements_[l].data_.f16_data_; + } else { + outputint[out_offset] = param->arg_elements_[l].index_; + } + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[l].data_.f16_data_; + } + } + } + } + } +} + +void ArgMinMaxDim3Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + int in_shape3 = in_shape[3]; + float *outputfp16 = (float *)output; + int *outputint = (int *)output; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < in_shape2; ++k) { + size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset; + size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset; + for (int l = 0; l < in_shape3; ++l) { + size_t offset = l + in_dim2_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f16_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), *compare_func); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim2_offset + l; + if (param->out_value_) { + outputfp16[out_offset] = param->arg_elements_[l].data_.f16_data_; + } else { + outputint[out_offset] = param->arg_elements_[l].index_; + } + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[l].data_.f16_data_; + } + } + } + } + } +} + +void ArgMinMaxFp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param) { + if (param->topk_ == 1) { + int pre_axis_count = 1; + int axis_count = 1; + int after_axis_count = 1; + ComputeAxisDims(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count); + + if (param->get_max_) { + ArgMaxTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count); + } else { + ArgMinTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count); + } + return; + } + + COMPARE_FUNCTION compare_function = NULL; + if (param->get_max_) { + compare_function = ArgCompareDescFp16; + } else { + compare_function = ArgCompareAscFp16; + } + + switch (param->axis_) { + case 0: + ArgMinMaxDim0Fp16(input, output, output_value, in_shape, param, compare_function); + break; + case 1: + ArgMinMaxDim1Fp16(input, output, output_value, in_shape, param, compare_function); + break; + case 2: + ArgMinMaxDim2Fp16(input, output, output_value, in_shape, param, compare_function); + break; + case 3: + ArgMinMaxDim3Fp16(input, output, output_value, in_shape, param, compare_function); + break; + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arg_min_max_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arg_min_max_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..b54971f24ade6d84bfc20f923d1caf7e8c0a2d0c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arg_min_max_fp16.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_ARG_MIN_MAX_FP16_H_ +#define NNACL_FP16_ARG_MIN_MAX_FP16_H_ + +#include +#include "nnacl_c/arg_min_max_parameter.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/kernel/arg_min_max.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ArgMinMaxFp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_ARG_MIN_MAX_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..01866fadd582e9faee5829af40f48583f3dcc69e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_fp16.c @@ -0,0 +1,1314 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/arithmetic_fp16.h" +#include +#include "nnacl_c/common_func.h" +#include "nnacl_c/nnacl_utils.h" + +int BroadcastAddFp16(const float16_t *in0, const float16_t *in1, float16_t *tile_in0, float16_t *tile_in1, + float16_t *out, int size, ArithmeticParameter *param) { + TileDimensionsFp16(in0, in1, tile_in0, tile_in1, param); + return ElementAddFp16(tile_in0, tile_in1, out, size); +} + +void TileOneDimensionFp16(const void *input, void *output, int dim, size_t ndim, const int *inShape, + const int *inStrides, const int *outStrides, const int *multiple) { + const float16_t *inData = (const float16_t *)input; + float16_t *outData = (float16_t *)output; + + int srcDimSize = inShape[dim]; + if (dim == ndim - 1) { + for (int i = 0; i < multiple[dim]; i++) { + memcpy(outData, inData, srcDimSize * sizeof(float16_t)); + outData += srcDimSize; + } + return; + } + for (size_t i = 0; i < srcDimSize; i++) { + for (size_t j = 0; j < multiple[dim]; j++) { + TileOneDimensionFp16(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim, + inShape, inStrides, outStrides, multiple); + } + } +} + +void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, + ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimensionFp16(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, + param->multiples0_); + TileOneDimensionFp16(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_, + param->multiples1_); +} + +int ElementMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] * input1[index]; + } + return NNACL_OK; +} + +int ElementOptMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] * input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vmulq_f16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] * input1[0]; + } + } + return NNACL_OK; +} + +int ElementMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); +#endif + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] * input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementOptMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0_opt, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[0] * input1[index]; + output[index] = res > 0 ? res : 0; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vmulq_f16(vin0, vin1_opt); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] * input1[0]; + output[index] = res > 0 ? res : 0; + } + } + return NNACL_OK; +} + +int ElementMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] * input1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementOptMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0_opt, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[0] * input1[index], 0), 6); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vmulq_f16(vin0, vin1_opt); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] * input1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0, vin1); + vst1q_f16(output + index, vout); + } + for (; index <= element_size - 4; index += C4NUM) { + float16x4_t vin0 = vld1_f16(input0 + index); + float16x4_t vin1 = vld1_f16(input1 + index); + float16x4_t vout = vadd_f16(vin0, vin1); + vst1_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] + input1[index]; + } + return NNACL_OK; +} + +int ElementOptAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] + input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vaddq_f16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] + input1[0]; + } + } + return NNACL_OK; +} + +int ElementAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } + float16x4_t zeros1 = vdup_n_f16(0.0f); + for (; index <= element_size - 4; index += C4NUM) { + float16x4_t vin0 = vld1_f16(input0 + index); + float16x4_t vin1 = vld1_f16(input1 + index); + float16x4_t vout = vadd_f16(vin0, vin1); + vout = vmax_f16(vout, zeros1); + vst1_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] + input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementOptAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0_opt, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[0] + input1[index]; + output[index] = res > 0 ? res : 0; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vaddq_f16(vin0, vin1_opt); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] + input1[0]; + output[index] = res > 0 ? res : 0; + } + } + return NNACL_OK; +} + +int ElementAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } + float16x4_t zeros1 = vdup_n_f16(0.0); + float16x4_t bounds1 = vdup_n_f16(6.0); + for (; index <= element_size - 4; index += C4NUM) { + float16x4_t vin0 = vld1_f16(input0 + index); + float16x4_t vin1 = vld1_f16(input1 + index); + float16x4_t vout = vadd_f16(vin0, vin1); + vout = vmin_f16(vmax_f16(vout, zeros1), bounds1); + vst1_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] + input1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementOptAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0_opt, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[0] + input1[index], 0), 6); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vaddq_f16(vin0, vin1_opt); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] + input1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] - input1[index]; + } + return NNACL_OK; +} + +int ElementOptSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] - input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vsubq_f16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] - input1[0]; + } + } + return NNACL_OK; +} + +int ElementSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] - input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementOptSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0_opt, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[0] - input1[index]; + output[index] = res > 0 ? res : 0; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vsubq_f16(vin0, vin1_opt); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] - input1[0]; + output[index] = res > 0 ? res : 0; + } + } + return NNACL_OK; +} + +int ElementSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] - input1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementOptSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0_opt, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[0] - input1[index], 0), 6); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vsubq_f16(vin0, vin1_opt); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] - input1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] / input1[index]; + } + return NNACL_OK; +} + +int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = MS_DIVQ_F16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] / input1[index]; + } + } else { + if (input1[0] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] / input1[0]; + } + } + return NNACL_OK; +} + +int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + float16_t res = input0[index] / input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = MSMAX(input0[0] / input1[index], 0); + } + } else { + if (input1[0] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMAX(input0[index] / input1[0], 0); + } + } + return NNACL_OK; +} + +int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = MSMIN(MSMAX(input0[index] / input1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = MSMIN(MSMAX(input0[0] / input1[index], 0), 6); + } + } else { + if (input1[0] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] / input1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + for (int i = 0; i < element_size; ++i) { + if (input1[i] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[i] = input0[i] - floorf(input0[i] / input1[i]) * input1[i]; + } + return NNACL_OK; +} + +int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { + if (!first_scalar) { + for (int i = 0; i < element_size; ++i) { + output[i] = input0[i] - floorf(input0[i] / input1[0]) * input1[0]; + } + } else { + for (int i = 0; i < element_size; ++i) { + output[i] = input0[i] - floorf(input0[i] / input1[i]) * input1[i]; + } + } + return NNACL_OK; +} + +int ElementFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + for (int i = 0; i < element_size; ++i) { + output[i] = floorf(input0[i] / input1[i]); + } + return NNACL_OK; +} +int ElementOptFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { + if (!first_scalar) { + for (int i = 0; i < element_size; ++i) { + output[i] = floorf(input0[i] / input1[0]); + } + } else { + for (int i = 0; i < element_size; ++i) { + output[i] = floorf(input0[i] / input1[i]); + } + } + return NNACL_OK; +} + +int ElementLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t vtrue = vdupq_n_f16(1); + float16x8_t vfalse = vdupq_n_f16(0); + uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); + uint16x8_t zeros = vdupq_n_u16(0); + for (; index <= element_size - 8; index += C8NUM) { + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vld1q_f16(input0 + index)), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vld1q_f16(input1 + index)), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[index]) & (bool)(input1[index])); + } + return NNACL_OK; +} + +int ElementOptLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t vtrue = vdupq_n_f16(1); + float16x8_t vfalse = vdupq_n_f16(0); + uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); + uint16x8_t zeros = vdupq_n_u16(0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1_ = vld1q_f16(input1 + index); + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vin0_opt), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vin1_), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[0]) & (bool)(input1[index])); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0_ = vld1q_f16(input0 + index); + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vin0_), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vin1_opt), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[index]) & (bool)(input1[0])); + } + } + return NNACL_OK; +} + +int ElementLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t vtrue = vdupq_n_f16(1); + float16x8_t vfalse = vdupq_n_f16(0); + uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); + uint16x8_t zeros = vdupq_n_u16(0); + for (; index <= element_size - 8; index += C8NUM) { + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vld1q_f16(input0 + index)), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vld1q_f16(input1 + index)), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[index]) | (bool)(input1[index])); + } + return NNACL_OK; +} + +int ElementOptLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t vtrue = vdupq_n_f16(1); + float16x8_t vfalse = vdupq_n_f16(0); + uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); + uint16x8_t zeros = vdupq_n_u16(0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1_ = vld1q_f16(input1 + index); + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vin0_opt), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vin1_), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[0]) | (bool)(input1[index])); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0_ = vld1q_f16(input0 + index); + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vin0_), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vin1_opt), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[index]) | (bool)(input1[0])); + } + } + return NNACL_OK; +} + +int ElementSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, + int element_size) { + ElementSubFp16(input0, input1, output, element_size); + return ElementMulFp16(output, output, output, element_size); +} + +int ElementOptSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, + int element_size, bool first_scalar) { + ElementOptSubFp16(input0, input1, output, element_size, first_scalar); + return ElementMulFp16(output, output, output, element_size); +} + +int ElementMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmaxq_f16(vin0, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMAX(input0[index], input1[index]); + } + return NNACL_OK; +} + +int ElementOptMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmaxq_f16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMAX(input0[0], input1[index]); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vmaxq_f16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMAX(input0[index], input1[0]); + } + } + return NNACL_OK; +} + +int ElementMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vminq_f16(vin0, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(input0[index], input1[index]); + } + return NNACL_OK; +} + +int ElementOptMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vminq_f16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(input0[0], input1[index]); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vminq_f16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(input0[index], input1[0]); + } + } + return NNACL_OK; +} + +int ElementNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] != input1[index]; + } + return NNACL_OK; +} + +int ElementOptNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] != input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] != input1[0]; + } + } + return NNACL_OK; +} + +int ElementEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] == input1[index]; + } + return NNACL_OK; +} + +int ElementOptEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] == input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] == input1[0]; + } + } + return NNACL_OK; +} + +int ElementLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcltq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] < input1[index]; + } + return NNACL_OK; +} + +int ElementOptLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcltq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] < input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vcltq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] < input1[0]; + } + } + return NNACL_OK; +} + +int ElementLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcleq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] <= input1[index]; + } + return NNACL_OK; +} + +int ElementOptLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcleq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] <= input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vcleq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] <= input1[0]; + } + } + return NNACL_OK; +} + +int ElementGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcgtq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] > input1[index]; + } + return NNACL_OK; +} + +int ElementOptGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcgtq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] > input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vcgtq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] > input1[0]; + } + } + return NNACL_OK; +} + +int ElementGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcgeq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] >= input1[index]; + } + return NNACL_OK; +} + +int ElementOptGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcgeq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] >= input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vcgeq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] >= input1[0]; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..0bd3ece9986a0844f5271ca88b05d3d7dd9aebd3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_fp16.h @@ -0,0 +1,124 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_ARITHMETIC_FP16_H_ +#define NNACL_FP16_ARITHMETIC_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/base/arithmetic_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void TileOneDimensionFp16(const void *input, void *output, int dim, size_t ndim, const int *inShape, + const int *inStrides, const int *outStrides, const int *multiple); +void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, + ArithmeticParameter *param); + +int ElementOptMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, + int element_size, bool first_scalar); +int ElementOptMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementOptEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementOptLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementOptLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementOptGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementOptGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int BroadcastAddFp16(const float16_t *in0, const float16_t *in1, float16_t *tile_in0, float16_t *tile_in1, + float16_t *out, int size, ArithmeticParameter *param); + +int ElementSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_ARITHMETIC_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_self_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_self_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..eb693c62b9fc6e8828cba82e93fbdde8e68fc1b2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_self_fp16.c @@ -0,0 +1,124 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp16/arithmetic_self_fp16.h" + +int ElementAbsFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = fabsf(input[i]); + } + return NNACL_OK; +} + +int ElementCosFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = cosf(input[i]); + } + return NNACL_OK; +} + +int ElementLogFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input[i] <= 0) { + return NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO; + } + output[i] = logf(input[i]); + } + return NNACL_OK; +} + +int ElementSquareFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input[i] * input[i]; + } + return NNACL_OK; +} + +int ElementSqrtFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input[i] < 0) { + return NNACL_ERRCODE_SQRT_NEGATIVE; + } + output[i] = sqrtf(input[i]); + } + return NNACL_OK; +} + +int ElementRsqrtFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = 1.f / sqrtf(input[i]); + } + return NNACL_OK; +} + +int ElementSinFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = sinf(input[i]); + } + return NNACL_OK; +} + +int ElementLogicalNotFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = (float)(!((bool)(input[i]))); + } + return NNACL_OK; +} + +int ElementRoundFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = roundf(input[i]); + } + return NNACL_OK; +} + +int ElementFloorFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = floorf(input[i]); + } + return NNACL_OK; +} + +int ElementCeilFp16(const float16_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = ceilf(input[i]); + } + return NNACL_OK; +} + +int ElementNegativeFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; ++i) { + output[i] = -input[i]; + } + return NNACL_OK; +} + +int ElementReciprocalFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; ++i) { + if (input[i] == 0.0f) { + return NNACL_ERR; + } + output[i] = 1.f / input[i]; + } + return NNACL_OK; +} + +int ElementErfFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = erff(input[i]); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_self_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_self_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..995abffd709a3da9d6663d40679313ed436a918c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_self_fp16.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_ARITHMETIC_SELF_FP16_H_ +#define NNACL_FP16_ARITHMETIC_SELF_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ElementAbsFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementCosFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementLogFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementSquareFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementSqrtFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementRsqrtFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementSinFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementLogicalNotFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementRoundFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementFloorFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementCeilFp16(const float16_t *input, float16_t *output, int number); + +int ElementNegativeFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementReciprocalFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementErfFp16(const float16_t *input, float16_t *output, int element_size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_ARITHMETIC_SELF_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/batchnorm_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/batchnorm_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..d34d98e432b0bbc9809e0e4dd6d7338a7d62034a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/batchnorm_fp16.c @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/batchnorm_fp16.h" +#include +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +void BatchNormFp16(const float16_t *input, const float16_t *mean, const float16_t *variance, + const BatchNormStruct *param, int task_id, int thread_num, float16_t *output) { + int units_per_thread = UP_DIV(param->unit_, thread_num); + int completed_units = task_id * units_per_thread; + int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units); + int cur_offset = completed_units * param->channel_; + + for (int i = 0; i < cur_unit; i++) { + const float16_t *unit_input = input + cur_offset; + float16_t *unit_output = output + cur_offset; + int c = 0; +#ifdef ENABLE_ARM + for (; c <= param->channel_ - C8NUM; c += C8NUM) { + MS_FLOAT16X8 input_8 = MS_LDQ_F16(unit_input + c); + MS_FLOAT16X8 mean_8 = MS_LDQ_F16(mean + c); + MS_FLOAT16X8 variance_8 = MS_LDQ_F16(variance + c); + MS_FLOAT16X8 variance_sqrt = MS_SQRTFX8_F16(MS_ADDQ_F16(variance_8, MS_MOVQ_F16(param->epsilon_))); + MS_FLOAT16X8 output_8 = MS_DIVQ_F16(MS_SUBQ_F16(input_8, mean_8), variance_sqrt); + MS_STQ_F16(unit_output + c, output_8); + } +#endif + for (; c < param->channel_; c++) { + float16_t variance_sqrt = sqrtf(variance[c] + param->epsilon_); + unit_output[c] = (unit_input[c] - mean[c]) / variance_sqrt; + } + cur_offset += param->channel_; + } +} + +void FusedBatchNormFp16(const float16_t *input, const float16_t *scale, const float16_t *offset, const float16_t *mean, + const float16_t *variance, const BatchNormStruct *param, int task_id, int thread_num, + float16_t *output) { + int units_per_thread = UP_DIV(param->unit_, thread_num); + int completed_units = task_id * units_per_thread; + int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units); + int cur_offset = completed_units * param->channel_; + + for (int i = 0; i < cur_unit; i++) { + const float16_t *unit_input = input + cur_offset; + float16_t *unit_output = output + cur_offset; + int c = 0; +#ifdef ENABLE_ARM + for (; c <= param->channel_ - C8NUM; c += C8NUM) { + MS_FLOAT16X8 input_8 = MS_LDQ_F16(unit_input + c); + MS_FLOAT16X8 scale_8 = MS_LDQ_F16(scale + c); + MS_FLOAT16X8 offset_8 = MS_LDQ_F16(offset + c); + MS_FLOAT16X8 mean_8 = MS_LDQ_F16(mean + c); + MS_FLOAT16X8 variance_8 = MS_LDQ_F16(variance + c); + MS_FLOAT16X8 variance_sqrt = MS_SQRTFX8_F16(MS_ADDQ_F16(variance_8, MS_MOVQ_F16(param->epsilon_))); + MS_FLOAT16X8 norm_val = MS_DIVQ_F16(MS_SUBQ_F16(input_8, mean_8), variance_sqrt); + MS_FLOAT16X8 output_8 = MS_ADDQ_F16(MS_MULQ_F16(norm_val, scale_8), offset_8); + MS_STQ_F16(unit_output + c, output_8); + } +#endif + for (; c < param->channel_; c++) { + float16_t variance_sqrt = sqrtf(variance[c] + param->epsilon_); + float16_t norm_val = (unit_input[c] - mean[c]) / variance_sqrt; + unit_output[c] = norm_val * scale[c] + offset[c]; + } + cur_offset += param->channel_; + } +} + +void FusedBatchNormFp16MeanVar(const float16_t *input, float16_t *run_mean, float16_t *run_var, + const BatchNormStruct *param, float16_t *save_mean, float16_t *save_var) { + const float N = (float)param->unit_; + const float VN = N; + const float VNUB = (N > 1.0f) ? (N - 1.0f) : 1.0f; + const float momentum = (1.0f - param->momentum_); + + for (int i = 0; i < param->unit_; i++) { + for (int c = 0; c < param->channel_; c++) { + int idx = i * param->channel_ + c; + run_mean[c] += input[idx]; + } + } + for (int c = 0; c < param->channel_; c++) { + run_mean[c] /= (float16_t)N; + } + for (int i = 0; i < param->unit_; i++) { + for (int c = 0; c < param->channel_; c++) { + int idx = i * param->channel_ + c; + run_var[c] += (float16_t)((float)(input[idx] - run_mean[c]) * (float)(input[idx] - run_mean[c])); + } + } + for (int c = 0; c < param->channel_; c++) { + float unbiased_var = ((float)run_var[c] / VNUB); + run_var[c] = (float16_t)((float)run_var[c] / VN); + save_mean[c] = (float16_t)(momentum * (float)save_mean[c] + (1.0f - momentum) * (float)run_mean[c]); + save_var[c] = (float16_t)(momentum * (float)save_var[c] + (1.0f - momentum) * unbiased_var); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/batchnorm_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/batchnorm_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..b74a083d6df6018286742d29fb152c423ab83e6f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/batchnorm_fp16.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_BATCHNORM_FP16_H_ +#define NNACL_FP16_BATCHNORM_FP16_H_ + +#include "nnacl_c/kernel/batch_norm.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void BatchNormFp16(const float16_t *input, const float16_t *mean, const float16_t *variance, + const BatchNormStruct *param, int task_id, int thread_num, float16_t *output); +void FusedBatchNormFp16(const float16_t *input, const float16_t *scale, const float16_t *offset, const float16_t *mean, + const float16_t *variance, const BatchNormStruct *param, int task_id, int thread_num, + float16_t *output); +void FusedBatchNormFp16MeanVar(const float16_t *input, float16_t *run_mean, float16_t *run_var, + const BatchNormStruct *param, float16_t *save_mean, float16_t *save_var); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_BATCHNORM_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/cast_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/cast_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..6c556a5d2f792c18a94216459ed493c2ce73022e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/cast_fp16.h @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_CAST_FP16_H_ +#define NNACL_FP16_CAST_FP16_H_ + +#include "nnacl_c/op_base.h" +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) +#include + +#ifdef __cplusplus +extern "C" { +#endif + +inline void BoolToFloat16(const bool *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +inline void Uint8ToFloat16(const uint8_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +inline void Float16ToInt32(const float16_t *input, int32_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int32_t)input[i]; + } +} + +inline void Float16ToInt64(const float16_t *input, int64_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int64_t)input[i]; + } +} + +#ifdef ENABLE_ARM64 +inline void Float32ToFloat16(const float *__restrict input, float16_t *__restrict output, int number) { + int count = (number & ~(C8NUM - 1)); + int i = 0; + for (; i < count; i += C8NUM) { + float32x4_t in1 = vld1q_f32(input + i); + float16x4_t out1 = vcvt_f16_f32(in1); + float32x4_t in2 = vld1q_f32(input + i + 4); + float16x4_t out2 = vcvt_f16_f32(in2); + float16x8_t out = vcombine_f16(out1, out2); + vst1q_f16(output + i, out); + } + for (; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +inline void Float16ToFloat32(const float16_t *__restrict input, float *__restrict output, int number) { + int count = number & ~(C8NUM - 1); + int i = 0; + for (; i < count; i += C8NUM) { + float16x8_t in = vld1q_f16(input + i); + float16x4_t in1 = vget_low_f16(in); + float16x4_t in2 = vget_high_f16(in); + float32x4_t out1 = vcvt_f32_f16(in1); + vst1q_f32(output + i, out1); + float32x4_t out2 = vcvt_f32_f16(in2); + vst1q_f32(output + i + C4NUM, out2); + } + for (; i < number; ++i) { + output[i] = (float)input[i]; + } +} +#else +void Float32ToFloat16(const float *input, float16_t *output, int number); + +void Float16ToFloat32(const float16_t *input, float *output, int number); +#endif + +#ifdef __cplusplus +} +#endif +#endif +#endif // NNACL_FP16_CAST_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/common_func_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/common_func_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..1cce761d32cdd317b84ee4138084213760f0ac0e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/common_func_fp16.c @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/common_func_fp16.h" + +void PostConvFuncCommFp16(float16_t *out_ptr, const float16_t *src_ptr_, const float16_t *bias_ptr, + size_t output_channel, size_t plane_size, size_t oc_stride, size_t hw_stride, + ActType act_type, int size) { + if (size == 0) { + return; + } + for (int oc = 0; oc < output_channel; oc++) { + int oc_div = oc / size, oc_mod = oc % size; + for (int hw = 0; hw < plane_size; hw++) { + int src_index = oc_div * size * hw_stride + hw * size + oc_mod; + int dst_index = hw * oc_stride + oc; + float16_t value = src_ptr_[src_index]; + if (bias_ptr != NULL) { + value = value + bias_ptr[oc]; + } + value = (act_type == ActType_Relu || act_type == ActType_Relu6) ? (MSMAX(0.f, value)) : (value); + value = (act_type == ActType_Relu6) ? (MSMIN(6.f, value)) : (value); + out_ptr[dst_index] = value; + } + } + return; +} + +void PostConvFuncFp16C8(const float16_t *c8_out, float16_t *nhwc_out, const float16_t *bias, size_t oc, size_t plane, + size_t oc_stride, ActType act_type) { +#ifdef ENABLE_ARM64 + size_t oc8mod = oc % C8NUM; + size_t oc8div = oc - oc8mod; + size_t stride_size = oc_stride * sizeof(float16_t); + PostFuncBiasReluC8Fp16(nhwc_out, c8_out, bias, oc8div, oc8mod, plane, stride_size, act_type); +#else + PostConvFuncCommFp16(nhwc_out, c8_out, bias, oc, plane, oc_stride, plane, act_type, C8NUM); +#endif +} + +void PostConvFuncFp16C4(const float16_t *c4_out, float16_t *nhwc_out, const float16_t *bias, size_t oc, size_t plane, + size_t plane_stride, ActType act_type) { +#ifdef ENABLE_ARM64 + size_t oc4mod = oc % C4NUM; + size_t oc4div = oc - oc4mod; + size_t stride_size = (plane_stride - plane) * C4NUM * sizeof(float16_t); + PostFuncBiasReluC4Fp16(nhwc_out, c4_out, bias, oc4div, oc4mod, plane, stride_size, act_type); +#else + PostConvFuncCommFp16(nhwc_out, c4_out, bias, oc, plane, oc, plane_stride, act_type, C4NUM); +#endif +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/common_func_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/common_func_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..95be975c934cb213f8fee6f16b430bd7f19fc9f9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/common_func_fp16.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_COMMON_FUNC_FP16_H_ +#define NNACL_FP16_COMMON_FUNC_FP16_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* deconv common */ +void PostConvFuncFp16C8(const float16_t *c8_out_ptr, float16_t *out_ptr, const float16_t *bias_ptr, + size_t output_channel, size_t plane_size, size_t stride, ActType act_type); +void PostFuncBiasReluC8Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t stride, size_t relu_type); + +/* deconv winograd */ +void PostConvFuncFp16C4(const float16_t *c4_out, float16_t *nhwc_out, const float16_t *bias, size_t output_channel, + size_t plane_size, size_t plane_stride, ActType act_type); +void PostFuncBiasReluC4Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc4div, size_t oc4mod, + size_t plane_size, size_t plane_stride, size_t relu_type); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_COMMON_FUNC_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/constant_of_shape_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/constant_of_shape_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..41719d6a4307938bcc1d590c0f40dc0378d9e2ab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/constant_of_shape_fp16.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ +#define NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/constant_of_shape_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +#ifdef __cplusplus +#ifdef ENABLE_FP16 +inline int ConstantOfShapeFp16(float16_t *output, int start, int end, float16_t value) { + for (int i = start; i < end; i++) { + output[i] = value; + } + return NNACL_OK; +} +#endif +} +#endif + +#endif // NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_depthwise_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_depthwise_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..7d718b8756a8d352d302a56f1d58d958291c321f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_depthwise_fp16.c @@ -0,0 +1,842 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/conv_depthwise_fp16.h" +#include +#include "nnacl_c/fp16/activation_fp16.h" + +#ifdef ENABLE_ARM82_A32 +void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step) { + for (int i = 0; i < num_pixels; i++) { + for (int c = 0; c < output_channel; c++) { + *output_ptr++ += weight_ptr[c] * input_ptr[c]; + } + input_ptr += input_step; + } +} +#endif + +#ifdef ENABLE_ARM +static void ConvDw3x3RowLeftFp16(const float16_t *src, float16_t *line, int lw, int channel) { + MS_FLOAT16X8 v0, v1, v2, v3; + v0 = MS_MOVQ_F16((float16_t)0.0); + int ic = 0; + for (; ic < channel - 7; ic += 8) { + v1 = MS_LDQ_F16(src + ic); + v2 = MS_LDQ_F16(src + channel + ic); + v3 = MS_LDQ_F16(src + 2 * channel + ic); + MS_FLOAT16X8 b0 = MS_SUBQ_F16(v0, v2); + MS_FLOAT16X8 b1 = MS_ADDQ_F16(v1, v2); + MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1); + MS_FLOAT16X8 b3 = MS_SUBQ_F16(v3, v1); + MS_STQ_F16(line + lw * ic, b0); + MS_STQ_F16(line + lw * ic + 8, b1); + MS_STQ_F16(line + lw * ic + 16, b2); + MS_STQ_F16(line + lw * ic + 24, b3); + } + if (ic < channel) { + float16_t *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 16, 0, 16); + memset(remain_line + 24, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float16_t d1 = src[i + ic]; + float16_t d2 = src[i + ic + channel]; + float16_t d3 = src[i + ic + 2 * channel]; + remain_line[i] = (float16_t)0.0 - d2; + remain_line[i + 8] = d1 + d2; + remain_line[i + 16] = d2 - d1; + remain_line[i + 24] = d3 - d1; + } + } +} + +static void ConvDw3x3RowMiddleFp16(const float16_t *src, float16_t *line, int lw, int channel) { + MS_FLOAT16X8 v0, v1, v2, v3; + int ic = 0; + for (; ic < channel - 7; ic += 8) { + v0 = MS_LDQ_F16(src + ic); + v1 = MS_LDQ_F16(src + channel + ic); + v2 = MS_LDQ_F16(src + 2 * channel + ic); + v3 = MS_LDQ_F16(src + 3 * channel + ic); + MS_FLOAT16X8 b0 = MS_SUBQ_F16(v0, v2); + MS_FLOAT16X8 b1 = MS_ADDQ_F16(v1, v2); + MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1); + MS_FLOAT16X8 b3 = MS_SUBQ_F16(v3, v1); + MS_STQ_F16(line + lw * ic, b0); + MS_STQ_F16(line + lw * ic + 8, b1); + MS_STQ_F16(line + lw * ic + 16, b2); + MS_STQ_F16(line + lw * ic + 24, b3); + } + if (ic < channel) { + float16_t *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 16, 0, 16); + memset(remain_line + 24, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float16_t d0 = src[i + ic]; + float16_t d1 = src[i + ic + channel]; + float16_t d2 = src[i + ic + 2 * channel]; + float16_t d3 = src[i + ic + 3 * channel]; + remain_line[i] = d0 - d2; + remain_line[i + 8] = d1 + d2; + remain_line[i + 16] = d2 - d1; + remain_line[i + 24] = d3 - d1; + } + } +} + +static void ConvDw3x3RowRightFp16(const float16_t *src, float16_t *line, int lw, int channel) { + MS_FLOAT16X8 v0, v1, v2, v3; + int ic = 0; + v3 = MS_MOVQ_F16((float16_t)0.0); + for (; ic < channel - 7; ic += 8) { + v0 = MS_LDQ_F16(src + ic); + v1 = MS_LDQ_F16(src + channel + ic); + v2 = MS_LDQ_F16(src + 2 * channel + ic); + MS_FLOAT16X8 b0 = MS_SUBQ_F16(v0, v2); + MS_FLOAT16X8 b1 = MS_ADDQ_F16(v1, v2); + MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1); + MS_FLOAT16X8 b3 = MS_SUBQ_F16(v3, v1); + MS_STQ_F16(line + lw * ic, b0); + MS_STQ_F16(line + lw * ic + 8, b1); + MS_STQ_F16(line + lw * ic + 16, b2); + MS_STQ_F16(line + lw * ic + 24, b3); + } + if (ic < channel) { + float16_t *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 16, 0, 16); + memset(remain_line + 24, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float16_t d0 = src[i + ic]; + float16_t d1 = src[i + ic + channel]; + float16_t d2 = src[i + ic + 2 * channel]; + remain_line[i] = d0 - d2; + remain_line[i + 8] = d1 + d2; + remain_line[i + 16] = d2 - d1; + remain_line[i + 24] = (float16_t)0.0 - d1; + } + } +} + +static void ConvDw3x3RowSingleFp16(const float16_t *src, float16_t *line, int lw, int channel) { + MS_FLOAT16X8 v0, v1, v2; + int ic = 0; + v2 = MS_MOVQ_F16((float16_t)0.0); + for (; ic < channel - 7; ic += 8) { + v0 = MS_LDQ_F16(src + ic); + v1 = MS_LDQ_F16(src + channel + ic); + MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1); + MS_STQ_F16(line + lw * ic, v0); + MS_STQ_F16(line + lw * ic + 8, v1); + MS_STQ_F16(line + lw * ic + 16, b2); + memset(line + lw * ic + 24, 0, 16); + } + if (ic < channel) { + float16_t *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 16, 0, 16); + memset(remain_line + 24, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float16_t d0 = src[i + ic]; + float16_t d1 = src[i + ic + channel]; + remain_line[i] = d0; + remain_line[i + 8] = d1; + remain_line[i + 16] = (float16_t)0.0 - d1; + } + } +} + +static void ConvDw3x3InitTopFp16(const float16_t *src, float16_t **lines, int width, int channel) { + float16_t *line0 = lines[0]; + float16_t *line1 = lines[1]; + float16_t *line2 = lines[2]; + int c8 = UP_ROUND(channel, C8NUM); + int lw = UP_DIV(width, C2NUM) * C4NUM; + memset(line0, 0, c8 * lw * sizeof(float16_t)); + ConvDw3x3RowLeftFp16(src, line1, lw, channel); + ConvDw3x3RowLeftFp16(src + width * channel, line2, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowMiddleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRightFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowRightFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowSingleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } +} + +static void ConvDw3x3InitRowFp16(const float16_t *src, float16_t **lines, int width, int channel) { + float16_t *line0 = lines[0]; + float16_t *line1 = lines[1]; + float16_t *line2 = lines[2]; + int lw = UP_DIV(width, C2NUM) * C4NUM; + ConvDw3x3RowLeftFp16(src - width * channel, line0, lw, channel); + ConvDw3x3RowLeftFp16(src, line1, lw, channel); + ConvDw3x3RowLeftFp16(src + width * channel, line2, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddleFp16(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 8, lw, channel); + ConvDw3x3RowMiddleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowMiddleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRightFp16(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 8, lw, channel); + ConvDw3x3RowRightFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowRightFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingleFp16(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 8, lw, channel); + ConvDw3x3RowSingleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowSingleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } +} + +static void ConvDw3x3RowFp16(const float16_t *src, float16_t **lines, int width, int channel) { + float16_t *tmp = lines[0]; + lines[0] = lines[1]; + lines[1] = lines[2]; + lines[2] = tmp; + int c8 = UP_ROUND(channel, C8NUM); + int lw = UP_DIV(width, C2NUM) * C4NUM; + memset(tmp, 0, c8 * lw * sizeof(float16_t)); + ConvDw3x3RowLeftFp16(src, tmp, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddleFp16(src + (ow - 1) * channel, tmp + 2 * ow * 8, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRightFp16(src + (ow - 1) * channel, tmp + 2 * ow * 8, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingleFp16(src + (ow - 1) * channel, tmp + 2 * ow * 8, lw, channel); + } +} + +static void ConvDw3x3BottomFp16(float16_t **lines, int width, int channel) { + float16_t *tmp = lines[0]; + lines[0] = lines[1]; + lines[1] = lines[2]; + lines[2] = tmp; + int c8 = UP_ROUND(channel, C8NUM); + memset(tmp, 0, UP_DIV(width, C2NUM) * c8 * C4NUM * sizeof(float16_t)); +} + +void ConvDw3x3LineFp16(float16_t *dst, float16_t **lines, const float16_t *weight, const float16_t *bias_data, + int width, int ori_channel, bool relu, bool relu6) { + int channel = ori_channel; + float16_t *line0 = lines[0]; + float16_t *line1 = lines[1]; + float16_t *line2 = lines[2]; + for (; channel > 0; channel -= 8) { + MS_FLOAT16X8 bias = MS_LDQ_F16(bias_data); + bias_data += 8; + MS_FLOAT16X8 g00 = MS_LDQ_F16(weight); + MS_FLOAT16X8 g01 = MS_LDQ_F16(weight + 8); + MS_FLOAT16X8 g02 = MS_LDQ_F16(weight + 16); + MS_FLOAT16X8 g03 = MS_LDQ_F16(weight + 24); + MS_FLOAT16X8 g10 = MS_LDQ_F16(weight + 32); + MS_FLOAT16X8 g11 = MS_LDQ_F16(weight + 40); + MS_FLOAT16X8 g12 = MS_LDQ_F16(weight + 48); + MS_FLOAT16X8 g13 = MS_LDQ_F16(weight + 56); + MS_FLOAT16X8 g20 = MS_LDQ_F16(weight + 64); + MS_FLOAT16X8 g21 = MS_LDQ_F16(weight + 72); + MS_FLOAT16X8 g22 = MS_LDQ_F16(weight + 80); + MS_FLOAT16X8 g23 = MS_LDQ_F16(weight + 88); + weight += 96; + float16_t *cur_dst = dst; + int ow = 0; + for (; ow < width - 1; ow += 2) { + MS_FLOAT16X8 acc0 = MS_MULQ_F16(MS_LDQ_F16(line0), g00); + MS_FLOAT16X8 acc1 = MS_MULQ_F16(MS_LDQ_F16(line0 + 8), g01); + MS_FLOAT16X8 acc2 = MS_MULQ_F16(MS_LDQ_F16(line0 + 16), g02); + MS_FLOAT16X8 acc3 = MS_MULQ_F16(MS_LDQ_F16(line0 + 24), g03); + line0 += 32; + acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line1), g10); + acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line1 + 8), g11); + acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line1 + 16), g12); + acc3 = MS_FMAQ_F16(acc3, MS_LDQ_F16(line1 + 24), g13); + + line1 += 32; + acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line2), g20); + acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line2 + 8), g21); + acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line2 + 16), g22); + acc3 = MS_FMAQ_F16(acc3, MS_LDQ_F16(line2 + 24), g23); + + line2 += 32; + MS_FLOAT16X8 res0 = MS_ADDQ_F16(acc0, MS_ADDQ_F16(acc2, acc1)); + MS_FLOAT16X8 res1 = MS_ADDQ_F16(acc1, MS_SUBQ_F16(acc3, acc2)); + res0 = MS_ADDQ_F16(res0, bias); + res1 = MS_ADDQ_F16(res1, bias); + if (relu || relu6) { + res0 = MS_MAXQ_F16(res0, MS_MOVQ_F16((float16_t)0.0)); + res1 = MS_MAXQ_F16(res1, MS_MOVQ_F16((float16_t)0.0)); + } + if (relu6) { + res0 = MS_MINQ_F16(res0, MS_MOVQ_F16((float16_t)6.0)); + res1 = MS_MINQ_F16(res1, MS_MOVQ_F16((float16_t)6.0)); + } + if (channel >= 8) { + MS_STQ_F16(cur_dst, res0); + MS_STQ_F16(cur_dst + ori_channel, res1); + } else { + for (int i = 0; i < channel; i++) { + cur_dst[i] = res0[i]; + cur_dst[ori_channel + i] = res1[i]; + } + } + cur_dst += 2 * ori_channel; + } + if (ow < width) { + MS_FLOAT16X8 acc0 = MS_MULQ_F16(MS_LDQ_F16(line0), g00); + MS_FLOAT16X8 acc1 = MS_MULQ_F16(MS_LDQ_F16(line0 + 8), g01); + MS_FLOAT16X8 acc2 = MS_MULQ_F16(MS_LDQ_F16(line0 + 16), g02); + line0 += 32; + acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line1), g10); + acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line1 + 8), g11); + acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line1 + 16), g12); + + line1 += 32; + acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line2), g20); + acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line2 + 8), g21); + acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line2 + 16), g22); + + line2 += 32; + MS_FLOAT16X8 res0 = MS_ADDQ_F16(acc0, MS_ADDQ_F16(acc2, acc1)); + res0 = MS_ADDQ_F16(res0, bias); + if (relu || relu6) { + res0 = MS_MAXQ_F16(res0, MS_MOVQ_F16((float16_t)0.0)); + } + if (relu6) { + res0 = MS_MINQ_F16(res0, MS_MOVQ_F16((float16_t)6.0)); + } + if (channel >= 8) { + MS_STQ_F16(cur_dst, res0); + } else { + for (int i = 0; i < channel; i++) { + cur_dst[i] = res0[i]; + } + } + } + dst += 8; + } +} + +void ConvDw3x3Fp16(float16_t *output_data, float16_t *buffer, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh) { + int units = UP_DIV(conv_param->output_w_, C2NUM); + int c8 = UP_ROUND(conv_param->input_channel_, C8NUM); + int line = conv_param->input_channel_ * conv_param->input_w_; + + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + + for (int b = 0; b < conv_param->output_batch_; b++) { + const float16_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float16_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + float16_t *line0 = buffer; + float16_t *line1 = buffer + units * c8 * C4NUM; + float16_t *line2 = buffer + units * c8 * C4NUM * 2; + float16_t *lines[3] = {line0, line1, line2}; + int oh = start_oh; + if (oh == 0) { + // input trans + ConvDw3x3InitTopFp16(src, lines, conv_param->output_w_, conv_param->input_channel_); + } else { + // input trans + ConvDw3x3InitRowFp16(src + oh * line, lines, conv_param->output_w_, conv_param->input_channel_); + } + // dst calc and trans + ConvDw3x3LineFp16(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + for (oh = start_oh + 1; oh < end_oh - 1; oh++) { + // input trans + ConvDw3x3RowFp16(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); + // dst calc and trans + ConvDw3x3LineFp16(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, + conv_param->input_channel_, relu, relu6); + } + if (oh == conv_param->output_h_ - 1) { + // input trans + ConvDw3x3BottomFp16(lines, conv_param->output_w_, conv_param->input_channel_); + } else { + // input trans + ConvDw3x3RowFp16(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); + } + // dst calc and trans + ConvDw3x3LineFp16(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + } +} + +#endif + +void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, int task_id) { + NNACL_CHECK_ZERO_RETURN(conv_param->stride_w_); + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_); + int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int h_start = h_step * task_id; + int h_end = MSMIN(h_start + h_step, conv_param->output_h_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + for (int b = 0; b < conv_param->output_batch_; b++) { + const float16_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float16_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = h_start; oh < h_end; oh++) { + float16_t *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_; + + int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_)); + + for (int ow = 0; ow < conv_param->output_w_; ow++) { + memcpy(dst_data + ow * conv_param->output_channel_, bias_data, conv_param->output_channel_ * sizeof(float16_t)); + } + for (int kh = start_kh; kh < end_kh; kh++) { + int ih = ih_origin + conv_param->dilation_h_ * kh; + + const float16_t *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_; + const float16_t *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_; + + int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_; + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + int out_w_start = MSMAX( + 0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_); + int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ - + conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / + conv_param->stride_w_); + + float16_t *dst_w = dst_data + out_w_start * conv_param->output_channel_; + int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw; + + const float16_t *src_kw = src_kh + iw_origin * conv_param->input_channel_; + int num_pixels = out_w_end - out_w_start; + ConvDwFp16Row(dst_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step); + weight_kh += conv_param->output_channel_; + } + } + if (relu) { + ReluFp16(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_); + } + if (relu6) { + Relu6Fp16(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_); + } + } + } +} + +/*conv depthwise fp16 begin*/ +void DepthwiseBorderPixelFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, + int height, int width, int in_kh_step, int in_kw_step, int kernel_w_step, bool is_relu, + bool is_relu6) { + for (int c = 0; c < C8NUM; c++) { + dst[c] = 0; + } + const float16_t *src_kh = src; + const float16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + const float16_t *src_kw = src_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + float16x8_t src_8 = vld1q_f16(src_kw); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst, dst_8); + + src_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w_step; + } // kernel_h loop + for (int c = 0; c < C8NUM; c++) { + dst[c] += bias[c]; + dst[c] = (is_relu) ? (MSMAX(0, dst[c])) : (dst[c]); + dst[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[c]))) : (dst[c]); + } +} + +void DepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top, + int bottom, int left, int right, const ConvParameter *conv_param, + const SlidingWindowParam *sliding) { + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_w_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + float16_t *dst_h = dst + top * sliding->out_h_step_; + for (int oh = top; oh < bottom; oh++) { + int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const float16_t *src_h = src + ih * sliding->in_h_step_; + + float16_t *dst_kernel = dst_h + left * sliding->block_channel_; + for (int ow = left; ow < right; ow++) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const float16_t *src_w = src_h + iw * sliding->block_channel_; + + const float16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const float16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM; +#ifdef ENABLE_ARM64 + ConvDwFp16Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_ * sizeof(float16_t), sliding->in_kw_step_ * sizeof(float16_t), + conv_param->kernel_w_ * C8NUM * sizeof(float16_t), relu, relu6); +#else + DepthwiseBorderPixelFp16(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C8NUM, relu, relu6); +#endif + dst_kernel += sliding->block_channel_; + } // width loop + dst_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void DepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, + int height, int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, + int in_sh_step, int in_sw_step, int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) { + float16_t *dst_h = dst; + const float16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + float16_t *dst_w = dst_h; + const float16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + const float16_t *src_kh = src_w; + const float16_t *weight_kh = weight; + for (int c = 0; c < C8NUM; c++) { + dst_w[c] = 0; + } + for (int kh = 0; kh < kernel_h; kh++) { + const float16_t *src_kw = src_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { +#ifdef ENABLE_ARM64 + float16x8_t src_8 = vld1q_f16(src_kw); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst_w); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst_w, dst_8); +#else + for (int c = 0; c < C8NUM; c++) { + dst_w[c] += src_kw[c] * weight_kw[c]; + } +#endif + src_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop + // add biad relu + for (int c = 0; c < C8NUM; c++) { + dst_w[c] += bias[c]; + dst_w[c] = (is_relu) ? (MSMAX(0, dst_w[c])) : (dst_w[c]); + dst_w[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[c]))) : (dst_w[c]); + } + dst_w += block_channel; + src_w += in_sw_step; + } // dst_width loop + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} +#endif + +// conv depthwise fp16: sliding window +void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_w_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + const float16_t *src = input_data; + float16_t *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float16_t *src_data = src + oc * C8NUM; + float16_t *dst_data = dst + oc * C8NUM; + const float16_t *weight = weight_data + oc * sliding->kernel_step_; + const float16_t *bias = bias_data + oc * C8NUM; + DepthwiseBorderFp16(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, + sliding); + DepthwiseBorderFp16(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, + conv_param->output_w_, conv_param, sliding); + DepthwiseBorderFp16(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, + conv_param, sliding); + DepthwiseBorderFp16(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->output_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + const float16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; + float16_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#ifdef ENABLE_ARM64 + ConvDwFp16Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float16_t), + sliding->block_channel_ * sizeof(float16_t), sliding->in_sh_step_ * sizeof(float16_t), + sliding->in_sw_step_ * sizeof(float16_t), sliding->in_kh_step_ * sizeof(float16_t), + sliding->in_kw_step_ * sizeof(float16_t), relu, relu6); +#else + DepthwiseCenterFp16(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, + sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, + sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_, + sliding->in_kh_step_, sliding->in_kw_step_, relu, relu6); +#endif + } + } // output C8 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nchwc8 +} +/*conv depthwise fp16 end*/ + +/*deconv depthwise fp16 begin*/ +void DeconvDepthwiseBorderPixelFp16(float16_t *dst, const float16_t *src, const float16_t *weight, int height, + int width, int in_kh_step, int in_kw_step, int kernel_w_step) { + float16_t *dst_kh = dst; + const float16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + float16_t *dst_kw = dst_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + float16x8_t src_8 = vld1q_f16(src); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst_kw); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst_kw, dst_8); + + dst_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w_step; + } // kernel_h loop +} + +void DeconvDepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, int top, int bottom, + int left, int right, const ConvParameter *conv_param, + const SlidingWindowParam *sliding) { + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_w_); + const float16_t *src_h = src + top * sliding->out_h_step_; + for (int ih = top; ih < bottom; ih++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + float16_t *dst_h = dst + oh * sliding->in_h_step_; + + const float16_t *src_kernel = src_h + left * sliding->block_channel_; + for (int iw = left; iw < right; iw++) { + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + float16_t *dst_w = dst_h + ow * sliding->block_channel_; + + const float16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM; + float16_t *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; +#ifdef ENABLE_ARM64 + DeconvDwFp16Border(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_ * sizeof(float16_t), sliding->in_kw_step_ * sizeof(float16_t), + conv_param->kernel_w_ * C8NUM * sizeof(float16_t)); +#else + DeconvDepthwiseBorderPixelFp16(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C8NUM); +#endif + src_kernel += sliding->block_channel_; + } // width loop + src_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void DeconvDepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, + int in_sw_step, int in_kh_step, int in_kw_step) { + float16_t *dst_h = dst; + const float16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + float16_t *dst_w = dst_h; + const float16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + float16_t *dst_kh = dst_w; + const float16_t *weight_kh = weight; + for (int kh = 0; kh < kernel_h; kh++) { + float16_t *dst_kw = dst_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { +#ifdef ENABLE_NEON + float16x8_t src_8 = vld1q_f16(src_w); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst_kw); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst_kw, dst_8); +#else + for (int c = 0; c < C8NUM; c++) { + dst_kw[c] += src_w[c] * weight_kw[c]; + } +#endif + dst_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} +#endif + +void DeconvDepthwisePostFuncFp16(float16_t *dst, const float16_t *bias, int block_channel, + const ConvParameter *conv_param) { + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + int hw = conv_param->output_h_ * conv_param->output_w_; + int hw8 = hw / C8NUM * C8NUM; + float16x8_t bias_value = vld1q_f16(bias); + float16x8_t zero = vdupq_n_f16(0.0f); + float16x8_t six = vdupq_n_f16(6.0f); + + int i = 0; + for (; i < hw8; i += C8NUM) { + float16_t *dst_ptr = dst + i * block_channel; + float16x8_t dst_value0 = vld1q_f16(dst_ptr); + float16x8_t dst_value1 = vld1q_f16(dst_ptr + C1NUM * block_channel); + float16x8_t dst_value2 = vld1q_f16(dst_ptr + C2NUM * block_channel); + float16x8_t dst_value3 = vld1q_f16(dst_ptr + C3NUM * block_channel); + float16x8_t dst_value4 = vld1q_f16(dst_ptr + C4NUM * block_channel); + float16x8_t dst_value5 = vld1q_f16(dst_ptr + C5NUM * block_channel); + float16x8_t dst_value6 = vld1q_f16(dst_ptr + C6NUM * block_channel); + float16x8_t dst_value7 = vld1q_f16(dst_ptr + C7NUM * block_channel); + + dst_value0 = vaddq_f16(dst_value0, bias_value); + dst_value1 = vaddq_f16(dst_value1, bias_value); + dst_value2 = vaddq_f16(dst_value2, bias_value); + dst_value3 = vaddq_f16(dst_value3, bias_value); + dst_value4 = vaddq_f16(dst_value4, bias_value); + dst_value5 = vaddq_f16(dst_value5, bias_value); + dst_value6 = vaddq_f16(dst_value6, bias_value); + dst_value7 = vaddq_f16(dst_value7, bias_value); + if (relu) { + dst_value0 = vmaxq_f16(dst_value0, zero); + dst_value1 = vmaxq_f16(dst_value1, zero); + dst_value2 = vmaxq_f16(dst_value2, zero); + dst_value3 = vmaxq_f16(dst_value3, zero); + dst_value4 = vmaxq_f16(dst_value4, zero); + dst_value5 = vmaxq_f16(dst_value5, zero); + dst_value6 = vmaxq_f16(dst_value6, zero); + dst_value7 = vmaxq_f16(dst_value7, zero); + } + if (relu6) { + dst_value0 = vminq_f16(dst_value0, six); + dst_value1 = vminq_f16(dst_value1, six); + dst_value2 = vminq_f16(dst_value2, six); + dst_value3 = vminq_f16(dst_value3, six); + dst_value4 = vminq_f16(dst_value4, six); + dst_value5 = vminq_f16(dst_value5, six); + dst_value6 = vminq_f16(dst_value6, six); + dst_value7 = vminq_f16(dst_value7, six); + } + vst1q_f16(dst_ptr, dst_value0); + vst1q_f16(dst_ptr + C1NUM * block_channel, dst_value1); + vst1q_f16(dst_ptr + C2NUM * block_channel, dst_value2); + vst1q_f16(dst_ptr + C3NUM * block_channel, dst_value3); + vst1q_f16(dst_ptr + C4NUM * block_channel, dst_value4); + vst1q_f16(dst_ptr + C5NUM * block_channel, dst_value5); + vst1q_f16(dst_ptr + C6NUM * block_channel, dst_value6); + vst1q_f16(dst_ptr + C7NUM * block_channel, dst_value7); + } + + float16_t *dst_ptr = dst + i * block_channel; + for (; i < hw; i++, dst_ptr += block_channel) { + float16x8_t dst_value0 = vld1q_f16(dst_ptr); + dst_value0 = vaddq_f16(dst_value0, bias_value); + dst_value0 = relu ? vmaxq_f16(dst_value0, zero) : dst_value0; + dst_value0 = relu6 ? vminq_f16(dst_value0, six) : dst_value0; + vst1q_f16(dst_ptr, dst_value0); + } +} + +// deconv depthwise fp16: sliding window +void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + const float16_t *src = input_data; + float16_t *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float16_t *src_data = src + oc * C8NUM; + float16_t *dst_data = dst + oc * C8NUM; + const float16_t *weight = weight_data + oc * sliding->kernel_step_; + const float16_t *bias = bias_data + oc * C8NUM; + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, + sliding); + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, + conv_param->input_w_, conv_param, sliding); + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, + conv_param, sliding); + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->input_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + float16_t *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; + const float16_t *in_t = + src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#ifdef ENABLE_ARM64 + DeconvDwFp16Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float16_t), + sliding->block_channel_ * sizeof(float16_t), sliding->in_sh_step_ * sizeof(float16_t), + sliding->in_sw_step_ * sizeof(float16_t), sliding->in_kh_step_ * sizeof(float16_t), + sliding->in_kw_step_ * sizeof(float16_t)); +#else + DeconvDepthwiseCenterFp16(out_t, in_t, weight, sliding->bottom_ - sliding->top_, + sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, + sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, + sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_); +#endif + } + DeconvDepthwisePostFuncFp16(dst_data, bias, sliding->block_channel_, conv_param); + } // output C8 loop + src += sliding->out_step_; + dst += sliding->in_step_; + } // batch loop + // output nchwc8 +} +/*deconv depthwise fp16 end*/ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_depthwise_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_depthwise_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..36d273d14de910864e9a5025c24f37001de802da --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_depthwise_fp16.h @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_CONV_DEPTHWISE_FP16_H_ +#define NNACL_FP16_CONV_DEPTHWISE_FP16_H_ + +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *filter_ptr, size_t num_pixels, + size_t input_channel, size_t input_step); +#ifdef ENABLE_ARM64 +void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, + size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, + size_t relu6); +void ConvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, + size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, + size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, + size_t relu, size_t relu6); +void DeconvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width, + size_t in_kh_step, size_t in_kw_step, size_t kernel_w); +void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +#endif + +void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, int task_id); + +void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +#ifdef ENABLE_ARM +void ConvDw3x3LineFp16(float16_t *dst, float16_t **lines, const float16_t *weight, const float16_t *bias_data, + int width, int ori_channel, bool relu, bool relu6); +void ConvDw3x3Fp16(float16_t *output_data, float16_t *buffer, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_CONV_DEPTHWISE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..515186cda6b94dd48ee5355b8c1bd8e5813d4554 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_fp16.c @@ -0,0 +1,334 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/conv_fp16.h" +#include +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/fp16/winograd_transform_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" + +void Im2ColPackUnitFp16(const float16_t *input_data, const ConvParameter *conv_param, float16_t *packed_input, + int real_cal_num, int block_index) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int kernel_plane = kernel_h * kernel_w; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + int input_stride = (input_h * in_w + input_w) * in_channel; + int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); + int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h)); + int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); + int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); + if (dilation_h == 1 && dilation_w == 1) { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * in_w * in_channel + input_stride; + int input_x_stride = input_y_stride + kw_s * in_channel; + int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, + (kw_e - kw_s) * in_channel * sizeof(float16_t)); + } // kernel_h loop + } else { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; + for (int n = kw_s; n < kw_e; n++) { + int input_x_stride = input_y_stride + n * dilation_w * in_channel; + int input_plane_offset = (j * kernel_w + n) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float16_t)); + } // kernel_w loop + } // kernel_h loop + } + } // tile num loop +} + +// fp16 convolution common (im2col+gemm) +void ConvFp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight, + const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id, + const ConvParameter *conv_param) { +#ifdef ENABLE_ARM64 + const int tile_n = 16; +#else + const int tile_n = 12; +#endif + NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_); + NNACL_CHECK_ZERO_RETURN(tile_n); + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int block_per_thread = UP_DIV(UP_DIV(output_hw, tile_n), conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int start_hw = start_block * tile_n; + int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * tile_n); + if (start_hw >= end_hw) { + return; + } + int out_stride = conv_param->output_channel_ * tile_n; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * tile_n; + col_major_input += task_id * deep * tile_n; + size_t input_size = deep * tile_n * sizeof(float16_t); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * conv_param->output_channel_ * output_hw + start_hw * conv_param->output_channel_; + for (int i = start_hw; i < end_hw; i += tile_n, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, tile_n); + memset(packed_input, 0, input_size); + Im2ColPackUnitFp16(input_data + in_offset, conv_param, packed_input, real_cal_row, i); +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorFp16Opt(packed_input, col_major_input, tile_n, deep); +#else + RowMajor2Col12MajorFp16Opt(packed_input, col_major_input, tile_n, deep); +#endif + MatMulFp16(col_major_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep, + real_cal_row, conv_param->output_channel_, conv_param->output_channel_, OutType_Nhwc); + } + } +} + +void ConvOutNc8hw8Fp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight, + const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id, + const ConvParameter *conv_param) { +#ifdef ENABLE_ARM64 + const int tile_n = 16; +#else + const int tile_n = 12; +#endif + NNACL_CHECK_ZERO_RETURN(conv_param->op_parameter_.thread_num_); + NNACL_CHECK_ZERO_RETURN(tile_n); + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int input_block = UP_DIV(output_hw, tile_n); + int block_per_thread = UP_DIV(input_block, conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int end_block = MSMIN(start_block + block_per_thread, input_block); + if (start_block >= end_block) { + return; + } + int weight_block = UP_DIV(conv_param->output_channel_, C8NUM); + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += deep * tile_n * task_id; + col_major_input += deep * tile_n * task_id; + size_t input_size = deep * tile_n * sizeof(float16_t); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + for (int i = start_block; i < end_block; i++) { + int real_in_row = (i != input_block - 1) ? tile_n : output_hw - i * tile_n; + memset(packed_input, 0, input_size); + Im2ColPackUnitFp16(input_data + in_offset, conv_param, packed_input, real_in_row, i * tile_n); +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorFp16Opt(packed_input, col_major_input, tile_n, deep); +#else + RowMajor2Col12MajorFp16Opt(packed_input, col_major_input, tile_n, deep); +#endif + const float16_t *cur_weight = packed_weight; + const float16_t *cur_bias = bias_data; + for (int j = 0; j < weight_block; j++, cur_weight += C8NUM * deep, cur_bias += C8NUM) { + int real_weight_row = (j != weight_block - 1) ? C8NUM : conv_param->output_channel_ - j * C8NUM; + int out_offset = j * output_hw * C8NUM + i * tile_n * real_weight_row; + MatMulFp16(col_major_input, cur_weight, output_data + out_offset, cur_bias, conv_param->act_type_, deep, + real_in_row, real_weight_row, real_weight_row, OutType_Nhwc); + } + } + } +} + +void Conv1x1OutNc8hw8MultiThreadByInputFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight, + const float16_t *bias, float16_t *output, int task_id, + const MatMulParameter *param) { +#ifdef ENABLE_ARM64 + const int tile_n = 16; +#else + const int tile_n = 12; +#endif + NNACL_CHECK_ZERO_RETURN(tile_n); + NNACL_CHECK_ZERO_RETURN(param->op_parameter_.thread_num_); + int input_block = UP_DIV(param->row_, tile_n); + int weight_block = UP_DIV(param->col_, C8NUM); + + int block_per_thread = UP_DIV(input_block, param->op_parameter_.thread_num_); + int input_start_block = block_per_thread * task_id; + int input_end_block = MSMIN(input_start_block + block_per_thread, input_block); + if (input_start_block >= input_end_block) { + return; + } + input += input_start_block * tile_n * param->deep_; + pack_input += input_start_block * tile_n * param->deep_; + + int cur_row_cnt = MSMIN(block_per_thread * tile_n, param->row_ - input_start_block * tile_n); +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorFp16Opt(input, pack_input, cur_row_cnt, param->deep_); +#else + RowMajor2Col12MajorFp16Opt(input, pack_input, cur_row_cnt, param->deep_); +#endif + for (int i = input_start_block; i < input_end_block; i++) { + int real_in_row = (i != input_block - 1) ? tile_n : param->row_ - i * tile_n; + const float16_t *cur_weight = weight; + const float16_t *cur_bias = bias; + for (int j = 0; j < weight_block; j++, cur_weight += C8NUM * param->deep_, cur_bias += C8NUM) { + int real_weight_row = (j != weight_block - 1) ? C8NUM : param->col_ - j * C8NUM; + int out_offset = j * param->row_ * C8NUM + i * tile_n * real_weight_row; + MatMulFp16(pack_input, cur_weight, output + out_offset, cur_bias, param->act_type_, param->deep_, real_in_row, + real_weight_row, real_weight_row, OutType_Nhwc); + } + pack_input += real_in_row * param->deep_; + } +} + +void Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight, + const float16_t *bias, float16_t *output, int task_id, + const MatMulParameter *param) { +#ifdef ENABLE_ARM64 + const int tile_n = 16; +#else + const int tile_n = 12; +#endif + NNACL_CHECK_ZERO_RETURN(tile_n); + NNACL_CHECK_ZERO_RETURN(param->op_parameter_.thread_num_); + int input_block = UP_DIV(param->row_, tile_n); + int weight_block = UP_DIV(param->col_, C8NUM); + + int block_per_thread = UP_DIV(weight_block, param->op_parameter_.thread_num_); + int weight_start_block = block_per_thread * task_id; + int weight_end_block = MSMIN(weight_start_block + block_per_thread, weight_block); + if (weight_start_block >= weight_end_block) { + return; + } + for (int i = 0; i < input_block; i++) { + int real_in_row = (i != input_block - 1) ? tile_n : param->row_ - i * tile_n; + const float16_t *cur_weight = weight + weight_start_block * C8NUM * param->deep_; + const float16_t *cur_bias = bias + weight_start_block * C8NUM; + for (int j = weight_start_block; j < weight_end_block; j++, cur_weight += C8NUM * param->deep_, cur_bias += C8NUM) { + int real_weight_row = (j != weight_block - 1) ? C8NUM : param->col_ - j * C8NUM; + int out_offset = j * param->row_ * C8NUM + i * tile_n * real_weight_row; + MatMulFp16(pack_input, cur_weight, output + out_offset, cur_bias, param->act_type_, param->deep_, real_in_row, + real_weight_row, real_weight_row, OutType_Nhwc); + } + pack_input += real_in_row * param->deep_; + } +} + +// fp16 convolution winograd +void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight, const float16_t *bias_data, + float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, + const ConvParameter *conv_param, TransFp16FuncList trans_func) { +#ifdef ENABLE_ARM64 + const int tile_num = 16; +#else + const int tile_num = 12; +#endif + NNACL_CHECK_ZERO_RETURN(conv_param->output_unit_); + NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_); + int in_channel = conv_param->input_channel_; + int input_unit = conv_param->input_unit_; + int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); + int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); + int output_count = out_w_block * out_h_block; + int per_thread_num = UP_DIV(output_count, conv_param->thread_num_); + int real_tile = per_thread_num < tile_num ? per_thread_num : tile_num; + NNACL_CHECK_ZERO_RETURN(real_tile); + int output_tile_count = UP_DIV(output_count, real_tile); + int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); + int input_unit_square = input_unit * input_unit; + + float16_t *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel; + float16_t *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM; + float16_t *tmp_data = buffer_list[2] + task_id * input_unit_square * C8NUM; + float16_t *col_buffer = buffer_list[3] + task_id * tile_num * in_channel; + // step 1 : filter transform (pre-processed offline) + // step 2 : input transform (online) + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { + int out_tile_index = thread_id * real_tile; + int cal_num = output_count - thread_id * real_tile; + cal_num = cal_num > real_tile ? real_tile : cal_num; + if (cal_num <= 0) { + return; + } + +#ifdef ENABLE_ARM64 + // Optimize input transform. Only valid for arm64, the tile num is 16. + // For arm32, the tile_num is 12. The function(InputTransform4x4Pack12Fp16) needs to be rewritten. + bool fused_pack = + (cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL); + if (fused_pack) { + float16_t *opt_trans_input = + buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, C8NUM); + WinogradInputTransformOptStepFp16(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num, + out_tile_index, out_w_block, conv_param, trans_func.in_step_func_); + + for (int w_index = 0; w_index < input_unit; w_index++) { + float16_t *src_w = opt_trans_input + w_index * input_unit * tile_num * C8NUM; + for (int c = 0; c < UP_DIV(in_channel, C8NUM); c++) { + int real_c = in_channel - c * C8NUM; + real_c = real_c > C8NUM ? C8NUM : real_c; + float16_t *src_c = src_w + c * input_unit_square * tile_num * C8NUM; + float16_t *dst_c = trans_input + c * tile_num * C8NUM; + trans_func.in_pack_func_(src_c, dst_c, C8NUM, in_channel * tile_num, real_c); + } + + for (int h_index = 0; h_index < input_unit; h_index++) { + const float16_t *gemm_input = trans_input + h_index * tile_num * in_channel; + int point_index = h_index * input_unit + w_index; + const float16_t *gemm_weight = trans_weight + point_index * in_channel * oc8 * C8NUM; + MatMulFp16(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num, + oc8 * C8NUM, input_unit_square, OutType_TileC8); + } + } + } else { +#endif + WinogradInputTransformFp16(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_func_); + // step 3 : gemm + float16_t *src_ptr = trans_input; + float16_t *dst_ptr = gemm_out; + float16_t *tmp_col_ptr = col_buffer; + for (int i = 0; i < input_unit_square; ++i) { +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); +#else + RowMajor2Col12MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); +#endif + MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel, + cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8); + } +#ifdef ENABLE_ARM64 + } +#endif + + // step 4 : output transform + if (conv_param->out_format_ != Format_NC4HW4) { // nc4hw4 + WinogradOutputNHWCTransformFp16(gemm_out, output_data + out_batch_offset, bias_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.out_func_); + } else { + WinogradOutputNC8HW8TransformFp16(gemm_out, output_data + out_batch_offset, bias_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.out_func_); + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..6d296b669e8bc0c00f6fafe65a6a9f3d66010b0a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_fp16.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_CONV_FP16_H_ +#define NNACL_FP16_CONV_FP16_H_ + +#include +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/fp16/winograd_utils_fp16.h" +#include "nnacl_c/fp16/winograd_transform_fp16.h" + +typedef float16_t *TmpBufferAddressFp16; +typedef float16_t *MatricesFp16; + +#ifdef __cplusplus +extern "C" { +#endif +void Im2ColPackUnitFp16(const float16_t *input_data, const ConvParameter *conv_param, float16_t *packed_input, + int real_cal_num, int block_index); + +// fp16 convolution common (im2col+gemm) +void ConvFp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight, + const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id, + const ConvParameter *conv_param); + +void ConvOutNc8hw8Fp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight, + const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id, + const ConvParameter *conv_param); + +void Conv1x1OutNc8hw8MultiThreadByInputFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight, + const float16_t *bias, float16_t *output, int task_id, + const MatMulParameter *param); + +void Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight, + const float16_t *bias, float16_t *output, int task_id, + const MatMulParameter *param); + +// fp16 convolution winograd +void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight, const float16_t *bias_data, + float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, + const ConvParameter *conv_param, TransFp16FuncList trans_func); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_CONV_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/crop_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/crop_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..31e1f08e92b00d45d622d28141f39e3a1c1c802c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/crop_fp16.c @@ -0,0 +1,155 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/crop_fp16.h" + +#include + +#include "nnacl_c/crop_parameter.h" + +void Fp16Crop1D(const float16_t *input, float16_t *output, int *out_shape, int64_t *in_offset, int task_id, + int thread_count) { + const int out_batch = out_shape[0]; + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_batch, thread_count) : out_batch; + if (task_id_stride <= 0) { + return; + } + int n = task_id * task_id_stride; + if (n >= out_batch) { + return; + } + const float16_t *in_ptr = input + n + in_offset[0]; + float16_t *out_ptr = output + n; + int64_t out_dist_stride = MSMIN(out_batch - task_id * task_id_stride, task_id_stride); + memcpy(out_ptr, in_ptr, sizeof(float16_t) * out_dist_stride); +} + +void Fp16Crop2D(const float16_t *input, float16_t *output, int *in_shape, int *out_shape, int64_t *in_offset, + int task_id, int thread_count) { + const int in_height = in_shape[1]; + const int out_batch = out_shape[0]; + const int out_height = out_shape[1]; + + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + for (int n = 0; n < out_batch; n++) { + int h = task_id * task_id_stride; + if (h >= out_height) { + return; + } + const float16_t *in_ptr = input + (n + in_offset[0]) * in_height + h + in_offset[1]; + float16_t *out_ptr = output + n * out_height + h; + int64_t out_dist_stride = MSMIN(out_height - task_id * task_id_stride, task_id_stride); + memcpy(out_ptr, in_ptr, sizeof(float16_t) * out_dist_stride); + } +} + +void Fp16Crop3D(const float16_t *input, float16_t *output, int *in_shape, int *out_shape, int64_t *in_offset, + int task_id, int thread_count) { + const int in_height = in_shape[1]; + const int in_width = in_shape[2]; + + const int out_batch = out_shape[0]; + const int out_height = out_shape[1]; + const int out_width = out_shape[2]; + + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + const int in_stride_h = in_width; + const int in_stride_n = in_stride_h * in_height; + + const int out_stride_h = out_width; + const int out_stride_n = out_stride_h * out_height; + + for (int n = 0; n < out_batch; n++) { + for (int t = 0; t < task_id_stride; t++) { + int h = t + task_id * task_id_stride; + if (h >= out_height) { + break; + } + const float16_t *in_ptr = + input + (n + in_offset[0]) * in_stride_n + (h + in_offset[1]) * in_stride_h + in_offset[2]; + float16_t *out_ptr = output + n * out_stride_n + h * out_stride_h; + memcpy(out_ptr, in_ptr, sizeof(float16_t) * out_width); + } + } +} + +void Fp16Crop4D(const float16_t *input, float16_t *output, int *in_shape, int *out_shape, int64_t *in_offset, + int task_id, int thread_count) { + const int in_height = in_shape[1]; + const int in_width = in_shape[2]; + const int in_channel = in_shape[3]; + + const int out_batch = out_shape[0]; + const int out_height = out_shape[1]; + const int out_width = out_shape[2]; + const int out_channel = out_shape[3]; + + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + const int in_stride_w = in_channel; + const int in_stride_h = in_channel * in_width; + const int in_stride_n = in_stride_h * in_height; + + const int out_stride_w = out_channel; + const int out_stride_h = out_channel * out_width; + const int out_stride_n = out_stride_h * out_height; + + for (int n = 0; n < out_batch; n++) { + for (int t = 0; t < task_id_stride; t++) { + int h = t + task_id * task_id_stride; + if (h >= out_height) { + break; + } + for (int w = 0; w < out_width; w++) { + const float16_t *in_ptr = input + (n + in_offset[0]) * in_stride_n + (h + in_offset[1]) * in_stride_h + + (w + in_offset[2]) * in_stride_w + in_offset[3]; + float16_t *out_ptr = output + n * out_stride_n + h * out_stride_h + w * out_stride_w; + memcpy(out_ptr, in_ptr, sizeof(float16_t) * out_channel); + } + } + } +} + +void Fp16Crop(const float16_t *input, float16_t *output, int *in_shape, int *out_shape, int64_t *in_offset, + int input_dim, int task_id, int thread_num) { + switch (input_dim) { + case 1: + Fp16Crop1D(input, output, out_shape, in_offset, task_id, thread_num); + break; + case 2: + Fp16Crop2D(input, output, in_shape, out_shape, in_offset, task_id, thread_num); + break; + case 3: + Fp16Crop3D(input, output, in_shape, out_shape, in_offset, task_id, thread_num); + break; + case 4: + Fp16Crop4D(input, output, in_shape, out_shape, in_offset, task_id, thread_num); + break; + default: + break; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/crop_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/crop_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..d039dd4eb37faaf83cf03de78529ed3bcf93e58f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/crop_fp16.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_CROP_FP16_H_ +#define NNACL_FP16_CROP_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/crop_parameter.h" + +void Fp16Crop(const float16_t *input, float16_t *output, int *in_shape, int *out_shape, int64_t *in_offset, + int input_dim, int task_id, int thread_num); + +#endif // NNACL_FP16_CROP_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/custom_gru_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/custom_gru_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..5d20218e445db4d9baa5a2c58e800e6b22bb10c8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/custom_gru_fp16.c @@ -0,0 +1,70 @@ +#ifdef ENABLE_ARM64 +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/custom_gru_fp16.h" +#include "nnacl_c/fp16/activation_fp16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" + +void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *weight_input, + const float16_t *weight_hidden, const float16_t *bias_input, const float16_t *bias_hidden, + const float16_t *init_h, float16_t *buffer[4], const CustomGruParameter *gru_param) { + int num_step = gru_param->num_step; + int batch_size = gru_param->batch_size; + int input_size = gru_param->input_size; + int hidden_size = gru_param->hidden_size; + int output_size = batch_size * hidden_size; + int double_output_size = output_size * C2NUM; + int col_align = UP_ROUND(hidden_size, C8NUM); + int weight_in_offset = col_align * input_size; + int weight_hidden_offset = col_align * hidden_size; + float16_t *input_gate = buffer[1]; + float16_t *hidden_gate = buffer[C3NUM]; + for (int i = 0; i < num_step; ++i) { + if (batch_size != 1) { + RowMajor2ColNMajorFp16(input + i * batch_size * input_size, buffer[0], batch_size, input_size, false); + for (int j = 0; j < C3NUM; ++j) { + MatmulBaseFp16Neon(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + RowMajor2ColNMajorFp16(init_h, buffer[C2NUM], batch_size, hidden_size, false); + for (int j = 0; j < C3NUM; ++j) { + MatmulBaseFp16Neon(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + } else { + for (int j = 0; j < C3NUM; ++j) { + VecMatmulFp16(input + i * input_size, weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, hidden_size); + VecMatmulFp16(init_h, weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, hidden_size); + } + } + ElementAddFp16(input_gate, hidden_gate, input_gate, double_output_size); + SigmoidFp16(input_gate, input_gate, double_output_size); + ElementMulFp16(input_gate, hidden_gate + double_output_size, input_gate, output_size); + ElementAddFp16(input_gate, input_gate + double_output_size, input_gate, output_size); + TanhFp16(input_gate, input_gate, output_size); + ElementSubFp16(init_h, input_gate, hidden_gate, output_size); + ElementMulFp16(input_gate + output_size, hidden_gate, hidden_gate, output_size); + ElementAddFp16(input_gate, hidden_gate, output, output_size); + init_h = output; + output += output_size; + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/custom_gru_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/custom_gru_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..c6f95a8dd0fd45dc8123a0ec4250a0f92f2aa3cc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/custom_gru_fp16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_CUSTOM_GRU_FP16_H_ +#define NNACL_FP16_CUSTOM_GRU_FP16_H_ +#ifdef ENABLE_ARM64 +#include "nnacl_c/custom_gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *weight_input, + const float16_t *weight_hidden, const float16_t *bias_input, const float16_t *bias_hidden, + const float16_t *init_h, float16_t *buffer[4], const CustomGruParameter *gru_param); +#ifdef __cplusplus +} +#endif + +#endif +#endif // NNACL_FP16_CUSTOM_GRU_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..6525441a99fb7f7ef10842cbc56e3ef4c567f02b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_fp16.c @@ -0,0 +1,129 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/deconv_fp16.h" +#include + +void DeConvPostAddC8WithStride(const float16_t *source, float16_t *dest, size_t srcStride, size_t dststride, + size_t count) { + if (count == 0) { + return; + } + + const float16_t *src_ptr = source; + float16_t *dst_ptr = dest; + float16x8_t src1 = vld1q_f16(src_ptr); + float16x8_t dst1 = vld1q_f16(dst_ptr); + float16x8_t src2; + float16x8_t dst2; + size_t i = 1; + while (i < count - 1) { + dst1 = vaddq_f16(dst1, src1); + vst1q_f16(dst_ptr, dst1); + + src2 = vld1q_f16(src_ptr + srcStride); + dst2 = vld1q_f16(dst_ptr + dststride); + dst2 = vaddq_f16(dst2, src2); + vst1q_f16(dst_ptr + dststride, dst2); + i = i + 2; + src1 = vld1q_f16(src_ptr + srcStride + srcStride); + dst1 = vld1q_f16(dst_ptr + dststride + dststride); + + src_ptr = src_ptr + srcStride + srcStride; + dst_ptr = dst_ptr + dststride + dststride; + } + dst1 = vaddq_f16(dst1, src1); + vst1q_f16(dst_ptr, dst1); + if (i < count) { + src2 = vld1q_f16(src_ptr + srcStride); + dst2 = vld1q_f16(dst_ptr + dststride); + dst2 = vaddq_f16(dst2, src2); + vst1q_f16(dst_ptr + dststride, dst2); + } +} + +int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, float16_t *dst, int output_channel, + const ConvParameter *conv_param) { + float16x8_t min_v = vdupq_n_f16(-FLT_MAX); + float16x8_t max_v = vdupq_n_f16(FLT_MAX); + if (conv_param->act_type_ == ActType_Relu) { + min_v = vdupq_n_f16(0.f); + } + if (conv_param->act_type_ == ActType_Relu6) { + min_v = vdupq_n_f16(0.f); + max_v = vdupq_n_f16(6.f); + } + + /* row8x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ + size_t input_plane = conv_param->input_w_ * conv_param->input_h_; + size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + size_t output_plane = conv_param->output_w_ * conv_param->output_h_; + int oc8 = UP_ROUND(output_channel, C8NUM); + int in_plane16 = UP_ROUND(input_plane, 16); + int src_iw_stride = C8NUM; + int src_ih_stride = conv_param->input_w_ * C8NUM; + int src_kw_stride = in_plane16 * C8NUM; + int src_kh_stride = in_plane16 * conv_param->kernel_w_ * C8NUM; + int dst_oh_stride = conv_param->output_w_ * C8NUM; + int dst_ow_stride = C8NUM; + int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM; + int dst_kw_stride = conv_param->dilation_w_ * C8NUM; + + NNACL_CHECK_ZERO_RETURN_ERR(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN_ERR(conv_param->dilation_w_); + + for (int c = 0; c < oc8; c += 8) { + float16_t *dst_ptr = tmp + c * output_plane; + const float16_t *src_ptr = src + c * in_plane16 * kernel_plane; + memset(dst_ptr, 0, output_plane * C8NUM * sizeof(float16_t)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + + const float16_t *src_in_ptr = src_ptr + ih * src_ih_stride + iw * src_iw_stride; + float16_t *dst_in_ptr = dst_ptr + oh * dst_oh_stride + ow * dst_ow_stride; + + for (int kh = kh_start; kh < kh_end; kh++) { + const float16_t *src_kh_ptr = src_in_ptr + kh * src_kh_stride; + float16_t *dst_kh_ptr = dst_in_ptr + kh * dst_kh_stride; + DeConvPostAddC8WithStride(src_kh_ptr + kw_start * src_kw_stride, dst_kh_ptr + kw_start * dst_kw_stride, + src_kw_stride, dst_kw_stride, kw_end - kw_start); + } // kh + } // iw + } // ih + + /* add bias for current oh*ow*C8 + * write to output data ptr in nhwc format */ + float16x8_t bias_v = vld1q_f16(bias + c); + float16_t *pack_tmp_data = dst_ptr; + for (size_t i = 0; i < output_plane; i++) { + float16x8_t data_v = vld1q_f16(pack_tmp_data); + data_v = vaddq_f16(data_v, bias_v); + data_v = vminq_f16(data_v, max_v); + data_v = vmaxq_f16(data_v, min_v); + vst1q_f16(pack_tmp_data, data_v); + pack_tmp_data += C8NUM; + } + } // oc8 + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..5d3a657285007d641da0fb935e29fb36d1f134bd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_fp16.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_DECONV_FP16_H_ +#define NNACL_FP16_DECONV_FP16_H_ + +#include +#include +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp16/common_func_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, float16_t *dst, int output_channel, + const ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_DECONV_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_winograd_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_winograd_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..00b9450d93fe6104d6a74412edd1b085ce0b892e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_winograd_fp16.c @@ -0,0 +1,519 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/deconv_winograd_fp16.h" +#include "nnacl_c/base/minimal_filtering_generator.h" + +void DeConvWgInputPackFp16(const float16_t *src_ptr, float16_t *dst_ptr, int channel, int stride) { + int ic4div = channel / C4NUM; + int ic4mod = channel % C4NUM; + const float16_t *src = src_ptr; + float16_t *dst = dst_ptr; + + for (int ic = 0; ic < ic4div; ic++) { + vst1_f16(dst, vld1_f16(src)); + dst += stride; + src += C4NUM; + } + + if (ic4mod != 0) { + int ic_res = 0; + for (; ic_res < ic4mod; ic_res++) { + dst[ic_res] = src[ic_res]; + } + for (; ic_res < C4NUM; ic_res++) { + dst[ic_res] = 0; + } + } + return; +} + +#ifdef ENABLE_ARM82_A32 +void DeconvWgMergeFp16A32Fun(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_step, size_t dst_step) { + asm volatile( + "mov r0, %[src_ptr]\n" + "mov r1, %[dst_ptr]\n" + "mov r2, r1\n" + + "vld1.16 {d0}, [r0], %[src_step]\n" + "vld1.16 {d2}, [r1], %[dst_step]\n" + "vld1.16 {d4}, [r0], %[src_step]\n" + "vld1.16 {d6}, [r1], %[dst_step]\n" + "vadd.f16 d0, d0, d2\n" + "vld1.16 {d8}, [r0], %[src_step]\n" + "vadd.f16 d4, d4, d6\n" + "vst1.16 {d0}, [r2], %[dst_step]\n" + "vst1.16 {d4}, [r2], %[dst_step]\n" + + "vld1.16 {d10}, [r1], %[dst_step]\n" + "vld1.16 {d12}, [r0], %[src_step]\n" + "vadd.f16 d8, d8, d10\n" + "vld1.16 {d14}, [r1], %[dst_step]\n" + "vadd.f16 d12, d12, d14\n" + "vld1.16 {d0}, [r0], %[src_step]\n" + "vst1.16 {d8}, [r2], %[dst_step]\n" + "vst1.16 {d12}, [r2], %[dst_step]\n" + + "vld1.16 {d2}, [r1], %[dst_step]\n" + "vld1.16 {d4}, [r0], %[src_step]\n" + "vld1.16 {d6}, [r1], %[dst_step]\n" + "vadd.f16 d0, d0, d2\n" + "vadd.f16 d4, d4, d6\n" + "vst1.16 {d0}, [r2], %[dst_step]\n" + "vst1.16 {d4}, [r2], %[dst_step]\n" + + "vld1.16 {d8}, [r0], %[src_step]\n" + "vld1.16 {d10}, [r1], %[dst_step]\n" + "vld1.16 {d12}, [r0], %[src_step]\n" + "vld1.16 {d14}, [r1], %[dst_step]\n" + "vadd.f16 d8, d8, d10\n" + "vadd.f16 d12, d12, d14\n" + "vst1.16 {d8}, [r2], %[dst_step]\n" + "vst1.16 {d12}, [r2], %[dst_step]\n" + + : + : [src_ptr] "r"(src_ptr), [dst_ptr] "r"(dst_ptr), [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "r0", "r1", "r2", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +} +#endif + +void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride, size_t count) { + const float16_t *src_ptr = src; + float16_t *dst_ptr = dst; + size_t cuont8 = count / C8NUM * C8NUM; + int i = 0; + for (; i < cuont8; i += C8NUM) { +#ifdef ENABLE_ARM64 + size_t src_step = src_stride * sizeof(float16_t); + size_t dst_step = dst_stride * sizeof(float16_t); + asm volatile( + "mov x7, %[src_ptr]\n" + "mov x8, %[dst_ptr]\n" + "mov x10, x8\n" + + "ld1 {v0.4h}, [x7], %[src_step]\n" + "ld1 {v1.4h}, [x8], %[dst_step]\n" + "ld1 {v2.4h}, [x7], %[src_step]\n" + "ld1 {v3.4h}, [x8], %[dst_step]\n" + "fadd v0.4h, v0.4h, v1.4h\n" + "ld1 {v4.4h}, [x7], %[src_step]\n" + "fadd v2.4h, v2.4h, v3.4h\n" + "st1 {v0.4h}, [x10], %[dst_step]\n" + "st1 {v2.4h}, [x10], %[dst_step]\n" + + "ld1 {v5.4h}, [x8], %[dst_step]\n" + "ld1 {v6.4h}, [x7], %[src_step]\n" + "fadd v4.4h, v4.4h, v5.4h\n" + "ld1 {v7.4h}, [x8], %[dst_step]\n" + "fadd v6.4h, v6.4h, v7.4h\n" + "ld1 {v0.4h}, [x7], %[src_step]\n" + "st1 {v4.4h}, [x10], %[dst_step]\n" + "st1 {v6.4h}, [x10], %[dst_step]\n" + + "ld1 {v1.4h}, [x8], %[dst_step]\n" + "ld1 {v2.4h}, [x7], %[src_step]\n" + "ld1 {v3.4h}, [x8], %[dst_step]\n" + "fadd v0.4h, v0.4h, v1.4h\n" + "fadd v2.4h, v2.4h, v3.4h\n" + "st1 {v0.4h}, [x10], %[dst_step]\n" + "st1 {v2.4h}, [x10], %[dst_step]\n" + + "ld1 {v4.4h}, [x7], %[src_step]\n" + "ld1 {v5.4h}, [x8], %[dst_step]\n" + "ld1 {v6.4h}, [x7], %[src_step]\n" + "ld1 {v7.4h}, [x8], %[dst_step]\n" + "fadd v4.4h, v4.4h, v5.4h\n" + "fadd v6.4h, v6.4h, v7.4h\n" + "st1 {v4.4h}, [x10], %[dst_step]\n" + "st1 {v6.4h}, [x10], %[dst_step]\n" + + : + : [src_ptr] "r"(src_ptr), [dst_ptr] "r"(dst_ptr), [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#elif defined(ENABLE_ARM82_A32) + size_t src_step = src_stride * sizeof(float16_t); + size_t dst_step = dst_stride * sizeof(float16_t); + DeconvWgMergeFp16A32Fun(src_ptr, dst_ptr, src_step, dst_step); +#else + for (int j = 0; j < C8NUM; j++) { + const float16_t *s = src_ptr + j * src_stride; + float16_t *d = dst_ptr + j * dst_stride; + for (int k = 0; k < C4NUM; k++) { + d[k] += s[k]; + } + } +#endif + src_ptr += C8NUM * src_stride; + dst_ptr += C8NUM * dst_stride; + } + + for (; i < count; i++) { + float16x4_t src_data = vld1_f16(src_ptr); + float16x4_t dst_data = vld1_f16(dst_ptr); + dst_data = vadd_f16(src_data, dst_data); + vst1_f16(dst_ptr, dst_data); + + src_ptr += src_stride; + dst_ptr += dst_stride; + } + return; +} + +void DeConvWgCalWgFp16(const float16_t *tile_in, float16_t *tile_out, const float16_t *weight_buf, float16_t *tmp_buf, + const float16_t *at_buf, float16_t *a_mid_buf, float16_t *trans_a_buf, bool *transferred, + const float16_t *bt_buf, float16_t *b_tmp_buf, int unit_size, int w_start, int h_start, + const ConvParameter *conv_param, const DeConvParam *deconv_param) { + int winograd_plane = unit_size * unit_size; + if (!transferred[unit_size]) { + WinogradTransLeftFp16(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, + DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransRightFp16(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, + deconv_param->ic_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + transferred[unit_size] = true; + } + + for (int index = 0; index < winograd_plane; index++) { + float16_t *src = trans_a_buf + index * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float16_t *dst = tmp_buf + index * deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + const float16_t *weight = weight_buf + index * deconv_param->ic_up_ * deconv_param->oc_up_; + TiledC4MatmulFp16(dst, src, weight, DECONV_WINOGRAD_DEFAULT_TILE * C4NUM, deconv_param->ic_div_, + deconv_param->oc_div_); + } + + WinogradTransLeftFp16(tmp_buf, bt_buf, b_tmp_buf, unit_size, unit_size, unit_size, + deconv_param->oc_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransRightFp16(b_tmp_buf, bt_buf, tmp_buf, unit_size, unit_size, unit_size, + deconv_param->oc_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + + // Add to dest + for (int uhi = 0; uhi < unit_size; uhi++) { + int h_index = uhi * conv_param->stride_h_ + h_start; + for (int uwi = 0; uwi < unit_size; uwi++) { + int w_index = uwi * conv_param->stride_w_ + w_start; + + float16_t *dst = tile_out + w_index * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_ + + h_index * deconv_param->out_tile_w_ * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + float16_t *src = tmp_buf + (uwi + uhi * unit_size) * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + DeConvWgMergeFp16(src, dst, C4NUM, C4NUM, DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_div_); + } + } + return; +} + +void DeConvWgCalCommFp16(const float16_t *tile_in, float16_t *tile_out, const float16_t *weight, float16_t *tmp_buf, + int h_start, int w_start, int h_size, int w_size, const ConvParameter *conv_param, + const DeConvParam *deconv_param) { + int count = deconv_param->oc_div_ * w_size * h_size; + int in_stride = DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + int out_stride = DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + + for (int hi = 0; hi < DECONV_WINOGRAD_DEFAULT_UNIT; hi++) { + for (int wi = 0; wi < DECONV_WINOGRAD_DEFAULT_UNIT; wi++) { + const float16_t *src_in = tile_in + (wi + hi * DECONV_WINOGRAD_DEFAULT_UNIT) * in_stride; + TiledC4MatmulFp16(tmp_buf, src_in, weight, DECONV_WINOGRAD_DEFAULT_TILE * C4NUM, deconv_param->ic_div_, count); + + for (int uhi = 0; uhi < h_size; uhi++) { + for (int uwi = 0; uwi < w_size; uwi++) { + int w_index = (wi + uwi) * conv_param->stride_w_ + w_start; + int h_index = (hi + uhi) * conv_param->stride_h_ + h_start; + float16_t *dst = tile_out + h_index * out_stride * deconv_param->out_tile_w_ + w_index * out_stride; + float16_t *src = tmp_buf + (uwi + uhi * w_size) * out_stride; + DeConvWgMergeFp16(src, dst, C4NUM, C4NUM, DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_div_); + } + } + } + } + return; +} + +int PackDeConvWgDataFp16(const float16_t *nhwc_weight, DeConvComputeUnit *unit, const ConvParameter *conv_param, + const DeConvParam *deconv_param) { + int tmp_kernel_plane = unit->w_size_ * unit->h_size_; + int output_channel = conv_param->output_channel_; + int size = conv_param->input_channel_ * output_channel * tmp_kernel_plane; + float16_t *current_unit_weight = (float16_t *)malloc(size * sizeof(float16_t)); + if (current_unit_weight == NULL) { + return NNACL_NULL_PTR; + } + for (int ic = 0; ic < conv_param->input_channel_; ic++) { + const float16_t *src_ic = nhwc_weight + deconv_param->kernel_plane_ * output_channel * ic; + float16_t *dst_ic = current_unit_weight + tmp_kernel_plane * output_channel * ic; + for (int uhi = 0; uhi < unit->h_size_; uhi++) { + for (int uwi = 0; uwi < unit->w_size_; uwi++) { + int src_h_offset = unit->h_start_ + uhi * conv_param->stride_h_; + int src_w_offset = unit->w_start_ + uwi * conv_param->stride_w_; + const float16_t *src_hw = src_ic + (src_h_offset * conv_param->kernel_w_ + src_w_offset) * output_channel; + float16_t *dst_hw = dst_ic + (uhi * unit->w_size_ + uwi) * output_channel; + memcpy(dst_hw, src_hw, output_channel * sizeof(float16_t)); + } + } + } + + if (unit->use_winograd_) { + /* Generate winograd */ + float matrix_g[64]; + float matrix_gt[64]; + float matrix_a[64]; + float matrix_at[64]; + float matrix_b[64]; + float matrix_bt[64]; + int ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 0.5f, + DECONV_WINOGRAD_DEFAULT_UNIT, unit->h_size_); + if (ret != NNACL_OK) { + free(current_unit_weight); + current_unit_weight = NULL; + return NNACL_ERRCODE_WINOGRAD_GENERATOR_ERROR; + } + + /* winograd AT */ + unit->winograd_.AT_ = malloc(unit->winograd_.i_ * unit->winograd_.o_ * sizeof(float16_t)); + if (unit->winograd_.AT_ == NULL) { + free(current_unit_weight); + current_unit_weight = NULL; + return NNACL_NULL_PTR; + } + Float32ToFloat16(matrix_at, unit->winograd_.AT_, unit->winograd_.i_ * unit->winograd_.o_); + + /* winograd BT */ + unit->winograd_.BT_ = malloc(unit->winograd_.o_ * unit->winograd_.o_ * sizeof(float16_t)); + if (unit->winograd_.BT_ == NULL) { + free(current_unit_weight); + free(unit->winograd_.AT_); + current_unit_weight = NULL; + unit->winograd_.AT_ = NULL; + return NNACL_NULL_PTR; + } + Float32ToFloat16(matrix_bt, unit->winograd_.BT_, unit->winograd_.o_ * unit->winograd_.o_); + + /* winograd Weight */ + size = conv_param->input_channel_ * output_channel * unit->winograd_.kh_ * unit->winograd_.kw_; + float16_t *winograd_unit_weight = (float16_t *)malloc(size * sizeof(float16_t)); + if (winograd_unit_weight == NULL) { + free(current_unit_weight); + free(unit->winograd_.AT_); + free(unit->winograd_.BT_); + current_unit_weight = NULL; + unit->winograd_.AT_ = NULL; + unit->winograd_.BT_ = NULL; + return NNACL_NULL_PTR; + } + + WinogradWeightTransformFp16(current_unit_weight, winograd_unit_weight, matrix_g, matrix_gt, C4NUM, + unit->winograd_.kh_, unit->h_size_, output_channel, conv_param->input_channel_, false); + + /* reset weight data & info */ + tmp_kernel_plane = unit->winograd_.kh_ * unit->winograd_.kw_; + free(current_unit_weight); + current_unit_weight = winograd_unit_weight; + winograd_unit_weight = NULL; + } + + /* trans mhwc -> hw1:k1-knc0-c4:k1-knc5-c8:hw2:k1-knc0-c4:k1 */ + float16_t *dst_weight = (float16_t *)unit->weight_; + size = deconv_param->ic_up_ * deconv_param->oc_up_ * tmp_kernel_plane; + memset(dst_weight, 0, size * sizeof(float16_t)); + for (int ic = 0; ic < conv_param->input_channel_; ic++) { + for (int oc = 0; oc < output_channel; oc++) { + int oc4div = oc / C4NUM, oc4mod = oc % C4NUM; + for (int upi = 0; upi < tmp_kernel_plane; upi++) { + int src_index = ic * output_channel * tmp_kernel_plane + upi * output_channel + oc; + int dst_index = upi * deconv_param->oc_up_ * deconv_param->ic_up_ + oc4div * C4NUM * deconv_param->ic_up_ + + ic * C4NUM + oc4mod; + dst_weight[dst_index] = current_unit_weight[src_index]; + } + } + } + + free(current_unit_weight); + current_unit_weight = NULL; + return NNACL_OK; +} + +void DeconvWgFp16(const float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_out, int start_index, + int calculate_count, const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id) { + NNACL_CHECK_ZERO_RETURN(deconv_param->in_tile_w_count_); + /* pack tile input */ + int tile_in_unit_stride = deconv_param->ic_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + float16x4_t zero = vdup_n_f16(0.0f); + + for (int unit_index = 0; unit_index < calculate_count; unit_index++) { + int plane_index = start_index + unit_index; + int w_unit_index = plane_index % deconv_param->in_tile_w_count_; + int h_unit_index = plane_index / deconv_param->in_tile_w_count_; + int w_start = w_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT; + int h_start = h_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT; + + float16_t *dst_unit = tile_in + unit_index * C4NUM; + for (int hi = 0; hi < DECONV_WINOGRAD_DEFAULT_UNIT; hi++) { + for (int wi = 0; wi < DECONV_WINOGRAD_DEFAULT_UNIT; wi++) { + float16_t *dst = dst_unit + (wi + hi * DECONV_WINOGRAD_DEFAULT_UNIT) * tile_in_unit_stride; + int w_index = w_start + wi; + int h_index = h_start + hi; + if (w_index >= conv_param->input_w_ || h_index >= conv_param->input_h_) { + for (int ic4_index = 0; ic4_index < deconv_param->ic_div_; ic4_index++) { + vst1_f16(dst + ic4_index * DECONV_WINOGRAD_DEFAULT_TILE * C4NUM, zero); + } + continue; + } + + const float16_t *src = nhwc_input_ + (w_index + h_index * conv_param->input_w_) * conv_param->input_channel_; + DeConvWgInputPackFp16(src, dst, conv_param->input_channel_, DECONV_WINOGRAD_DEFAULT_TILE * C4NUM); + } + } + } + + /* compute */ + bool transferred[DECONV_WINOGRAD_BUFFER_COUNT] = {false}; + for (int i = 0; i < deconv_param->compute_size_; i++) { + DeConvComputeUnit *unit = &deconv_param->compute_units_[i]; + if (unit->use_winograd_) { + float16_t *tmp_buf = (float16_t *)unit->tmp_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ * + deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + + /* winograd a buffer */ + if (unit->winograd_.kh_ >= DECONV_WINOGRAD_BUFFER_COUNT) { + return; + } + DeConvWgABuffer *tmp_a = &deconv_param->a_buffer_[unit->winograd_.kh_]; + float16_t *mid_a = (float16_t *)tmp_a->middle_buffer_ + task_id * unit->winograd_.kw_ * unit->winograd_.kh_ * + DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float16_t *dst_a = (float16_t *)tmp_a->dest_buffer_ + task_id * unit->winograd_.kw_ * unit->winograd_.kh_ * + DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float16_t *tmp_b = (float16_t *)unit->winograd_.b_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ * + DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + DeConvWgCalWgFp16(tile_in, tile_out, (float16_t *)unit->weight_, tmp_buf, unit->winograd_.AT_, mid_a, dst_a, + transferred, unit->winograd_.BT_, tmp_b, unit->winograd_.kh_, unit->w_start_, unit->h_start_, + conv_param, deconv_param); + } else { + float16_t *tmp_buf = (float16_t *)unit->tmp_buffer_ + task_id * deconv_param->oc_div_ * unit->w_size_ * + unit->h_size_ * DECONV_WINOGRAD_DEFAULT_TILE * C4NUM; + DeConvWgCalCommFp16(tile_in, tile_out, (float16_t *)unit->weight_, tmp_buf, unit->h_start_, unit->w_start_, + unit->h_size_, unit->w_size_, conv_param, deconv_param); + } + } + return; +} + +void DeconvWgPostFp16(const float16_t *tile_out, float16_t *nc4hw4_output, const ConvParameter *conv_param, + const DeConvParam *deconv_param, int calculate_count, int tile_index) { + NNACL_CHECK_ZERO_RETURN(deconv_param->in_tile_w_count_); + /* merge */ + int src_unit_stride = deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + + int src_stride = DECONV_WINOGRAD_DEFAULT_TILE * C4NUM; + int dst_stride = conv_param->output_w_ * conv_param->output_h_ * C4NUM; + + for (int index = 0; index < calculate_count; ++index) { + const float16_t *src_start = tile_out + index * C4NUM; + + int plane_index = tile_index * DECONV_WINOGRAD_DEFAULT_TILE + index; + int w_unit_index = plane_index % deconv_param->in_tile_w_count_; + int h_unit_index = plane_index / deconv_param->in_tile_w_count_; + int w_start = w_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT * conv_param->stride_w_ - conv_param->pad_l_; + int h_start = h_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT * conv_param->stride_h_ - conv_param->pad_u_; + float16_t *dst_start = nc4hw4_output + h_start * conv_param->output_w_ * C4NUM + w_start * C4NUM; + + int merge_w_start = MSMAX(-w_start, 0); + int merge_h_start = MSMAX(-h_start, 0); + int merge_h_end = MSMIN(deconv_param->out_tile_h_, conv_param->output_h_ - h_start); + int merge_w_end = MSMIN(deconv_param->out_tile_w_, conv_param->output_w_ - w_start); + + for (int hi = merge_h_start; hi < merge_h_end; hi++) { + for (int wi = merge_w_start; wi < merge_w_end; wi++) { + const float16_t *src = src_start + (hi * deconv_param->out_tile_w_ + wi) * src_unit_stride; + float16_t *dst = dst_start + (hi * conv_param->output_w_ + wi) * C4NUM; + DeConvWgMergeFp16(src, dst, src_stride, dst_stride, deconv_param->oc_div_); + } + } + } + return; +} + +#ifndef ENABLE_ARM +void WinogradTransLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, + size_t length) { + const int unitStep = C4NUM * length; + for (int y = 0; y < h; ++y) { + float16_t *dstY = M + y * w * unitStep; + for (int x = 0; x < w; ++x) { + float16_t *dstX = dstY + x * unitStep; + const float16_t *srcX = S + x * unitStep; + memset(dstX, 0, unitStep * sizeof(float16_t)); + for (int i = 0; i < k; ++i) { + float16_t b = B[i * h + y]; + const float16_t *srcY = srcX + i * w * unitStep; + if (0.0f == b) { + continue; + } + for (int j = 0; j < unitStep; ++j) { + dstX[j] += srcY[j] * b; + } + } + } + } +} + +void WinogradTransRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, + size_t length) { + const int unitStep = C4NUM * length; + for (int y = 0; y < h; ++y) { + float16_t *dstY = M + y * w * unitStep; + const float16_t *srcY = S + y * k * unitStep; + + for (int x = 0; x < w; ++x) { + float16_t *dstX = dstY + x * unitStep; + memset(dstX, 0, unitStep * sizeof(float16_t)); + for (int i = 0; i < k; ++i) { + const float16_t *srcX = srcY + i * unitStep; + float16_t b = B[i * h + x]; + if (0.0f == b) { + continue; + } + for (int j = 0; j < unitStep; ++j) { + dstX[j] += srcX[j] * b; + } + } + } + } +} + +void TiledC4MatmulFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t cal_num, size_t ic4, + size_t oc4) { + int dx, sz, dz; + int src_depth_step = C4NUM * DECONV_WINOGRAD_DEFAULT_TILE; + for (dz = 0; dz < oc4; ++dz) { + float16_t *dst_z = dst + dz * cal_num; + const float16_t *weight_dz = weight + dz * ic4 * 16; + for (dx = 0; dx < DECONV_WINOGRAD_DEFAULT_TILE; ++dx) { + float16_t *dst_x = dst_z + dx * C4NUM; + dst_x[0] = 0.0f; + dst_x[1] = 0.0f; + dst_x[2] = 0.0f; + dst_x[3] = 0.0f; + const float16_t *src_dx = src + C4NUM * dx; + for (sz = 0; sz < ic4; ++sz) { + const float16_t *src_z = src_dx + sz * src_depth_step; + const float16_t *weight_z = weight_dz + sz * 16; + for (int i = 0; i < C4NUM; ++i) { + for (int j = 0; j < C4NUM; ++j) { + dst_x[j] += src_z[i] * weight_z[C4NUM * i + j]; + } + } + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_winograd_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_winograd_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..dbb5bcd1d38d07b9015c7de2d7c74c362d78f3b7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_winograd_fp16.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_DECONV_WINOGRAD_FP16_H_ +#define NNACL_FP16_DECONV_WINOGRAD_FP16_H_ + +#include "nnacl_c/fp16/winograd_transform_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PackDeConvWgDataFp16(const float16_t *nhwc_weight, DeConvComputeUnit *unit, const ConvParameter *conv_param, + const DeConvParam *deconv_param); + +void DeconvWgFp16(const float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_out, int start_index, + int calculate_count, const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id); + +void DeconvWgPostFp16(const float16_t *tile_out, float16_t *nc4hw4_output, const ConvParameter *conv_param, + const DeConvParam *deconv_param, int calculate_count, int tile_index); + +void TiledC4MatmulFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t ic4, size_t cal_num, + size_t oc4); + +void WinogradTransLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, + size_t length); + +void WinogradTransRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, + size_t length); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_DECONV_WINOGRAD_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/dynamic_quant_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/dynamic_quant_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..de3d1deae935cec4ce2284835c55a75b7804921a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/dynamic_quant_fp16.c @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/dynamic_quant_fp16.h" + +void CalculateMinMaxFp16(const float16_t *data, int count, float16_t *real_min, float16_t *real_max) { +#ifndef ENABLE_ARM64 + for (int i = 0; i < count; ++i) { + if (data[i] < *real_min) { + *real_min = data[i]; + } + if (data[i] > *real_max) { + *real_max = data[i]; + } + } +#else + // avoid to compile optimize. + volatile int count_8 = DOWN_ROUND(count, C8NUM); + CalculateMinMaxCount8Fp16(data, count_8, real_min, real_max); + for (int i = count_8; i < count; ++i) { + if (data[i] < *real_min) { + *real_min = data[i]; + } + if (data[i] > *real_max) { + *real_max = data[i]; + } + } +#endif +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/dynamic_quant_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/dynamic_quant_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..1ab6cf5ee4188ecd668d0114b15d23f86a789239 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/dynamic_quant_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_DYNAMIC_QUANT_FP16_H_ +#define NNACL_INT8_DYNAMIC_QUANT_FP16_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CalculateMinMaxFp16(const float16_t *data, int count, float16_t *real_min, float16_t *real_max); + +#ifdef ENABLE_ARM64 +void CalculateMinMaxCount8Fp16(const float16_t *data, int count_8, float16_t *real_min, float16_t *real_max); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DYNAMIC_QUANT_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/exp_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/exp_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..adeb0a7eb7194f39c0d4bfe8ea63988be758b6af --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/exp_fp16.c @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/exp_fp16.h" +#include +#include +#include "nnacl_c/errorcode.h" + +#if defined(ENABLE_NEON) +static inline void simd_exp_fp16(float16x8_t input, float16_t *dst) { + static float16x8_t maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, + 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static float16x8_t minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, + -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + input = vmaxq_f16(minv, vminq_f16(input, maxv)); + vst1q_f16(dst, VexpFp16(input)); +} +#endif + +void ExpFp16(const float16_t *src, float16_t *dst, int num) { + int i = 0; +#ifdef ENABLE_NEON + int count = (num / C8NUM) * C8NUM; + for (; i < count; i += C8NUM) { + simd_exp_fp16(vld1q_f16(src + i), dst + i); + } +#endif + for (; i < num; ++i) { + single_exp_fp16(src[i], dst + i); + } +} + +int ExpFusionFp16(const void *src_data, void *dst_data, const ExpStruct *exp, int task_id) { + NNACL_CHECK_ZERO_RETURN_ERR(exp->base_.thread_nr_); + ExpParameter *param = (ExpParameter *)exp->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + float16_t *src = (float16_t *)src_data; + float16_t *dst = (float16_t *)dst_data; + int stride = UP_DIV(exp->element_num_, exp->base_.thread_nr_); + int start = stride * task_id; + int end = MSMIN(exp->element_num_, start + stride); + int num = end - start; + + if (param->scale_ == 1) { + ExpFp16(src + start, dst + start, num); + } else { + int i = 0; +#ifdef ENABLE_ARM64 + MS_FLOAT16X8 scale = MS_MOVQ_F16(exp->in_scale_); + int count = (num / C8NUM) * C8NUM; + for (; i < count; i += C8NUM) { + simd_exp_fp16(MS_MULQ_F16(MS_LDQ_F16(src + i), scale), dst + i); + } +#endif + for (; i < num; ++i) { + single_exp_fp16(src[i] * exp->in_scale_, dst + i); + } + } + if (exp->out_scale_ != 1) { + int i = 0; +#ifdef ENABLE_ARM64 + MS_FLOAT16X8 scale = MS_MOVQ_F16(exp->out_scale_); + int count = (num / C8NUM) * C8NUM; + for (; i < count; i += C8NUM) { + simd_exp_fp16(MS_LDQ_F16(src + i), dst + i); + MS_STQ_F16(dst + i, MS_MULQ_F16(MS_LDQ_F16(dst + i), scale)); + } +#endif + for (; i < num; ++i) { + single_exp_fp16(src[i], dst + i); + dst[i] *= exp->out_scale_; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/exp_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/exp_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..d5b308253539cd4c1f6a1cc0b57481087229c6d1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/exp_fp16.h @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_EXP_FP16_H_ +#define NNACL_FP16_EXP_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/exp.h" +#include "nnacl_c/exp_parameter.h" +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ExpFp16(const float16_t *src, float16_t *dst, int num); +int ExpFusionFp16(const void *src_data, void *dst_data, const ExpStruct *exp, int task_id); + +#ifdef ENABLE_NEON +static inline float16x8_t VexpFp16(float16x8_t input) { + float32x4_t input_low = MS_CVT_F32_F16(vget_low_f16(input)); + float32x4_t input_high = MS_CVT_F32_F16(vget_high_f16(input)); + return vcombine_f16(MS_CVT_F16_F32(VexpFp32(input_low)), MS_CVT_F16_F32(VexpFp32(input_high))); +} +#endif + +static inline void single_exp_fp16(float16_t src, float16_t *dst) { + static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; + int integer; + if (src > 0) { + src = MSMIN(88.72283935546875f, src); + integer = (float)src * 1.44269504088896341f + 0.5f; + } else { + src = MSMAX(-87.3365478515625f, src); + integer = (float)src * 1.44269504088896341f - 0.5f; + } + const int shift = 23; + const int bias = 126; + const float factor = 2; + float decimal = (float)src - integer * param[0]; + int int_exp = (integer + bias) << shift; + const float decimal_exp = + 1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); + float *tmp = (float *)(&int_exp); + *dst = (float16_t)(*(tmp)*decimal_exp * factor); +} + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_EXP_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/fill_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/fill_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..a4b32348b180d177c4072e0e1bbf7f7521539045 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/fill_fp16.c @@ -0,0 +1,24 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/fill_fp16.h" + +inline int FillFp16(float16_t *output, int size, float16_t data) { + for (int i = 0; i < size; ++i) { + output[i] = data; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/fill_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/fill_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..c177e9a910b3a7985061df182513cfce4236a346 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/fill_fp16.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_FILL_FP16_H_ +#define NNACL_FP16_FILL_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fill_parameter.h" +#ifdef ENABLE_ARM +#include +#endif + +#ifdef __cplusplus +extern "C" { +#endif +int FillFp16(float16_t *output, int size, float16_t data); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_FILL_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/gru_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/gru_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..ff2f9e44866207df3e3c2c24485a7aa62d1c1144 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/gru_fp16.c @@ -0,0 +1,148 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/gru_fp16.h" +#include +#include "nnacl_c/fp16/lstm_fp16.h" +#include "nnacl_c/fp16/activation_fp16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" + +void GruStepUnitFp16(float16_t *output, float16_t *update_gate, float16_t *reset_gate, float16_t *hidden_buffer, + const float16_t *state_weight, const float16_t *state_bias, float16_t *hidden_state, + float16_t *buffer[4], const GruParameter *gru_param) { + float16_t *packed_state = buffer[2]; + float16_t *state_gate = buffer[3]; + bool is_vec = gru_param->batch_ == 1; + + const float16_t *state_update_weight = state_weight; + const float16_t *state_reset_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_; + const float16_t *state_hidden_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_ * 2; + float16_t *state_update_gate = state_gate; + float16_t *state_reset_gate = state_gate + gru_param->batch_ * gru_param->hidden_size_; + float16_t *state_hidden_buffer = state_gate + gru_param->batch_ * gru_param->hidden_size_ * 2; + const float16_t *state_update_bias = state_bias; + const float16_t *state_reset_bias = state_bias + gru_param->hidden_size_; + const float16_t *state_hidden_bias = state_bias + gru_param->hidden_size_ * 2; + + // state * weight + if (is_vec) { + LstmMatMulFp16(state_reset_gate, hidden_state, state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + LstmMatMulFp16(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + RowMajor2Col16MajorFp16(hidden_state, packed_state, gru_param->batch_, gru_param->hidden_size_, false); + LstmMatMulFp16(state_reset_gate, packed_state, state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + LstmMatMulFp16(state_update_gate, packed_state, state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + ElementAddFp16(update_gate, state_update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_); + ElementAddFp16(reset_gate, state_update_gate + gru_param->batch_ * gru_param->hidden_size_, reset_gate, + gru_param->batch_ * gru_param->hidden_size_); + + // update reset_gate + SigmoidFp16(reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_); + + // update update_gate + SigmoidFp16(update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_); + + ElementMulFp16(hidden_state, reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_); + if (is_vec) { + LstmMatMulFp16(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + RowMajor2Col16MajorFp16(reset_gate, packed_state, gru_param->batch_, gru_param->hidden_size_, false); + LstmMatMulFp16(state_hidden_buffer, packed_state, state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + ElementAddFp16(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_); + + TanhFp16(hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_); + + ElementMulFp16(update_gate, hidden_state, hidden_state, gru_param->batch_ * gru_param->hidden_size_); + + float16_t one = 1.0f; + ElementOptSubFp16(&one, update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_, true); + + ElementMulAccFp16(update_gate, hidden_buffer, hidden_state, gru_param->batch_ * gru_param->hidden_size_); + + memcpy(output, hidden_state, gru_param->batch_ * gru_param->hidden_size_ * sizeof(float16_t)); +} + +void GruUnidirectionalFp16(float16_t *output, const float16_t *packed_input, const float16_t *weight_g, + const float16_t *weight_r, const float16_t *input_bias, const float16_t *state_bias, + float16_t *hidden_state, float16_t *buffer[4], const GruParameter *gru_param, + bool is_backward) { + float16_t *gate = buffer[1]; + for (int i = 0; i < 3; i++) { + const float16_t *weight_loop = weight_g + gru_param->input_size_ * gru_param->input_col_align_ * i; + const float16_t *bias_loop = input_bias + gru_param->input_col_align_ * i; + float16_t *gate_loop = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * i; + MatMulFp16(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, gru_param->input_size_, + gru_param->seq_len_ * gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, OutType_Nhwc); + } + + float16_t *update_gate = gate; + float16_t *reset_gate = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_; + float16_t *hidden_buffer = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * 2; + for (int t = 0; t < gru_param->seq_len_; t++) { + int real_t = is_backward ? gru_param->seq_len_ - t - 1 : t; + float16_t *update_gate_t = update_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float16_t *reset_gate_t = reset_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float16_t *hidden_buffer_t = hidden_buffer + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float16_t *output_ptr = output + real_t * gru_param->output_step_; + GruStepUnitFp16(output_ptr, update_gate_t, reset_gate_t, hidden_buffer_t, weight_r, state_bias, hidden_state, + buffer, gru_param); + } +} + +void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r, + const float16_t *input_bias, const float16_t *state_bias, float16_t *hidden_state, float16_t *buffer[4], + int check_seq_len, const GruParameter *gru_param) { + // forward + float16_t *packed_input = buffer[0]; + RowMajor2Col16MajorFp16(input, packed_input, gru_param->seq_len_ * gru_param->batch_, gru_param->input_size_, false); + GruUnidirectionalFp16(output, packed_input, weight_g, weight_r, input_bias, state_bias, hidden_state, buffer, + gru_param, false); + // zero out extra fw outputs + for (int t = check_seq_len; t < gru_param->seq_len_; t++) { + float16_t *output_ptr = output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + + // backward + if (gru_param->bidirectional_) { + const float16_t *backward_weight_g = weight_g + 3 * gru_param->input_col_align_ * gru_param->input_size_; + const float16_t *backward_weight_r = weight_r + 3 * gru_param->state_col_align_ * gru_param->hidden_size_; + const float16_t *backward_input_bias = input_bias + 3 * gru_param->input_col_align_; + const float16_t *backward_state_bias = state_bias + 3 * gru_param->state_col_align_; + float16_t *backward_output = output + gru_param->batch_ * gru_param->hidden_size_; + float16_t *backward_hidden_state = hidden_state + gru_param->batch_ * gru_param->hidden_size_; + GruUnidirectionalFp16(backward_output, packed_input, backward_weight_g, backward_weight_r, backward_input_bias, + backward_state_bias, backward_hidden_state, buffer, gru_param, true); + + // zero out extra bw outputs + for (int t = gru_param->seq_len_ - 1; t >= check_seq_len; t--) { + float16_t *output_ptr = backward_output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/gru_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/gru_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..ea8a7b71fbb044c3da8c6cb37e14316fac665aae --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/gru_fp16.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRU_H_ +#define NNACL_FP16_GRU_H_ +#include "nnacl_c/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r, + const float16_t *input_bias, const float16_t *state_bias, float16_t *hidden_state, float16_t *buffer[4], + int check_seq_len, const GruParameter *gru_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRU_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..a174e46eaa89684556ba7a0ece629a0f52abfb0e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.c @@ -0,0 +1,217 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/instance_norm_fp16.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, + const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) { + NNACL_CHECK_NULL_RETURN_ERR(src_data); + NNACL_CHECK_NULL_RETURN_ERR(dst_data); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); + int channel = param->channel_; + int hw_plane = param->inner_size_; + NNACL_CHECK_ZERO_RETURN_ERR(hw_plane); + int channel_step = UP_DIV(channel, param->op_parameter_.thread_num_); + int channel_begin = task_id * channel_step; + int channel_end = MSMIN(channel_begin + channel_step, channel); + + for (int b = 0; b < param->batch_; b++) { + const float16_t *src_b = src_data + b * channel * hw_plane; + float16_t *dst_b = dst_data + b * channel * hw_plane; + for (int c = channel_begin; c < channel_end; c++) { + const float16_t *src = src_b + c * hw_plane; + float16_t *dst = dst_b + c * hw_plane; + float mean = 0.0f; + float square_mean = 0.0f; + + int index = 0; + for (; index <= hw_plane - C8NUM; index += C8NUM) { + float16x8_t srcv = vld1q_f16(src + index); + float16x8_t squarev = vmulq_f16(srcv, srcv); + + float16x4_t sum2 = vadd_f16(vget_low_f16(srcv), vget_high_f16(srcv)); + float32x4_t sum_f32 = vcvt_f32_f16(sum2); + mean += MS_ADDVQ_F32(sum_f32); + + float16x4_t square2 = vadd_f16(vget_low_f16(squarev), vget_high_f16(squarev)); + float32x4_t square_f32 = vcvt_f32_f16(square2); + square_mean += MS_ADDVQ_F32(square_f32); + } + for (; index < hw_plane; index++) { + mean += src[index]; + square_mean += src[index] * src[index]; + } + + mean /= (float)hw_plane; + square_mean /= (float)hw_plane; + const float deno = 1 / sqrtf(square_mean - mean * mean + param->epsilon_); + + index = 0; + float16x8_t meanv = vdupq_n_f16(mean); + float16x8_t denov = vdupq_n_f16(deno); + for (; index <= hw_plane - C8NUM; index += C8NUM) { + float16x8_t srcv = vld1q_f16(src + index); + float16x8_t outv = vsubq_f16(srcv, meanv); + outv = vmulq_f16(outv, denov); + + float16x8_t gammav = vdupq_n_f16(gamma_data[c]); + float16x8_t betav = vdupq_n_f16(beta_data[c]); + outv = vmulq_f16(outv, gammav); + outv = vaddq_f16(outv, betav); + vst1q_f16(dst + index, outv); + } + for (; index < hw_plane; index++) { + dst[index] = (src[index] - mean) * deno; + dst[index] = dst[index] * gamma_data[c] + beta_data[c]; + } + } + } + return NNACL_OK; +} + +int InstanceNormNC8HW8Fp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, + const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) { + NNACL_CHECK_NULL_RETURN_ERR(src_data); + NNACL_CHECK_NULL_RETURN_ERR(dst_data); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); + int channel = param->channel_; + int hw_plane = param->inner_size_; + NNACL_CHECK_ZERO_RETURN_ERR(hw_plane); + int channel_step = UP_DIV(UP_DIV(channel, C8NUM), param->op_parameter_.thread_num_) * C8NUM; + int channel_begin = (int)(task_id)*channel_step; + int channel_end = MSMIN(channel_begin + channel_step, channel); + int c8_down = channel_end / C8NUM * C8NUM; + int c_res = channel_end - c8_down; + float32x4_t hw_plane_4 = vdupq_n_f32(hw_plane); + for (int b = 0; b < param->batch_; b++) { + const float16_t *src_b = src_data + b * channel * hw_plane; + float16_t *dst_b = dst_data + b * channel * hw_plane; + int c = channel_begin; + for (; c <= channel_end - C16NUM; c += C16NUM) { + const float16_t *src = src_b + c * hw_plane; + const float16_t *src1 = src_b + (c + C8NUM) * hw_plane; + float16_t *dst = dst_b + c; + float32x4_t mean1 = vdupq_n_f32(0.0f); + float32x4_t mean2 = vdupq_n_f32(0.0f); + float32x4_t mean3 = vdupq_n_f32(0.0f); + float32x4_t mean4 = vdupq_n_f32(0.0f); + float32x4_t square_mean1 = vdupq_n_f32(0.0f); + float32x4_t square_mean2 = vdupq_n_f32(0.0f); + float32x4_t square_mean3 = vdupq_n_f32(0.0f); + float32x4_t square_mean4 = vdupq_n_f32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + float16x8_t srcv = vld1q_f16(src + index * C8NUM); + float16x8_t srcv1 = vld1q_f16(src1 + index * C8NUM); + + float32x4_t srcv01 = vcvt_f32_f16(vget_low_f16(srcv)); + float32x4_t srcv02 = vcvt_f32_f16(vget_high_f16(srcv)); + float32x4_t srcv11 = vcvt_f32_f16(vget_low_f16(srcv1)); + float32x4_t srcv12 = vcvt_f32_f16(vget_high_f16(srcv1)); + mean1 = vaddq_f32(mean1, srcv01); + mean2 = vaddq_f32(mean2, srcv02); + mean3 = vaddq_f32(mean3, srcv11); + mean4 = vaddq_f32(mean4, srcv12); + square_mean1 = vaddq_f32(square_mean1, vmulq_f32(srcv01, srcv01)); + square_mean2 = vaddq_f32(square_mean2, vmulq_f32(srcv02, srcv02)); + square_mean3 = vaddq_f32(square_mean3, vmulq_f32(srcv11, srcv11)); + square_mean4 = vaddq_f32(square_mean4, vmulq_f32(srcv12, srcv12)); + } + float16x8_t mean = + vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(mean1, hw_plane_4)), vcvt_f16_f32(MS_DIVQ_F32(mean2, hw_plane_4))); + float16x8_t mean_1 = + vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(mean3, hw_plane_4)), vcvt_f16_f32(MS_DIVQ_F32(mean4, hw_plane_4))); + float16x8_t square_mean = vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(square_mean1, hw_plane_4)), + vcvt_f16_f32(MS_DIVQ_F32(square_mean2, hw_plane_4))); + float16x8_t square_mean_1 = vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(square_mean3, hw_plane_4)), + vcvt_f16_f32(MS_DIVQ_F32(square_mean4, hw_plane_4))); + float16x8_t deno = vaddq_f16(vsubq_f16(square_mean, vmulq_f16(mean, mean)), vdupq_n_f16(param->epsilon_)); + float16x8_t deno1 = vaddq_f16(vsubq_f16(square_mean_1, vmulq_f16(mean_1, mean_1)), vdupq_n_f16(param->epsilon_)); + deno = 1 / MS_SQRTFX8_F16(deno); + deno1 = 1 / MS_SQRTFX8_F16(deno1); + + float16x8_t gammav = vmulq_f16(vld1q_f16(gamma_data + c), deno); // deno * gamma_data[c] + float16x8_t gammav1 = vmulq_f16(vld1q_f16(gamma_data + c + C8NUM), deno1); // deno * gamma_data[c] + float16x8_t betav = vld1q_f16(beta_data + c); + float16x8_t betav1 = vld1q_f16(beta_data + c + C8NUM); + for (int index = 0; index < hw_plane; ++index) { + float16x8_t srcv = vld1q_f16(src + index * C8NUM); + float16x8_t srcv1 = vld1q_f16(src1 + index * C8NUM); + float16x8_t outv = vsubq_f16(srcv, mean); + float16x8_t outv1 = vsubq_f16(srcv1, mean_1); + outv = vmulq_f16(outv, gammav); + outv1 = vmulq_f16(outv1, gammav1); + outv = vaddq_f16(outv, betav); + outv1 = vaddq_f16(outv1, betav1); + vst1q_f16(dst + index * channel, outv); + vst1q_f16(dst + index * channel + C8NUM, outv1); + } + } + for (; c <= channel_end - C8NUM; c += C8NUM) { + const float16_t *src = src_b + c * hw_plane; + float16_t *dst = dst_b + c; + float32x4_t mean1 = vdupq_n_f32(0.0f); + float32x4_t mean2 = vdupq_n_f32(0.0f); + float32x4_t square_mean1 = vdupq_n_f32(0.0f); + float32x4_t square_mean2 = vdupq_n_f32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + float16x8_t srcv = vld1q_f16(src + index * C8NUM); + float32x4_t srcv1 = vcvt_f32_f16(vget_low_f16(srcv)); + float32x4_t srcv2 = vcvt_f32_f16(vget_high_f16(srcv)); + mean1 = vaddq_f32(mean1, srcv1); + mean2 = vaddq_f32(mean2, srcv2); + square_mean1 = vaddq_f32(square_mean1, vmulq_f32(srcv1, srcv1)); + square_mean2 = vaddq_f32(square_mean2, vmulq_f32(srcv2, srcv2)); + } + float16x8_t mean = + vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(mean1, hw_plane_4)), vcvt_f16_f32(MS_DIVQ_F32(mean2, hw_plane_4))); + float16x8_t square_mean = vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(square_mean1, hw_plane_4)), + vcvt_f16_f32(MS_DIVQ_F32(square_mean2, hw_plane_4))); + float16x8_t deno = + vaddq_f16(vsubq_f16(square_mean, vmulq_f16(mean, mean)), vdupq_n_f16(param->epsilon_)); // question + deno = 1 / MS_SQRTFX8_F16(deno); // question + + float16x8_t gammav = vmulq_f16(vld1q_f16(gamma_data + c), deno); // deno * gamma_data[c] + float16x8_t betav = vld1q_f16(beta_data + c); + for (int index = 0; index < hw_plane; ++index) { + float16x8_t srcv = vld1q_f16(src + index * C8NUM); + float16x8_t outv = vsubq_f16(srcv, mean); + outv = vmulq_f16(outv, gammav); + outv = vaddq_f16(outv, betav); + vst1q_f16(dst + index * channel, outv); + } + } + for (; c < channel_end; ++c) { + const float16_t *src = src_b + c8_down * hw_plane + c; + float16_t *dst = dst_b + c; + float mean = 0.0f; + float square_mean = 0.0f; + for (int index = 0; index < hw_plane; ++index) { + float16_t tmp = src[index * c_res]; + mean += tmp; + square_mean += tmp * tmp; + } + mean /= (float)hw_plane; + square_mean /= (float)hw_plane; + const float deno = gamma_data[c] / sqrtf(square_mean - mean * mean + param->epsilon_); + for (int index = 0; index < hw_plane; ++index) { + dst[index * channel] = (src[index * c_res] - mean) * deno + beta_data[c]; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..fdbbd065304b7313de1f2353619a183d1024c845 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_INSTANCE_NORM_FP16_H_ +#define NNACL_FP16_INSTANCE_NORM_FP16_H_ + +#include "nnacl_c/instance_norm_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif + +int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, + const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id); +int InstanceNormNC8HW8Fp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, + const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_INSTANCE_NORM_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/layer_norm_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/layer_norm_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..d6e2b545e4153210f014eea20d6b9d7a2c4c69d8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/layer_norm_fp16.c @@ -0,0 +1,110 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/layer_norm_fp16.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +int LayerNormMeanAndSquareFp16(const float16_t *src, int num, float16_t *mean, float16_t *variance) { + if (num <= 0) { + return NNACL_ERR; + } + int index = 0; + float sum = 0.0f; + float square_mean = 0.0f; + for (; index <= num - C8NUM; index += C8NUM) { + float16x8_t srcv = vld1q_f16(src + index); + for (int i = 0; i < C8NUM; ++i) { + square_mean += srcv[i] * srcv[i]; + } + float16x4_t sum2 = vadd_f16(vget_low_f16(srcv), vget_high_f16(srcv)); + float32x4_t sum_f32 = vcvt_f32_f16(sum2); + sum += MS_ADDVQ_F32(sum_f32); + } + for (; index < num; index++) { + sum += src[index]; + square_mean += src[index] * src[index]; + } + *mean = (float16_t)(sum / num); + square_mean = square_mean / num; + *variance = square_mean - (*mean) * (*mean); + return NNACL_OK; +} + +void LayerNormGammaAndBetaFp16(float16_t *dst, const float16_t *src, const float16_t *gamma_data, + const float16_t *beta_data, int num, const float16_t mean, const float16_t deno) { + int index = 0; + float16x8_t meanv = vdupq_n_f16(mean); + float16x8_t denov = vdupq_n_f16(deno); + for (; index <= num - C8NUM; index += C8NUM) { + float16x8_t srcv = vld1q_f16(src + index); + float16x8_t outv = vsubq_f16(srcv, meanv); + outv = vmulq_f16(outv, denov); + float16x8_t gammav = vld1q_f16(gamma_data + index); + float16x8_t betav = vld1q_f16(beta_data + index); + outv = vmulq_f16(outv, gammav); + outv = vaddq_f16(outv, betav); + vst1q_f16(dst + index, outv); + } + for (; index < num; index++) { + dst[index] = (src[index] - mean) * (deno); + dst[index] = dst[index] * gamma_data[index] + beta_data[index]; + } +} + +int LayerNormFp16(const float16_t *src_data, const float16_t *gamma_data, const float16_t *beta_data, + float16_t *dst_data, float16_t *out_mean, float16_t *out_variance, const LayerNormComputeParam *param, + int task_id, int thread_num) { + if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) { + return NNACL_NULL_PTR; + } + NNACL_CHECK_ZERO_RETURN_ERR(param->params_inner_size_); + NNACL_CHECK_ZERO_RETURN_ERR(param->params_outer_size_); + NNACL_CHECK_ZERO_RETURN_ERR(thread_num); + int step = UP_DIV(param->norm_outer_size_, thread_num); + int thread_end = MSMIN((task_id + 1) * step, param->norm_outer_size_); + for (int i = task_id * step; i < thread_end; i++) { + const float16_t *src_norm = src_data + i * param->norm_inner_size_; + float16_t *dst_norm = dst_data + i * param->norm_inner_size_; + float16_t cur_mean = 0.0f; + float16_t cur_variance = 0.0f; + int ret = LayerNormMeanAndSquareFp16(src_norm, param->norm_inner_size_, &cur_mean, &cur_variance); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + if (out_mean != NULL) { + out_mean[i] = cur_mean; + } + if (out_variance != NULL) { + out_variance[i] = cur_variance; + } + const float16_t deno = 1 / sqrtf(cur_variance + param->epsilon_); + if (param->norm_outer_size_ <= param->params_outer_size_) { + for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) { + const float16_t *src_param = src_norm + x * param->params_inner_size_; + float16_t *dst_param = dst_norm + x * param->params_inner_size_; + LayerNormGammaAndBetaFp16(dst_param, src_param, gamma_data, beta_data, param->params_inner_size_, cur_mean, + deno); + } + } else { + int x = i / param->params_outer_size_; + const float16_t *gamma = gamma_data + x * param->norm_inner_size_; + const float16_t *beta = beta_data + x * param->norm_inner_size_; + LayerNormGammaAndBetaFp16(dst_norm, src_norm, gamma, beta, param->norm_inner_size_, cur_mean, deno); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/layer_norm_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/layer_norm_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..dd5a1992ba50706a276f5e6db0424daef77a6868 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/layer_norm_fp16.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_LAYER_NORM_FP16_H_ +#define NNACL_FP16_LAYER_NORM_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/layer_norm.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormFp16(const float16_t *src_data, const float16_t *gamma_data, const float16_t *beta_data, + float16_t *dst_data, float16_t *out_mean, float16_t *out_variance, const LayerNormComputeParam *param, + int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_LAYER_NORM_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/log_softmax_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/log_softmax_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..d00a48073368883cd1eba39ef913c7b2f01bc89b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/log_softmax_fp16.c @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/log_softmax_fp16.h" +#include +#include +#include "nnacl_c/fp16/softmax_fp16.h" +#include "nnacl_c/fp16/exp_fp16.h" + +void LogSoftmaxLastAxisFp16(const float16_t *src, float16_t *dst, float16_t *exp_data, int batch, int channel) { + SoftmaxNormFp16(src, dst, batch, channel); + ExpFp16(dst, exp_data, batch * channel); + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + float16_t sum = 0; + int j = 0; +#ifdef ENABLE_NEON + float16x8_t sum8 = vdupq_n_f16(0); + int count = (channel / C8NUM) * C8NUM; + for (; j < count; j += C8NUM) { + sum8 = vaddq_f16(sum8, vld1q_f16(exp_data + cur_batch_offset + j)); + } + sum = sum8[0] + sum8[1] + sum8[2] + sum8[3] + sum8[4] + sum8[5] + sum8[6] + sum8[7]; +#endif + for (; j < channel; j++) { + sum += exp_data[cur_batch_offset + j]; + } + for (int k = 0; k < channel; k++) { + dst[cur_batch_offset + k] = dst[cur_batch_offset + k] - log(sum); + } + } +} + +// output = (input - reduce_max(input, axis)) - log(reduce_sum(exp(input - reduce_max(input, axis)), axis)) +void LogSoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, int *input_shape, int n_dim, + int axis) { + int inner_size = 1; + int outter_size = 1; + + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + float16_t max_data = input_ptr[inner_offset]; + sum_data[k + sum_outter_offset] = 0; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = input_ptr[axis_offset] - max_data; + sum_data[k + sum_outter_offset] += exp(output_ptr[axis_offset]); + } + } + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] - log(sum_data[k + sum_outter_offset]); + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/log_softmax_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/log_softmax_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..f4cf14d03960cc163a42e54dafe14a95c59d4d59 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/log_softmax_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_LOG_SOFTMAX_FP16_H_ +#define NNACL_FP16_LOG_SOFTMAX_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/softmax_parameter.h" +#ifdef ENABLE_NEON +#include +#endif +#ifdef __cplusplus +extern "C" { +#endif +void LogSoftmaxLastAxisFp16(const float16_t *src, float16_t *dst, float16_t *exp_data, int batch, int channel); +void LogSoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, int *input_shape, int n_dim, + int axis); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_LOG_SOFTMAX_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/lstm_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/lstm_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..d812b6e13a998f39a21fc4a1a4595700da6711f1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/lstm_fp16.c @@ -0,0 +1,367 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/lstm_fp16.h" +#include +#include +#include "nnacl_c/fp16/activation_fp16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align, + const int32_t *order) { + for (int i = 0; i < batch; i++) { + const float *src_batch = src + i * col * deep; + float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align * deep; +#ifdef ENABLE_ARM64 + RowMajor2ColNMajorFp16(src_batch, dst_batch, col, deep, true); +#else + RowMajor2Col8MajorFp16(src_batch, dst_batch, col, deep, true); +#endif + } +} + +void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align, + const int32_t *order) { + for (int i = 0; i < batch; i++) { + const float16_t *src_batch = src + i * col * deep; + float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align * deep; +#ifdef ENABLE_ARM64 + RowMajor2ColNMajorFp16(src_batch, dst_batch, col, deep, false); +#else + RowMajor2Col8MajorFp16(src_batch, dst_batch, col, deep, false); +#endif + } +} + +void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + for (int i = 0; i < unidirectional_batch; i++) { + const float *src_batch = src + i * col; + float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align; + Float32ToFloat16(src_batch, dst_batch, col); + } + if (is_bidirectional) { + const float *backward_src = src + batch * col; + float16_t *backward_dst = dst + unidirectional_batch * col_align; + for (int i = 0; i < unidirectional_batch; i++) { + const float *backward_src_batch = backward_src + i * col; + float16_t *backward_dst_batch = backward_dst + (order == NULL ? i : order[i]) * col_align; + Float32ToFloat16(backward_src_batch, backward_dst_batch, col); + } + } +} + +void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + for (int i = 0; i < unidirectional_batch; i++) { + const float16_t *src_batch = src + i * col; + float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align; + (void)memcpy(dst_batch, src_batch, col * sizeof(float16_t)); + } + if (is_bidirectional) { + const float16_t *backward_src = src + batch * col; + float16_t *backward_dst = dst + unidirectional_batch * col_align; + for (int i = 0; i < unidirectional_batch; i++) { + const float16_t *backward_src_batch = backward_src + i * col; + float16_t *backward_dst_batch = backward_dst + (order == NULL ? i : order[i]) * col_align; + (void)memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float16_t)); + } + } +} + +// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col] +void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols, + int inner_size) { + for (int r = 0; r < rows; r++) { + for (int c = 0; c < cols; c++) { + float16_t res = 0; + const float16_t *input_col = input + r * inner_size; + const float16_t *weight_col = weight + c * inner_size; + int index = 0; + float16x8_t out = vdupq_n_f16(0.0f); + for (; index <= inner_size - 8; index += 8) { + float16x8_t in_0 = vld1q_f16(input_col + index); + float16x8_t in_1 = vld1q_f16(weight_col + index); + out = vfmaq_f16(out, in_1, in_0); + } + float16x4_t add2 = vadd_f16(vget_low_f16(out), vget_high_f16(out)); + float16x4_t add4 = vpadd_f16(add2, add2); + float16x4_t add8 = vpadd_f16(add4, add4); + res += vget_lane_f16(add8, 0); + for (; index < inner_size; index++) { + res += input_col[index] * weight_col[index]; + } + output[r * cols + c] += res; + } + } +} + +void ElementMulAccFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; + for (; index <= element_size - 8; index += 8) { + float16x8_t in_0 = vld1q_f16(input0 + index); + float16x8_t in_1 = vld1q_f16(input1 + index); + float16x8_t out = vld1q_f16(output + index); + out = vfmaq_f16(out, in_1, in_0); + vst1q_f16(output + index, out); + } + for (; index < element_size; index++) { + output[index] += input0[index] * input1[index]; + } +} + +int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float16_t *output, const int element_size) { + int index = 0; + for (; index <= element_size - 8; index += 8) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vld1q_f16(output + index); + vout = MS_FMAQ_N_F16(vout, vin0, input1); + vst1q_f16(output + index, vout); + } + for (; index < element_size; index++) { + output[index] += input0[index] * input1; + } + return NNACL_OK; +} + +void UpdateStateFp16(float16_t *cell_state, const float16_t *forget_gate, const float16_t *input_gate, + const float16_t *cell_gate, float16_t *state_buffer, int batch, int hidden_size, + float16_t zoneout) { + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { // zoneout * old_cell_state + (void)memcpy(state_buffer, cell_state, batch * hidden_size * sizeof(float16_t)); + ElementOptMulFp16(state_buffer, &zoneout, state_buffer, batch * hidden_size, false); + } + + ElementMulFp16(forget_gate, cell_state, cell_state, batch * hidden_size); + ElementMulAccFp16(input_gate, cell_gate, cell_state, batch * hidden_size); + + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { // (1 - zoneout) * new_cell_state + ElementOptMulAccFp16(cell_state, 1 - zoneout, state_buffer, batch * hidden_size); + } +} + +void UpdateOutputFp16(float16_t *hidden_state, float16_t *output, const float16_t *cell_state, float16_t *output_gate, + const float16_t *weight_project, const float16_t *project_bias, float16_t *buffer[C7NUM], + const LstmParameter *lstm_param) { + int batch = lstm_param->batch_; + int hidden_size = lstm_param->hidden_size_; + int output_size = lstm_param->output_size_; + float16_t *state_buffer = buffer[C5NUM]; + float16_t *hidden_buffer = weight_project ? buffer[C3NUM] : hidden_state; + float16_t zoneout = lstm_param->zoneout_hidden_; + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { + (void)memcpy(state_buffer, hidden_state, batch * output_size * sizeof(float16_t)); + ElementOptMulFp16(state_buffer, &zoneout, state_buffer, batch * output_size, false); + } + + TanhFp16(cell_state, hidden_buffer, batch * hidden_size); + ElementMulFp16(hidden_buffer, output_gate, hidden_buffer, batch * hidden_size); + + if (weight_project) { + float16_t *left_matrix = hidden_buffer; +#ifdef ENABLE_ARM64 + if (batch >= C4NUM) { + left_matrix = buffer[C6NUM]; + RowMajor2ColLadder12MajorFp16(hidden_buffer, left_matrix, batch, hidden_size); + } +#else + if (batch != 1) { + left_matrix = buffer[C6NUM]; + RowMajor2Col16MajorFp16(hidden_buffer, left_matrix, batch, hidden_size, false); + } +#endif + LstmMatMulFp16(hidden_state, left_matrix, weight_project, project_bias, batch, hidden_size, output_size, + batch == 1); + } + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { + ElementOptMulAccFp16(hidden_state, 1 - zoneout, state_buffer, batch * output_size); + } + (void)memcpy(output, hidden_state, batch * output_size * sizeof(float16_t)); +} + +#ifdef ENABLE_ARM64 +void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, + int col, bool is_vec) { + MatmulFp16OptV2(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); +} +#else +void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, + int col, bool is_vec) { + if (is_vec) { + (void)memcpy(c, bias, col * sizeof(float16_t)); + MatMulAccFp16(c, a, b, row, col, deep); + } else { + MatMulFp16(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); + } +} +#endif + +void UpdateLstmGateFp16(float16_t *gate_buffer, const float16_t *input, const float16_t *weight, const float16_t *bias, + int row, int deep, int col, int col_align, bool is_vec) { + for (int i = 0; i < 4; i++) { + const float16_t *weight_i = weight + deep * col_align * i; + const float16_t *bias_i = bias + col_align * i; + float16_t *gate = gate_buffer + row * col * i; + LstmMatMulFp16(gate, input, weight_i, bias_i, row, deep, col, is_vec); + } +} + +void LstmStepUnitFp16(float16_t *output, float16_t *input_gate, float16_t *forget_gate, float16_t *cell_gate, + float16_t *output_gate, const float16_t *state_weight, const float16_t *state_bias, + const float16_t *weight_project, const float16_t *project_bias, float16_t *hidden_state, + float16_t *cell_state, float16_t *buffer[C7NUM], const LstmParameter *lstm_param) { + float16_t *packed_state = buffer[C2NUM]; + float16_t *state_gate = buffer[C3NUM]; + float16_t *cell_buffer = buffer[C4NUM]; + float16_t *hidden_buffer = buffer[C5NUM]; + bool is_vec = lstm_param->batch_ == 1; +#ifdef ENABLE_ARM64 + if (lstm_param->batch_ <= C3NUM) { + UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); + } else { + RowMajor2ColLadder12MajorFp16(hidden_state, packed_state, lstm_param->batch_, lstm_param->output_size_); + UpdateLstmGateFp16(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); + } +#else + if (is_vec) { + UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); + } else { + RowMajor2Col16MajorFp16(hidden_state, packed_state, lstm_param->batch_, lstm_param->output_size_, false); + UpdateLstmGateFp16(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); + } +#endif + ElementAddFp16(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); + ElementAddFp16(forget_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 2, forget_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAddFp16(cell_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 3, cell_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAddFp16(output_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_, output_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + + // update input_gate + SigmoidFp16(input_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); + + // update forget_gate + SigmoidFp16(forget_gate, forget_gate, lstm_param->batch_ * lstm_param->hidden_size_); + + // update cell_gate + TanhFp16(cell_gate, cell_gate, lstm_param->batch_ * lstm_param->hidden_size_); + // update cell state + UpdateStateFp16(cell_state, forget_gate, input_gate, cell_gate, cell_buffer, lstm_param->batch_, + lstm_param->hidden_size_, lstm_param->zoneout_cell_); + + // update output_gate + SigmoidFp16(output_gate, output_gate, lstm_param->batch_ * lstm_param->hidden_size_); + // update output + UpdateOutputFp16(hidden_state, output, cell_state, output_gate, weight_project, project_bias, buffer, lstm_param); + + if (!(lstm_param->zoneout_cell_ >= -FLT_EPSILON && lstm_param->zoneout_cell_ <= FLT_EPSILON)) { + (void)memcpy(cell_state, cell_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t)); + } + + if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) { + (void)memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->output_size_ * sizeof(float16_t)); + } +} + +void LstmGateCompute(float16_t *gate, const float16_t *input, const float16_t *weight_i, const float16_t *input_bias, + const LstmParameter *lstm_param) { + int row_input = lstm_param->seq_len_ * lstm_param->batch_; + for (int i = 0; i < C4NUM; i++) { + const float16_t *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i; + const float16_t *bias_loop = input_bias + lstm_param->input_col_align_ * i; + float16_t *gate_loop = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * i; +#ifdef ENABLE_ARM64 + MatmulFp16OptV2(input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, row_input, + lstm_param->hidden_size_, lstm_param->hidden_size_, OutType_Nhwc); +#else + MatMulFp16(input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, row_input, + lstm_param->hidden_size_, lstm_param->hidden_size_, OutType_Nhwc); +#endif + } +} + +void LstmUnidirectionalFp16(float16_t *output, const float16_t *packed_input, const float16_t *weight_i, + const float16_t *weight_h, const float16_t *input_bias, const float16_t *state_bias, + const float16_t *weight_project, const float16_t *project_bias, float16_t *hidden_state, + float16_t *cell_state, float16_t *buffer[C7NUM], const LstmParameter *lstm_param, + bool is_backward) { + float16_t *gate = buffer[1]; + LstmGateCompute(gate, packed_input, weight_i, input_bias, lstm_param); + + float16_t *input_gate = gate; + float16_t *forget_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 2; + float16_t *cell_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 3; + float16_t *output_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_; + for (int t = 0; t < lstm_param->seq_len_; t++) { + int real_t = is_backward ? lstm_param->seq_len_ - t - 1 : t; + float16_t *input_gate_t = input_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *forget_gate_t = forget_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *cell_gate_t = cell_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *output_gate_t = output_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *output_ptr = output + real_t * lstm_param->output_step_; + LstmStepUnitFp16(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, + weight_project, project_bias, hidden_state, cell_state, buffer, lstm_param); + } +} + +void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h, + const float16_t *input_bias, const float16_t *state_bias, const float16_t *weight_project, + const float16_t *project_bias, float16_t *hidden_state, float16_t *cell_state, float16_t *buffer[C7NUM], + const LstmParameter *lstm_param) { + // forward +#ifdef ENABLE_ARM64 + const float16_t *packed_input = input; + if (lstm_param->batch_ * lstm_param->seq_len_ >= C4NUM) { + float16_t *temp_input = buffer[0]; + RowMajor2ColLadder12MajorFp16(input, temp_input, lstm_param->seq_len_ * lstm_param->batch_, + lstm_param->input_size_); + packed_input = temp_input; + } +#else + float16_t *packed_input = buffer[0]; + RowMajor2Col16MajorFp16(input, packed_input, lstm_param->seq_len_ * lstm_param->batch_, lstm_param->input_size_, + false); +#endif + LstmUnidirectionalFp16(output, packed_input, weight_i, weight_h, input_bias, state_bias, weight_project, project_bias, + hidden_state, cell_state, buffer, lstm_param, false); + + // backward + if (lstm_param->bidirectional_) { + const float16_t *backward_weight_i = weight_i + 4 * lstm_param->input_col_align_ * lstm_param->input_size_; + const float16_t *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->output_size_; + const float16_t *backward_input_bias = input_bias + 4 * lstm_param->input_col_align_; + const float16_t *backward_state_bias = state_bias + 4 * lstm_param->state_col_align_; + const float16_t *backward_weight_project = + weight_project ? weight_project + lstm_param->hidden_size_ * lstm_param->proj_col_align_ : NULL; + float16_t *backward_output = output + lstm_param->batch_ * lstm_param->output_size_; + float16_t *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; + float16_t *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->output_size_; + + LstmUnidirectionalFp16(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias, + backward_state_bias, backward_weight_project, project_bias, backward_hidden_state, + backward_cell_state, buffer, lstm_param, true); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/lstm_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/lstm_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..675643f9aa0443131d1dc768c2ce58387d7ac0ab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/lstm_fp16.h @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_LSTM_FP16_H_ +#define NNACL_FP16_LSTM_FP16_H_ + +#include "nnacl_c/lstm_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif +void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align, + const int32_t *order); + +void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align, + const int32_t *order); + +void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order); + +void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order); + +void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, + int col, bool is_vec); + +void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols, + int inner_size); + +void ElementMulAccFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float16_t *output, const int element_size); + +void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h, + const float16_t *input_bias, const float16_t *state_bias, const float16_t *weight_project, + const float16_t *project_bias, float16_t *hidden_state, float16_t *cell_state, float16_t *buffer[C7NUM], + const LstmParameter *lstm_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_LSTM_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matmul_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matmul_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..9d239cde6481415246f1b29b891d8d68ea9e39fd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matmul_fp16.c @@ -0,0 +1,1204 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/matmul_fp16.h" + +static void Col2Row8SrcFromFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + int row_c8 = row / C8NUM * C8NUM; + int col_c8 = col / C8NUM * C8NUM; + const float16_t *src = (const float16_t *)src_ptr; + int ci = 0; + for (; ci < col_c8; ci += C8NUM) { + int ri = 0; + for (; ri < row_c8; ri += C8NUM) { + const float16_t *src_ptr1 = src + ci * row + ri; + float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; +#ifdef ENABLE_ARM64 + size_t strid_row = row * 2; + asm volatile( + "mov x10, %[src_ptr1]\n" + "mov x11, %[dst_ptr1]\n" + "mov x12, %[strid_row]\n" + "ld1 {v0.8h}, [x10], x12\n" + "ld1 {v1.8h}, [x10], x12\n" + "ld1 {v2.8h}, [x10], x12\n" + "ld1 {v3.8h}, [x10], x12\n" + "ld1 {v4.8h}, [x10], x12\n" + "ld1 {v5.8h}, [x10], x12\n" + "ld1 {v6.8h}, [x10], x12\n" + "ld1 {v7.8h}, [x10], x12\n" + + "zip1 v8.8h, v0.8h, v1.8h\n" + "zip1 v9.8h, v2.8h, v3.8h\n" + "zip1 v10.8h, v4.8h, v5.8h\n" + "zip1 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v16.2d, v12.2d, v14.2d\n" + "trn2 v18.2d, v12.2d, v14.2d\n" + "trn1 v17.2d, v13.2d, v15.2d\n" + "trn2 v19.2d, v13.2d, v15.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + "zip2 v10.8h, v4.8h, v5.8h\n" + "zip2 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v22.2d, v12.2d, v14.2d\n" + "trn1 v21.2d, v13.2d, v15.2d\n" + "trn2 v23.2d, v13.2d, v15.2d\n" + + "st1 {v16.8h}, [x11], #16\n" + "st1 {v17.8h}, [x11], #16\n" + "st1 {v18.8h}, [x11], #16\n" + "st1 {v19.8h}, [x11], #16\n" + "st1 {v20.8h}, [x11], #16\n" + "st1 {v21.8h}, [x11], #16\n" + "st1 {v22.8h}, [x11], #16\n" + "st1 {v23.8h}, [x11], #16\n" + : + : [dst_ptr1] "r"(dst_ptr1), [src_ptr1] "r"(src_ptr1), [strid_row] "r"(strid_row) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); +#else + for (int tr = 0; tr < C8NUM; ++tr) { + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[tr * C8NUM + tc] = src_ptr1[tc * row + tr]; + } + } +#endif + } + for (; ri < row; ++ri) { + const float16_t *src_ptr1 = src + ci * row; + float16_t *dst_ptr1 = dst_ptr + ci * row; + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[ri * C8NUM + tc] = src_ptr1[tc * row + ri]; + } + } + } + for (int r = 0; r < row; r++) { + for (int tc = ci; tc < col; tc++) { + int cd8 = tc / C8NUM; + int cm8 = tc % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[tc * row + r]; + } + } +} + +static void Col2Row8SrcFromFp32(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + int row_c8 = row / C8NUM * C8NUM; + int col_c8 = col / C8NUM * C8NUM; + int ci = 0; + const float *src = (const float *)src_ptr; + for (; ci < col_c8; ci += C8NUM) { + int ri = 0; + for (; ri < row_c8; ri += C8NUM) { + const float *src_ptr1 = src + ci * row + ri; + float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; +#ifdef ENABLE_ARM64 + size_t strid_row = row * 4; + asm volatile( + "mov x10, %[src_ptr1]\n" + "mov x11, %[dst_ptr1]\n" + "mov x12, %[strid_row]\n" + "ld1 {v8.4s, v9.4s}, [x10], x12\n" + "ld1 {v10.4s, v11.4s}, [x10], x12\n" + "ld1 {v12.4s, v13.4s}, [x10], x12\n" + "ld1 {v14.4s, v15.4s}, [x10], x12\n" + "ld1 {v16.4s, v17.4s}, [x10], x12\n" + "ld1 {v18.4s, v19.4s}, [x10], x12\n" + "ld1 {v20.4s, v21.4s}, [x10], x12\n" + "ld1 {v22.4s, v23.4s}, [x10], x12\n" + + "fcvtn v0.4h, v8.4s\n" + "fcvtn2 v0.8h, v9.4s\n" + "fcvtn v1.4h, v10.4s\n" + "fcvtn2 v1.8h, v11.4s\n" + "fcvtn v2.4h, v12.4s\n" + "fcvtn2 v2.8h, v13.4s\n" + "fcvtn v3.4h, v14.4s\n" + "fcvtn2 v3.8h, v15.4s\n" + "fcvtn v4.4h, v16.4s\n" + "fcvtn2 v4.8h, v17.4s\n" + "fcvtn v5.4h, v18.4s\n" + "fcvtn2 v5.8h, v19.4s\n" + "fcvtn v6.4h, v20.4s\n" + "fcvtn2 v6.8h, v21.4s\n" + "fcvtn v7.4h, v22.4s\n" + "fcvtn2 v7.8h, v23.4s\n" + + "zip1 v8.8h, v0.8h, v1.8h\n" + "zip1 v9.8h, v2.8h, v3.8h\n" + "zip1 v10.8h, v4.8h, v5.8h\n" + "zip1 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v16.2d, v12.2d, v14.2d\n" + "trn2 v18.2d, v12.2d, v14.2d\n" + "trn1 v17.2d, v13.2d, v15.2d\n" + "trn2 v19.2d, v13.2d, v15.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + "zip2 v10.8h, v4.8h, v5.8h\n" + "zip2 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v22.2d, v12.2d, v14.2d\n" + "trn1 v21.2d, v13.2d, v15.2d\n" + "trn2 v23.2d, v13.2d, v15.2d\n" + + "st1 {v16.8h}, [x11], #16\n" + "st1 {v17.8h}, [x11], #16\n" + "st1 {v18.8h}, [x11], #16\n" + "st1 {v19.8h}, [x11], #16\n" + "st1 {v20.8h}, [x11], #16\n" + "st1 {v21.8h}, [x11], #16\n" + "st1 {v22.8h}, [x11], #16\n" + "st1 {v23.8h}, [x11], #16\n" + : + : [dst_ptr1] "r"(dst_ptr1), [src_ptr1] "r"(src_ptr1), [strid_row] "r"(strid_row) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); +#else + for (int tr = 0; tr < C8NUM; ++tr) { + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[tr * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + tr]); + } + } +#endif + } + for (; ri < row; ++ri) { + const float *src_ptr1 = src + ci * row; + float16_t *dst_ptr1 = dst_ptr + ci * row; + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[ri * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + ri]); + } + } + } + for (int r = 0; r < row; r++) { + for (int tc = ci; tc < col; tc++) { + int cd8 = tc / C8NUM; + int cm8 = tc % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = (float16_t)(src[tc * row + r]); + } + } +} + +void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) { + if (src_float16) { + Col2Row8SrcFromFp16(src_ptr, dst_ptr, row, col); + } else { + Col2Row8SrcFromFp32(src_ptr, dst_ptr, row, col); + } + return; +} + +void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode) { + if (write_mode == OutType_Nhwc) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r16div = r / 16, r16mod = r % 16; + int c8div = c / 8, c8mod = c % 8; + size_t ci = r * stride + c; + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r16div * deep * 16 + d * 16 + r16mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else if (write_mode == OutType_C8) { + int col_8 = UP_ROUND(col, C8NUM); + int row_16 = UP_ROUND(row, C16NUM); + for (int r = 0; r < row_16; r++) { + for (int c = 0; c < col_8; c++) { + int r16div = r / C16NUM, r16mod = r % C16NUM; + int c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = (c8div * C8NUM * row_16 + r * C8NUM + c8mod); + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; + size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else { + for (int i = 0; i < row; ++i) { + int src_r_offset = i; + int dst_r_offset = i * col * stride; + for (int j = 0; j < col; ++j) { + int c8div = j / 8, c8mod = j % 8; + size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; + float16_t value = 0; + for (int d = 0; d < deep; ++d) { + size_t ai = src_r_offset + d * C16NUM; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, j) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } +} + +void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode) { + if (write_mode == OutType_Nhwc) { // common conv and matmul + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r12div = r / 12, r12mod = r % 12; + int c8div = c / 8, c8mod = c % 8; + size_t ci = r * stride + c; + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * 12 + d * 12 + r12mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else if (write_mode == OutType_C8) { // common deconv + int col_8 = UP_ROUND(col, C8NUM); + int row_12 = UP_ROUND(row, C12NUM); + for (int r = 0; r < row_12; r++) { + for (int c = 0; c < col_8; c++) { + int r12div = r / C12NUM, r12mod = r % C12NUM; + int c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod); + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod; + size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else { // winograd conv + for (int i = 0; i < row; ++i) { + int src_r_offset = i; + int dst_r_offset = i * col * stride; + for (int j = 0; j < col; ++j) { + int c8div = j / 8, c8mod = j % 8; + size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; + float16_t value = 0; + for (int d = 0; d < deep; ++d) { + size_t ai = src_r_offset + d * C12NUM; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, j) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } +} + +#ifdef ENABLE_DEBUG +void MatMul12x16Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, size_t stride, size_t out_type) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r12div = r / C12NUM, r12mod = r % C12NUM; + int c16div = c / C16NUM, c16mod = c % C16NUM; + size_t index = r * stride + c; + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod; + size_t bi = c16div * deep * C16NUM + d * C16NUM + c16mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[index] = value; + } + } +} +#endif + +void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, + int depth, int row, int col, int stride, int out_type) { + if (out_type == OutType_C8) { + // common deconv +#ifdef ENABLE_ARM64 + MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, false); +#else + MatMul12x8Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); +#endif + } else { + // winograd conv(OntType_TileC8) and common conv(OutType_Nhwc) and matmul(OutType_Nhwc) +#ifdef ENABLE_ARM64 + MatmulFp16Neon64Opt(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); +#else + MatMul12x8A32Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); +#endif + } + return; +} + +#ifdef ENABLE_ARM64 +// 1*8 X 8*8 -> 1 X 8 +void VecMatmulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, int depth, + int col) { + int ci = col; + const float16_t *bv_base = b; + + while (ci > 0) { + float16x8_t acc_0 = vdupq_n_f16((float16_t)0.0); + if (bias != NULL) { + acc_0 = vld1q_f16(bias); + bias += C8NUM; + } + + int di = 0; + for (; di < depth - C8NUM + 1; di += C8NUM) { + float16x8_t av = vld1q_f16(a + di); + float16x8_t bv_0; + float16x8_t bv_1; + for (int i = 0; i < C8NUM; i += C2NUM) { + bv_0 = vld1q_f16(bv_base); // bv_i为一行,8列数据 + acc_0 = vfmaq_n_f16(acc_0, bv_0, av[i]); // av[i]为向量中的一个值 + bv_base += C8NUM; + + bv_1 = vld1q_f16(bv_base); // bv_i为一行,8列数据 + acc_0 = vfmaq_n_f16(acc_0, bv_1, av[i + 1]); // av[i]为向量中的一个值 + bv_base += C8NUM; + } + } + if (di < depth) { + for (; di < depth; ++di) { + float16_t ai = a[di]; + float16x8_t bv0 = vld1q_f16(bv_base); + bv_base += C8NUM; + acc_0 = vfmaq_n_f16(acc_0, bv0, ai); + } + } + if (act_type == ActType_Relu) { + acc_0 = vmaxq_f16(acc_0, vdupq_n_f16((float16_t)0.0)); + } + if (act_type == ActType_Relu6) { + acc_0 = vminq_f16(vmaxq_f16(acc_0, vdupq_n_f16((float16_t)0.0)), vdupq_n_f16((float16_t)6.0)); + } + + // only save actual col num data + if (ci < C8NUM) { + for (int i = 0; i < ci; ++i) { + c[i] = acc_0[i]; + } + return; + } + vst1q_f16(c, acc_0); + c += C8NUM; + ci -= C8NUM; + } +} +#endif + +#ifdef ENABLE_ARM82_A32 +void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col) { + for (int ci = 0; ci < col; ci++) { + float value = 0; + for (int di = 0; di < depth; di++) { + value += a[di] * b[ci * depth + di]; + } + if (bias != NULL) value += bias[ci]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type == ActType_Relu || act_type == ActType_Relu6) value = MSMAX(0.0f, value); + c[ci] = value; + } +} +#endif + +void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, + int depth, int col) { +#ifdef ENABLE_ARM64 + MatVecMulFp16Neon64(a, b, c, bias, (int)act_type, depth, col); +#else + MatVecMulA32NeonFp16(a, b, c, bias, (int)act_type, depth, col); +#endif +} + +#ifdef ENABLE_ARM64 +static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) { + size_t stride = col * 2; + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.8h}, [x10], %[stride]\n" + "ld1 {v1.8h}, [x10], %[stride]\n" + "ld1 {v2.8h}, [x10], %[stride]\n" + "ld1 {v3.8h}, [x10], %[stride]\n" + "ld1 {v4.8h}, [x10], %[stride]\n" + "ld1 {v5.8h}, [x10], %[stride]\n" + "ld1 {v6.8h}, [x10], %[stride]\n" + "ld1 {v7.8h}, [x10], %[stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "ld1 {v8.8h}, [x10], %[stride]\n" + "ld1 {v9.8h}, [x10], %[stride]\n" + "ld1 {v10.8h}, [x10], %[stride]\n" + "ld1 {v11.8h}, [x10], %[stride]\n" + "ld1 {v12.8h}, [x10], %[stride]\n" + "ld1 {v13.8h}, [x10], %[stride]\n" + "ld1 {v14.8h}, [x10], %[stride]\n" + "ld1 {v15.8h}, [x10], %[stride]\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip1 v16.8h, v8.8h, v9.8h\n" + "zip1 v17.8h, v10.8h, v11.8h\n" + "zip1 v18.8h, v12.8h, v13.8h\n" + "zip1 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + + "zip2 v16.8h, v0.8h, v1.8h\n" + "zip2 v17.8h, v2.8h, v3.8h\n" + "zip2 v18.8h, v4.8h, v5.8h\n" + "zip2 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v16.8h, v8.8h, v9.8h\n" + "zip2 v17.8h, v10.8h, v11.8h\n" + "zip2 v18.8h, v12.8h, v13.8h\n" + "zip2 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + : + : [dst_c] "r"(dst_ptr), [src_c] "r"(src_ptr), [stride] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif + +void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + size_t row_up_16 = UP_ROUND(row, C16NUM); + size_t row16 = row / C16NUM * C16NUM; + size_t col8 = col / C8NUM * C8NUM; + const float16_t *src_r = src_ptr; + float16_t *dst_r = dst_ptr; + size_t ri = 0; + // find 16 block unit + for (; ri < row16; ri += C16NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; +#ifdef ENABLE_ARM64 + Row2Col16Block16(src_c, dst_c, col); +#else + for (int tr = 0; tr < C16NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C16NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; + for (size_t i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C16NUM * col; + dst_r += C16NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C16NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + for (; ri < row_up_16; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C16NUM] = 0; + } + dst_r += 1; + } + return; +} + +#ifdef ENABLE_ARM64 +void RowMajor2ColLadder12MajorFp16(const float16_t *src, float16_t *dst_ptr, int row, int col) { + // Col12Major ==> Col8Major ==> Col4Major + const float16_t *src_r = src; + float16_t *dst_r = dst_ptr; + int ri = 0; + size_t col8 = col / C8NUM * C8NUM; + // find 16 block unit + for (; ri <= row - C12NUM; ri += C12NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; + Transpose12x8ARM64Fp16(src_c, dst_c, col * C2NUM, C24NUM); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; + for (size_t i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C12NUM * col; + dst_r += C12NUM * col; + } + for (; ri <= row - C8NUM; ri += C8NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + Transpose8x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C8NUM * sizeof(float16_t)); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + for (size_t i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + for (; ri <= row - C4NUM; ri += C4NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C4NUM; + Transpose4x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C4NUM * sizeof(float16_t)); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C4NUM; + for (size_t i = 0; i < C4NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C4NUM * col; + dst_r += C4NUM * col; + } + if (ri < row) { + memcpy(dst_r, src_r, (row - ri) * col * C2NUM); + } +} + +void RowMajor2RowLadder12MajorFp16(const float16_t *src, float16_t *dst, int row, int col) { + // Row12 ==> Row8 ==> Row4 + for (int r = 0; r < row; r++) { + int c = 0; + for (; c <= col - C12NUM; c += C12NUM) { + MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); + MS_FLOAT16X4 src_data1 = MS_LD_F16(src + r * col + c + C8NUM); + MS_STQ_F16(dst + c / C12NUM * C12NUM * row + r * C12NUM, src_data); + MS_ST_F16(dst + c / C12NUM * C12NUM * row + r * C12NUM + C8NUM, src_data1); + } + for (; c <= col - C8NUM; c += C8NUM) { + MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); + MS_STQ_F16(dst + c / C12NUM * C12NUM * row + r * C8NUM, src_data); + } + for (; c <= col - C4NUM; c += C4NUM) { + MS_FLOAT16X4 src_data = MS_LD_F16(src + r * col + c); + MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); + } + for (; c < col; ++c) { + dst[c / C4NUM * C4NUM * row + r + c % C4NUM * row] = src[r * col + c]; + } + } +} + +void RowMajor2ColNMajorFp16srcFp16(const float16_t *src_ptr, float16_t *dst_ptr, int row, int col) { + const float16_t *src_r = src_ptr; + float16_t *dst_r = dst_ptr; + int ri = 0; + size_t col8 = col / C8NUM * C8NUM; + // find 16 block unit + for (; ri <= row - C16NUM; ri += C16NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; + Row2Col16Block16(src_c, dst_c, col); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; + for (size_t i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C16NUM * col; + dst_r += C16NUM * col; + } + for (; ri <= row - C8NUM; ri += C8NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + Transpose8x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C8NUM * sizeof(float16_t)); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + for (size_t i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + for (; ri <= row - C4NUM; ri += C4NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C4NUM; + Transpose4x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C4NUM * sizeof(float16_t)); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C4NUM; + for (size_t i = 0; i < C4NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C4NUM * col; + dst_r += C4NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C4NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } +} + +void RowMajor2ColNMajorFp16(const void *src_ptr, float16_t *dst_ptr, int row, int col, bool is_fp32_src) { + // Col16Major ==> Col8Major ==> Col4Major + if (!is_fp32_src) { + RowMajor2ColNMajorFp16srcFp16((const float16_t *)src_ptr, dst_ptr, row, col); + return; + } + const float *src_r = src_ptr; + float16_t *dst_r = dst_ptr; + int ri = 0; + // find 16 block unit + for (; ri <= row - C16NUM; ri += C16NUM) { + for (int r = 0; r < C16NUM; ++r) { + for (int c = 0; c < col; ++c) { + dst_r[c * C16NUM + r % C16NUM] = src_r[r * col + c]; + } + } + src_r += C16NUM * col; + dst_r += C16NUM * col; + } + for (; ri <= row - C8NUM; ri += C8NUM) { + for (int r = 0; r < C8NUM; ++r) { + for (int c = 0; c < col; ++c) { + dst_r[c * C8NUM + r % C8NUM] = src_r[r * col + c]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + for (; ri <= row - C4NUM; ri += C4NUM) { + for (int r = 0; r < C4NUM; ++r) { + for (int c = 0; c < col; ++c) { + dst_r[c * C4NUM + r % C4NUM] = src_r[r * col + c]; + } + } + src_r += C4NUM * col; + dst_r += C4NUM * col; + } + for (; ri < row; ++ri) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C4NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } +} + +void RowMajor2RowNMajorFp16(const void *src_ptr, float16_t *dst, int row, int col, bool is_fp32_src) { + // Row16 ==> Row8 ==> Row4 + if (is_fp32_src) { + const float *src = (const float *)src_ptr; + for (int r = 0; r < row; r++) { + int c = 0; + for (; c <= col - C16NUM; c += C16NUM) { + const float *cur_src = src + r * col + c; + MS_FLOAT32X4X4 src_f32_data = {MS_LDQ_F32(cur_src), MS_LDQ_F32(cur_src + C4NUM), MS_LDQ_F32(cur_src + C8NUM), + MS_LDQ_F32(cur_src + C12NUM)}; + MS_FLOAT16X4X4 res = { + MS_CVT_F16_F32(src_f32_data.val[0]), + MS_CVT_F16_F32(src_f32_data.val[1]), + MS_CVT_F16_F32(src_f32_data.val[2]), + MS_CVT_F16_F32(src_f32_data.val[3]), + }; + MS_ST4_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM, res); + } + for (; c <= col - C8NUM; c += C8NUM) { + const float *cur_src = src + r * col + c; + MS_FLOAT32X4X2 src_f32_data = {MS_LDQ_F32(cur_src), MS_LDQ_F32(cur_src + C4NUM)}; + MS_FLOAT16X4X2 res = { + MS_CVT_F16_F32(src_f32_data.val[0]), + MS_CVT_F16_F32(src_f32_data.val[1]), + }; + MS_ST2_F16(dst + c / C8NUM * C8NUM * row + r * C8NUM, res); + } + for (; c <= col - C4NUM; c += C4NUM) { + MS_FLOAT16X4 src_data = MS_CVT_F16_F32(MS_LDQ_F32(src + r * col + c)); + MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); + } + for (; c < col; ++c) { + dst[c / C4NUM * C4NUM * row + r * C4NUM + c % C4NUM] = src[r * col + c]; + } + } + return; + } + const float16_t *src = (const float16_t *)src_ptr; + for (int r = 0; r < row; r++) { + int c = 0; + for (; c <= col - C16NUM; c += C16NUM) { + MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); + MS_FLOAT16X8 src_data1 = MS_LDQ_F16(src + r * col + c + C8NUM); + MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM, src_data); + MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM + C8NUM, src_data1); + } + for (; c <= col - C8NUM; c += C8NUM) { + MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); + MS_STQ_F16(dst + c / C8NUM * C8NUM * row + r * C8NUM, src_data); + } + for (; c <= col - C4NUM; c += C4NUM) { + MS_FLOAT16X4 src_data = MS_LD_F16(src + r * col + c); + MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); + } + for (; c < col; ++c) { + dst[c / C4NUM * C4NUM * row + r * C4NUM + c % C4NUM] = src[r * col + c]; + } + } +} +#endif + +void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + size_t row_up_12 = UP_ROUND(row, C12NUM); + size_t row12 = row / C12NUM * C12NUM; + size_t col8 = col / C8NUM * C8NUM; + const float16_t *src_r = src_ptr; + float16_t *dst_r = dst_ptr; + size_t ri = 0; + // transpose 12x8 + for (; ri < row12; ri += C12NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; +#ifdef ENABLE_ARM64 + Transpose12x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), 24); +#elif ENABLE_ARM82_A32 + Transpose12x8A32Fp16(src_c, dst_c, col * sizeof(float16_t), 24); +#else + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C12NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; + for (size_t i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C12NUM * col; + dst_r += C12NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C12NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + for (; ri < row_up_12; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C12NUM] = 0; + } + dst_r += 1; + } +} + +void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + if (is_fp32_src) { + const float *fp32_src = (const float *)src; + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r_div16 = r / 16; + int r_mod16 = r % 16; + dst[r_div16 * 16 * col + c * 16 + r_mod16] = (float16_t)(fp32_src[r * col + c]); + } + } + } else { + const float16_t *fp16_src = (const float16_t *)src; + RowMajor2Col16MajorFp16Opt(fp16_src, dst, row, col); + } + return; +} + +void RowMajor2Col12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + if (is_fp32_src) { + const float *fp32_src = (const float *)src; + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r_div12 = r / 12; + int r_mod12 = r % 12; + dst[r_div12 * 12 * col + c * 12 + r_mod12] = (float16_t)(fp32_src[r * col + c]); + } + } + } else { + const float16_t *fp16_src = (const float16_t *)src; + RowMajor2Col12MajorFp16Opt(fp16_src, dst, row, col); + } + return; +} + +void RowMajor2Row16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int c_div16 = c / 16; + int c_mod16 = c % 16; + if (is_fp32_src) { + dst[c_div16 * 16 * row + r * 16 + c_mod16] = (float16_t)(((const float *)src)[r * col + c]); + } else { + dst[c_div16 * 16 * row + r * 16 + c_mod16] = ((const float16_t *)src)[r * col + c]; + } + } + } +} + +void RowMajor2Row16MajorFp16Opt(const float16_t *src, float16_t *dst, int row, int col) { + int col_align = UP_ROUND(col, C16NUM); + for (int r = 0; r < row; r++) { + int c = 0; + for (; c < col; c++) { + int c_div16 = c / C16NUM; + int c_mod16 = c % C16NUM; + dst[c_div16 * C16NUM * row + r * C16NUM + c_mod16] = src[r * col + c]; + } + for (; c < col_align; c++) { + int c_div16 = c / C16NUM; + int c_mod16 = c % C16NUM; + dst[c_div16 * C16NUM * row + r * C16NUM + c_mod16] = (float16_t)0.0; + } + } +} + +void RowMajor2Row12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int c_div12 = c / 12; + int c_mod12 = c % 12; + if (is_fp32_src) { + dst[c_div12 * 12 * row + r * 12 + c_mod12] = (float16_t)(((const float *)src)[r * col + c]); + } else { + dst[c_div12 * 12 * row + r * 12 + c_mod12] = ((const float16_t *)src)[r * col + c]; + } + } + } +} + +void RowMajor2Row8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + int down_c8 = col / C8NUM; + int stride = C8NUM * row; + for (int r = 0; r < row; r++) { + int c = 0; + for (; c < down_c8; c++) { + MS_FLOAT16X8 src_data = MS_LDQ_F16((const float16_t *)src + r * col + c * C8NUM); + MS_STQ_F16(dst + c * stride + r * C8NUM, src_data); + } + c *= C8NUM; + for (; c < col; c++) { + int c_div8 = c / 8; + int c_mod8 = c % 8; + dst[c_div8 * stride + r * 8 + c_mod8] = ((const float16_t *)src)[r * col + c]; + } + } +} + +void RowMajor2ColMajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + for (int r = 0; r < row; ++r) { + for (int c = 0; c < col; ++c) { + if (is_fp32_src) { + dst[c * row + r] = (float16_t)(((const float *)src)[r * col + c]); + } else { + dst[c * row + r] = ((const float16_t *)src)[r * col + c]; + } + } + } +} + +#ifdef ENABLE_ARM64 +void RowMajor2Col8MajorFp16_arm64(const float16_t *src_c, float16_t *dst_c, size_t col) { + size_t stride = col * sizeof(float16_t); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.8h}, [x10], %[stride]\n" + "ld1 {v1.8h}, [x10], %[stride]\n" + "ld1 {v2.8h}, [x10], %[stride]\n" + "ld1 {v3.8h}, [x10], %[stride]\n" + "ld1 {v4.8h}, [x10], %[stride]\n" + "ld1 {v5.8h}, [x10], %[stride]\n" + "ld1 {v6.8h}, [x10], %[stride]\n" + "ld1 {v7.8h}, [x10], %[stride]\n" + + "zip1 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v0.8h, v1.8h\n" + "zip1 v10.8h, v2.8h, v3.8h\n" + "zip2 v11.8h, v2.8h, v3.8h\n" + "zip1 v12.8h, v4.8h, v5.8h\n" + "zip2 v13.8h, v4.8h, v5.8h\n" + "zip1 v14.8h, v6.8h, v7.8h\n" + "zip2 v15.8h, v6.8h, v7.8h\n" + + "trn1 v16.4s, v8.4s, v10.4s\n" + "trn2 v17.4s, v8.4s, v10.4s\n" + "trn1 v18.4s, v12.4s, v14.4s\n" + "trn2 v19.4s, v12.4s, v14.4s\n" + "trn1 v20.4s, v9.4s, v11.4s\n" + "trn2 v21.4s, v9.4s, v11.4s\n" + "trn1 v22.4s, v13.4s, v15.4s\n" + "trn2 v23.4s, v13.4s, v15.4s\n" + + "trn1 v0.2d, v16.2d, v18.2d\n" + "trn1 v1.2d, v17.2d, v19.2d\n" + "trn2 v2.2d, v16.2d, v18.2d\n" + "trn2 v3.2d, v17.2d, v19.2d\n" + "trn1 v4.2d, v20.2d, v22.2d\n" + "trn1 v5.2d, v21.2d, v23.2d\n" + "trn2 v6.2d, v20.2d, v22.2d\n" + "trn2 v7.2d, v21.2d, v23.2d\n" + + "st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x11], #64\n" + "st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x11], #64\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); + return; +} +#endif + +void RowMajor2Col8MajorFp16SrcFp16(const float16_t *src, float16_t *dst, int row, int col) { + int row8 = row / C8NUM * C8NUM; +#ifdef ENABLE_ARM64 + int col_skip = col / C8NUM * C8NUM; + int skip_size = C8NUM; +#else + int col_skip = col / C4NUM * C4NUM; + int skip_size = C4NUM; +#endif + const float16_t *src_r = src; + float16_t *dst_r = dst; + + int ri = 0; + for (; ri < row8; ri += C8NUM) { + int ci = 0; + for (; ci < col_skip; ci += skip_size) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + +#ifdef ENABLE_ARM64 + RowMajor2Col8MajorFp16_arm64(src_c, dst_c, col); +#else + for (int tr = 0; tr < C8NUM; tr++) { + for (int tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C8NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + for (int i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + for (; ri < row; ri++, src_r += col, dst_r++) { + for (int i = 0; i < col; i++) { + dst_r[i * C8NUM] = src_r[i]; + } + } + + for (; ri < UP_ROUND(row, C8NUM); ri++, dst_r++) { + for (int i = 0; i < col; i++) { + dst_r[i * C8NUM] = 0; + } + } +} + +void RowMajor2Col8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + if (!is_fp32_src) { + return RowMajor2Col8MajorFp16SrcFp16(src, dst, row, col); + } + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r_div8 = r / 8; + int r_mod8 = r % 8; + dst[r_div8 * 8 * col + c * 8 + r_mod8] = (float16_t)(((const float *)src)[r * col + c]); + } + } +} + +#if defined(ENABLE_DEBUG) && defined(ENABLE_ARM64) +// arm64 matmul +void MatmulBaseFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc) { + int r16 = row / C16NUM * C16NUM; + int r8 = row / C8NUM * C8NUM; + for (int r = 0; r < row; ++r) { + int row_tile = 0; + if (r < r16) { + row_tile = C16NUM; + } else if (r < r8) { + row_tile = C8NUM; + } else { + row_tile = C4NUM; + } + int index = r / row_tile * row_tile * depth + r % row_tile; + for (int t = 0; t < col; ++t) { + int c_div = t / C8NUM; + int c_mod = t % C8NUM; + float16_t res = bias[t]; + for (int d = 0; d < depth; ++d) { + res += a[index + d * row_tile] * b[c_div * depth * C8NUM + d * C8NUM + c_mod]; + } + c[r * col + t] = res; + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matmul_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matmul_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..4a8e94bb8169dfa317d75485d317a52c81e5b840 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matmul_fp16.h @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_MATMUL_FP16_H_ +#define NNACL_FP16_MATMUL_FP16_H_ + +#include +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" + +#define ADD_BIAS(value, bias, c) \ + if (bias != NULL) value = value + bias[c]; + +#define DO_RELU(value, act_type) \ + if (act_type == ActType_Relu) value = MSMAX(0.0f, value); + +#define DO_RELU6(value, act_type) \ + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); \ + if (act_type == ActType_Relu6) value = MSMAX(0.0f, value); + +#ifdef __cplusplus +extern "C" { +#endif +void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode); + +void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode); + +#ifdef ENABLE_ARM64 +void RowMajor2ColLadder12MajorFp16(const float16_t *src, float16_t *dst_ptr, int row, int col); + +void RowMajor2RowLadder12MajorFp16(const float16_t *src, float16_t *dst, int row, int col); + +void RowMajor2ColNMajorFp16(const void *src, float16_t *dst_ptr, int row, int col, bool is_fp32_src); + +void RowMajor2RowNMajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void MatMul12x16Fp16Opt(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, size_t stride, size_t out_type); +void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc); + +void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); + +void MatmulBaseFp16Neon(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); + +void MatmulFp16OptV2(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); + +#ifdef ENABLE_DEBUG +void MatmulBaseFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); +#endif + +void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col); + +void VecMatmulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, int depth, + int col); +void VecMatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col); +#elif ENABLE_ARM82_A32 +void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode); + +void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col); + +void MatVecMulA32NeonFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col); +#endif + +void MatMul12x16Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, size_t stride, size_t out_type); + +void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, + int depth, int row, int col, int stride, int out_type); + +void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, + int depth, int col); + +void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16); + +void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); + +void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); + +void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2Col12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2Row16MajorFp16Opt(const float16_t *src, float16_t *dst, int row, int col); + +void RowMajor2Row16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2Row12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2Row8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2Col8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2ColMajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_MATMUL_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matrix_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matrix_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..518a68abcb3de3a21268a4d4738a27e88971d9b1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matrix_fp16.c @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/matrix_fp16.h" + +void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, + int n) { + int count = 0; + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + float16_t res = 0; + for (int i = 0; i < k; i++) { + res += *(matrix_a + h_offset + i) * *(matrix_b + w + i * n); + } + *(matrix_c + count) = res; + count++; + } + } +} + +#ifndef ENABLE_ARM64 +void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, + int n, int in_channel) { + int cnt = 0; + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + for (int y = 0; y < in_channel; ++y) { + float tmp = 0; + for (int z = 0; z < k; ++z) { + tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; + } + matrix_c[cnt++] = tmp; + } + } + } +} +#endif + +void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c, + const float16_t *bias, int m, int k, int n) { + if (bias == NULL) { + int count = 0; + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + float16x8_t res = vmovq_n_f16(0); + for (int i = 0; i < k; i++) { + res = vaddq_f16(res, vmulq_f16(matrix_a[h_offset + i], matrix_b[w + i * n])); + } + matrix_c[count] = res; + count++; + } + } + } else { + int count = 0; + float16x8_t bias_ptr = vld1q_f16(bias); + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + float16x8_t res = vmovq_n_f16(0); + for (int i = 0; i < k; i++) { + res = vaddq_f16(res, vmulq_f16(matrix_a[h_offset + i], matrix_b[w + i * n])); + } + matrix_c[count] = vaddq_f16(res, bias_ptr); + count++; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matrix_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matrix_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..e347c242661198296d8778c9e2c7b9c367f8f1dc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matrix_fp16.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_MATRIX_FP16_H_ +#define NNACL_FP16_MATRIX_FP16_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif +void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n); + +void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c, + const float16_t *bias, int m, int k, int n); +void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, + int n, int in_channel); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_MATRIX_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/one_hot_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/one_hot_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..06676190195eb00c1868d34fba25fd969e990dac --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/one_hot_fp16.c @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/one_hot_fp16.h" +#include "nnacl_c/errorcode.h" +int OneHotToFp16(const int *indices, float16_t on_value, float16_t off_value, float16_t *output, + const OneHotStruct *one_hot_param, const int tid, const int thread_num) { + if (indices == NULL || one_hot_param == NULL || output == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + + int outer_size = one_hot_param->outer_size_; + int inner_size = one_hot_param->inner_size_; + int depth = one_hot_param->depth_; + int i, j, k; + for (i = tid; i < outer_size; i += thread_num) { + float16_t *output_ptr = output + i * depth * inner_size; + for (k = 0; k < depth; k++) { // output layout: outer_size * depth * inner_size + const int *indices_ptr = indices + i * inner_size; + for (j = 0; j < inner_size; j++) { + *output_ptr = off_value; + int index = *(indices_ptr++); + if (one_hot_param->support_neg_index_ && index < 0) { + index += depth; + } + if (index == k) { + *output_ptr = on_value; + } + output_ptr++; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/one_hot_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/one_hot_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..6d10be8df70845625cf13e8e0cf7428b5b40a1f2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/one_hot_fp16.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_ONE_HOT_FP16_H_ +#define NNACL_FP16_ONE_HOT_FP16_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/one_hot_parameter.h" +#include "nnacl_c/kernel/one_hot.h" + +#ifdef __cplusplus +extern "C" { +#endif +int OneHotToFp16(const int *indices, float16_t on_value, float16_t off_value, float16_t *output, + const OneHotStruct *one_hot_param, const int tid, const int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_ONE_HOT_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..164696e65cabaf62340a2ef1657e0c1e97d4834c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.c @@ -0,0 +1,933 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/pack_fp16.h" +#include + +#ifdef ENABLE_ARM +void PackWeightConvDw3x3Fp16(const void *src, void *dst, int channel) { + // nchw to nc8hw8 with 1D F(2,3) + for (int i = 0; i < channel; i++) { + float16_t *src_kernel = (float16_t *)src + i * 9; + float16_t *dst_kernel = (float16_t *)dst + (i / 8) * 96 + i % 8; + for (int y = 0; y < 3; y++) { + float16_t g0 = src_kernel[3 * y]; + float16_t g1 = src_kernel[3 * y + 1]; + float16_t g2 = src_kernel[3 * y + 2]; + + dst_kernel[32 * y] = g0; + dst_kernel[32 * y + 8] = (float16_t)0.5 * (g0 + g1 + g2); + dst_kernel[32 * y + 16] = (float16_t)0.5 * (g0 - g1 + g2); + dst_kernel[32 * y + 24] = g2; + } + } +} +#endif + +void PackHWCToWHCFp16(const float16_t *src, float16_t *dst, int height, int width, int channel) { + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float16_t)); + } + } +} + +void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c4 * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C4NUM; + for (int i = 0; i < channel; i++) { + int c4_block_num = i / C4NUM; + int c4_block_rem = i % C4NUM; + int src_ic_offset = src_kernel_offset + i; + int dst_ic_offset = dst_kernel_offset + c4_block_num * plane * C4NUM + c4_block_rem; + ((float16_t *)dst + dst_ic_offset)[0] = ((float16_t *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNHWCToNC8HW8NotAlignedFp16(const float16_t *src, float16_t *dst, const int batch, const int plane, + const int channel) { + if (channel <= C8NUM) { + memcpy(dst, src, batch * plane * channel * sizeof(float16_t)); + return; + } + int tmp = DOWN_DIV(channel, C8NUM); + int c_res = channel - tmp * C8NUM; + int c8_block = tmp * plane * C8NUM; + for (int b = 0; b < batch; b++) { + int batch_oc_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = batch_oc_offset + k * channel; + int dst_kernel_offset = batch_oc_offset + k * C8NUM; + int c = 0; + for (; c <= channel - C8NUM; c += C8NUM) { + float16x8_t src_data = vld1q_f16(src + src_kernel_offset + c); + vst1q_f16(dst + dst_kernel_offset + c * plane, src_data); + } + for (; c < channel; ++c) { + dst[batch_oc_offset + c8_block + k * c_res + c - tmp * C8NUM] = src[src_kernel_offset + c]; + } + } + } +} + +void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c4 * C4NUM; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_rem = c % C4NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int channel, int task_id, + int thread_count) { +#ifdef ENABLE_ARM64 + // Transpose16x8 in arm64 + const int hw_tile = C16NUM; +#else + // Transpose8x8 in others + const int hw_tile = C8NUM; +#endif + int hw_align = plane / hw_tile; + int task_start = 0; + int task_end = plane; + if (thread_count > 0) { + int offset_hw = UP_DIV(hw_align, thread_count) * hw_tile; + task_start = offset_hw * task_id; + int count = plane - task_start; + if (count <= 0) { + return; + } + task_end = (task_id + 1) == thread_count ? plane : MSMIN(plane, task_start + offset_hw); + hw_align = task_start + ((task_end - task_start) >= offset_hw ? offset_hw : 0); + } else { + hw_align *= hw_tile; + } + int c8 = channel / C8NUM * C8NUM; + int batch = plane * channel; + for (int n = 0; n < batches; n++) { + const float16_t *src_batch = (const float16_t *)src + n * batch; + float16_t *dst_batch = (float16_t *)dst + n * batch; + int hw = task_start; + for (; hw < hw_align; hw += hw_tile) { + int c = 0; + for (; c < c8; c += C8NUM) { + const float16_t *src_ptr = src_batch + hw * channel + c; + float16_t *dst_ptr = dst_batch + c * plane + hw; +#ifdef ENABLE_ARM64 + size_t src_stride = channel * sizeof(float16_t); + size_t dst_stride = plane * sizeof(float16_t); + Transpose16x8ARM64Fp16(src_ptr, dst_ptr, src_stride, dst_stride); +#elif defined(ENABLE_ARM82_A32) + size_t src_stride = channel * sizeof(float16_t); + size_t dst_stride = plane * sizeof(float16_t); + Transpose8x8A32Fp16(src_ptr, dst_ptr, src_stride, dst_stride); +#else + for (int tr = 0; tr < hw_tile; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; + } + } +#endif + } + for (; c < channel; c++) { + const float16_t *src_ptr = src_batch + hw * channel + c; + float16_t *dst_ptr = dst_batch + c * plane + hw; + for (size_t i = 0; i < hw_tile; i++) { + dst_ptr[i] = src_ptr[i * channel]; + } + } + } + for (; hw < task_end; hw++) { + const float16_t *src_ptr = src_batch + hw * channel; + float16_t *dst_ptr = dst_batch + hw; + for (size_t i = 0; i < channel; i++) { + dst_ptr[i * plane] = src_ptr[i]; + } + } + } +} + +void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count) { + return PackNHWCToNCHWFp16(src, dst, batch, channel, plane, task_id, thread_count); +} + +void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int ic4 = UP_DIV(channel, C4NUM); + int c4_channel = ic4 * C4NUM; + int nhwc4_batch_unit_offset = ic4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc4_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float16_t *dst_per_plane = (float16_t *)dst + nhwc4_batch_offset + i * c4_channel; + memcpy(dst_per_plane, (float16_t *)src + batch_offset + i * channel, channel * sizeof(float16_t)); + for (int j = channel; j < c4_channel; ++j) { + dst_per_plane[j] = 0; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float16_t); + memcpy(dst, src, ori_input_size); + } +} + +void PackNHWCToNHWC8Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int ic8 = UP_DIV(channel, C8NUM); + int c8_channel = ic8 * C8NUM; + int nhwc8_batch_unit_offset = ic8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + int nhwc8_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float16_t *dst_per_plane = (float16_t *)dst + nhwc8_batch_offset + i * c8_channel; + memcpy(dst_per_plane, (float16_t *)src + batch_offset + i * channel, channel * sizeof(float16_t)); + for (int j = channel; j < c8_channel; ++j) { + dst_per_plane[j] = 0; + } + } + nhwc8_batch_offset += nhwc8_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float16_t); + memcpy(dst, src, ori_input_size); + } +} + +void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc_batch_unit_offset = channel * plane; + for (int b = 0; b < batch; b++) { + int batch_offset = b * c4 * C4NUM * plane; + for (int i = 0; i < plane; i++) { + memcpy((float16_t *)dst + b * nhwc_batch_unit_offset + i * channel, + (float16_t *)src + batch_offset + i * c4 * C4NUM, channel * sizeof(float16_t)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float16_t); + memcpy((float16_t *)dst, (float16_t *)src, ori_input_size); + } +} + +void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int nhwc4_batch_offset = 0; + int ic4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = ic4 * C4NUM * plane; + + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int c = 0; c < channel; c++) { + int src_c_offset = batch_offset + c * plane; + int dst_c_offset = nhwc4_batch_offset + c; + for (int i = 0; i < plane; i++) { + int src_plane_offset = src_c_offset + i; + int dst_plane_offset = dst_c_offset + i * ic4 * C4NUM; + ((float16_t *)dst)[dst_plane_offset] = ((float16_t *)src)[src_plane_offset]; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } +} + +void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * channel; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c * plane; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNCHWFp32ToNC8HW8Fp16(const void *src_ptr, void *dst_ptr, int batch, int plane, int channel) { + const float *src = (const float *)src_ptr; + float16_t *dst = (float16_t *)dst_ptr; + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c8 * C8NUM; + for (int c = 0; c < channel; c++) { + int c8_block_num = c / C8NUM; + int c8_block_rem = c % C8NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem; + (dst + dst_kernel_offset)[0] = (float16_t)(src + src_kernel_offset)[0]; + } + } + } +} + +void PackNCHWFp16ToNC8HW8Fp16(const void *src_ptr, void *dst_ptr, int batch, int plane, int channel) { + const float16_t *src = (const float16_t *)src_ptr; + float16_t *dst = (float16_t *)dst_ptr; + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c8 * C8NUM; + for (int c = 0; c < channel; c++) { + int c8_block_num = c / C8NUM; + int c8_block_rem = c % C8NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem; + (dst + dst_kernel_offset)[0] = (float16_t)(src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC8HW8ToNCHWFp16(const void *src_ptr, void *dst_ptr, int batch, int plane, int channel) { + const float16_t *src = (const float16_t *)src_ptr; + float16_t *dst = (float16_t *)dst_ptr; + int c8 = UP_ROUND(channel, C8NUM); + for (int b = 0; b < batch; b++) { + const float16_t *batch_src = src + b * plane * c8; + float16_t *batch_dst = dst + b * plane * channel; + + for (size_t c = 0; c < channel; c++) { + size_t c_div = c / C8NUM; + size_t c_mod = c % C8NUM; + for (size_t p = 0; p < plane; p++) { + int src_offset = c_div * plane * C8NUM + p * C8NUM + c_mod; + int dst_offset = c * plane + p; + batch_dst[dst_offset] = batch_src[src_offset]; + } + } + } +} + +void PackNHWCFp32ToNHWC8Fp16(const float *src, float16_t *dst, int batch, int plane, int channel) { + int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; + for (int b = 0; b < batch; b++) { + float16_t *dst_batch = dst + b * plane * c8_channel; + const float *src_batch = src + b * plane * channel; + for (int i = 0; i < plane; i++) { + float16_t *dst_plane = dst_batch + i * c8_channel; + const float *src_plane = src_batch + i * channel; + for (int c = 0; c < channel; c++) { + dst_plane[c] = (float16_t)(src_plane[c]); + } + } + } +} + +void PackNHWCFp32ToC8HWN8Fp16(const float *src, float16_t *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int hw = 0; hw < plane; hw++) { + for (int c = 0; c < channel; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int src_index = n * plane * channel + hw * channel + c; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + dst[dst_index] = (float16_t)(src[src_index]); + } + } + } + return; +} + +void PackNC8HW8ToNHWCFp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c8 * C8NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C8NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c8 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C8NUM; + int dst_c_offset = dst_kernel_offset + c * C8NUM; + vst1q_f16(dst + dst_c_offset, vld1q_f16(src + src_c_offset)); + } + // res part + int res_c = channel - (c8 - 1) * C8NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c8 - 1) * C8NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c8 - 1) * C8NUM + i; + ((float16_t *)dst + dst_res_c_offset)[0] = ((float16_t *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNHWCFp16ToC8HWN8Fp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int hw = 0; hw < plane; hw++) { + for (int c = 0; c < channel; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int src_index = n * plane * channel + hw * channel + c; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + dst[dst_index] = src[src_index]; + } + } + } + return; +} + +void PackNHWC8Fp16ToNHWCFp32(const float16_t *src, float *dst, int batch, int plane, int channel) { + int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; + for (int b = 0; b < batch; b++) { + const float16_t *src_batch = src + b * plane * c8_channel; + float *dst_batch = dst + b * plane * channel; + for (int i = 0; i < plane; i++) { + const float16_t *src_plane = src_batch + i * c8_channel; + float *dst_plane = dst_batch + i * channel; + for (int c = 0; c < channel; c++) { + dst_plane[c] = (float16_t)(src_plane[c]); + } + } + } +} + +void PackNHWC8ToNHWCFp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel) { + int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; + for (int b = 0; b < batch; b++) { + const float16_t *src_batch = src + b * plane * c8_channel; + float16_t *dst_batch = dst + b * plane * channel; + for (int i = 0; i < plane; i++) { + const float16_t *src_plane = src_batch + i * c8_channel; + float16_t *dst_plane = dst_batch + i * channel; + memcpy(dst_plane, src_plane, channel * sizeof(float16_t)); + } + } +} + +#ifdef ENABLE_ARM82_A32 +inline void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride) { + asm volatile( + "mov r10, %[src]\n" + "mov r12, %[dst]\n" + "vld1.16 {q0}, [r10], %[src_stride]\n" + "vld1.16 {q2}, [r10], %[src_stride]\n" + "vld1.16 {q4}, [r10], %[src_stride]\n" + "vld1.16 {q6}, [r10], %[src_stride]\n" + + "vtrn.16 d0, d4\n" + "vtrn.16 d1, d5\n" + "vtrn.16 d8, d12\n" + "vtrn.16 d9, d13\n" + + "vld1.16 {q8}, [r10], %[src_stride]\n" + "vld1.16 {q10}, [r10], %[src_stride]\n" + "vld1.16 {q12}, [r10], %[src_stride]\n" + "vld1.16 {q14}, [r10], %[src_stride]\n" + + "vtrn.32 d0, d8\n" + "vtrn.32 d4, d12\n" + "vtrn.32 d1, d9\n" + "vtrn.32 d5, d13\n" + + "vtrn.16 d16, d20\n" + "vtrn.16 d17, d21\n" + "vtrn.16 d24, d28\n" + "vtrn.16 d25, d29\n" + + "vtrn.32 d16, d24\n" + "vtrn.32 d20, d28\n" + "vtrn.32 d17, d25\n" + "vtrn.32 d21, d29\n" + + "vswp d1, d16\n" + "vswp d5, d20\n" + "vswp d9, d24\n" + "vswp d13, d28\n" + + "vst1.16 {q0}, [r12], %[dst_stride]\n" + "vst1.16 {q2}, [r12], %[dst_stride]\n" + "vst1.16 {q4}, [r12], %[dst_stride]\n" + "vst1.16 {q6}, [r12], %[dst_stride]\n" + + "vst1.16 {q8}, [r12], %[dst_stride]\n" + "vst1.16 {q10}, [r12], %[dst_stride]\n" + "vst1.16 {q12}, [r12], %[dst_stride]\n" + "vst1.16 {q14}, [r12], %[dst_stride]\n" + + : + : [dst] "r"(dst), [src] "r"(src), [src_stride] "r"(src_stride), [dst_stride] "r"(dst_stride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +} + +inline void Transpose12x8A32Fp16(const float16_t *src_c, float16_t *dst_c, size_t src_stride, size_t dst_stride) { + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.16 {q0}, [r10], %[src_stride]\n" + "vld1.16 {q2}, [r10], %[src_stride]\n" + "vld1.16 {q4}, [r10], %[src_stride]\n" + "vld1.16 {q6}, [r10], %[src_stride]\n" + + "vtrn.16 d0, d4\n" + "vtrn.16 d1, d5\n" + "vtrn.16 d8, d12\n" + "vtrn.16 d9, d13\n" + + "vld1.16 {q8}, [r10], %[src_stride]\n" + "vld1.16 {q10}, [r10], %[src_stride]\n" + "vld1.16 {q12}, [r10], %[src_stride]\n" + "vld1.16 {q14}, [r10], %[src_stride]\n" + + "vtrn.32 d0, d8\n" + "vtrn.32 d4, d12\n" + "vtrn.32 d1, d9\n" + "vtrn.32 d5, d13\n" + + "vtrn.16 d16, d20\n" + "vtrn.16 d17, d21\n" + "vtrn.16 d24, d28\n" + "vtrn.16 d25, d29\n" + + "vld1.16 {q1}, [r10], %[src_stride]\n" + "vld1.16 {q3}, [r10], %[src_stride]\n" + "vld1.16 {q5}, [r10], %[src_stride]\n" + "vld1.16 {q7}, [r10], %[src_stride]\n" + + "vtrn.32 d16, d24\n" + "vtrn.32 d20, d28\n" + "vtrn.32 d17, d25\n" + "vtrn.32 d21, d29\n" + + "vswp d1, d16\n" + "vswp d5, d20\n" + "vswp d9, d24\n" + "vswp d13, d28\n" + + "vtrn.16 d2, d6\n" + "vtrn.16 d3, d7\n" + "vtrn.16 d10, d14\n" + "vtrn.16 d11, d15\n" + + "vtrn.32 d2, d10\n" + "vtrn.32 d6, d14\n" + "vtrn.32 d3, d11\n" + "vtrn.32 d7, d15\n" + + "vst1.16 {q0, d2}, [r12], %[dst_stride]\n" + "vst1.16 {q2, d6}, [r12], %[dst_stride]\n" + "vst1.16 {q4, d10}, [r12], %[dst_stride]\n" + "vst1.16 {q6, d14}, [r12], %[dst_stride]\n" + + "vswp d3, d18\n" + "vswp d7, d22\n" + "vswp d11, d26\n" + "vswp d15, d30\n" + + "vst1.16 {q8, d18}, [r12], %[dst_stride]\n" + "vst1.16 {q10, d22}, [r12], %[dst_stride]\n" + "vst1.16 {q12, d26}, [r12], %[dst_stride]\n" + "vst1.16 {q14, d30}, [r12], %[dst_stride]\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [src_stride] "r"(src_stride), [dst_stride] "r"(dst_stride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +} +#endif + +#ifdef ENABLE_ARM64 +inline void Transpose4x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) { + dst_stride += dst_stride; + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8h}, [x10], %[src_stride]\n" + "ld1 {v1.8h}, [x10], %[src_stride]\n" + "ld1 {v2.8h}, [x10], %[src_stride]\n" + "ld1 {v3.8h}, [x10], %[src_stride]\n" + "add x10, x11, %[dst_stride]\n" + + "zip1 v4.8h, v0.8h, v1.8h\n" + "zip1 v5.8h, v2.8h, v3.8h\n" + + "trn1 v6.4s, v4.4s, v5.4s\n" + "trn2 v7.4s, v4.4s, v5.4s\n" + + "trn1 v24.2d, v6.2d, v7.2d\n" + "trn2 v25.2d, v6.2d, v7.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + + "trn1 v10.4s, v8.4s, v9.4s\n" + "trn2 v11.4s, v8.4s, v9.4s\n" + + "trn1 v26.2d, v10.2d, v11.2d\n" + "trn2 v27.2d, v10.2d, v11.2d\n" + + "st1 {v24.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v25.8h}, [x10], %[tow_dst_stride]\n" + "st1 {v26.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v27.8h}, [x10], %[tow_dst_stride]\n" + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [src_stride] "r"(src_stride), [dst_stride] "r"(dst_stride), + [tow_dst_stride] "r"(2 * dst_stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v24", "v25", "v26", + "v27"); +} + +inline void Transpose8x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) { + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8h}, [x10], %[src_stride]\n" + "ld1 {v1.8h}, [x10], %[src_stride]\n" + "ld1 {v2.8h}, [x10], %[src_stride]\n" + "ld1 {v3.8h}, [x10], %[src_stride]\n" + "ld1 {v4.8h}, [x10], %[src_stride]\n" + "ld1 {v5.8h}, [x10], %[src_stride]\n" + "ld1 {v6.8h}, [x10], %[src_stride]\n" + "ld1 {v7.8h}, [x10], %[src_stride]\n" + "add x10, x11, %[dst_stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v26.2d, v20.2d, v22.2d\n" + "trn1 v25.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + "zip2 v10.8h, v4.8h, v5.8h\n" + "zip2 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v28.2d, v12.2d, v14.2d\n" + "trn2 v30.2d, v12.2d, v14.2d\n" + "trn1 v29.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v24.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v25.8h}, [x10], %[tow_dst_stride]\n" + "st1 {v26.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v27.8h}, [x10], %[tow_dst_stride]\n" + "st1 {v28.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v29.8h}, [x10], %[tow_dst_stride]\n" + "st1 {v30.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v31.8h}, [x10], %[tow_dst_stride]\n" + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [src_stride] "r"(src_stride), [dst_stride] "r"(dst_stride), + [tow_dst_stride] "r"(2 * dst_stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} + +void Transpose12x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) { +#ifdef ENABLE_DEBUG + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_ptr[tc * C12NUM + tr] = src_ptr[tr * col + tc]; + } + } +#else + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8h}, [x10], %[src_stride]\n" + "ld1 {v1.8h}, [x10], %[src_stride]\n" + "ld1 {v2.8h}, [x10], %[src_stride]\n" + "ld1 {v3.8h}, [x10], %[src_stride]\n" + "ld1 {v4.8h}, [x10], %[src_stride]\n" + "ld1 {v5.8h}, [x10], %[src_stride]\n" + "ld1 {v6.8h}, [x10], %[src_stride]\n" + "ld1 {v7.8h}, [x10], %[src_stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "ld1 {v8.8h}, [x10], %[src_stride]\n" + "ld1 {v9.8h}, [x10], %[src_stride]\n" + "ld1 {v10.8h}, [x10], %[src_stride]\n" + "ld1 {v11.8h}, [x10], %[src_stride]\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip1 v16.8h, v8.8h, v9.8h\n" + "zip1 v17.8h, v10.8h, v11.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + + "trn1 v28.2d, v20.2d, v20.2d\n" + "trn2 v29.2d, v20.2d, v20.2d\n" + "trn1 v30.2d, v21.2d, v21.2d\n" + "trn2 v31.2d, v21.2d, v21.2d\n" + + "add x10, x11, #16\n" + "st1 {v24.8h}, [x11], %[dst_stride]\n" + "st1 {v28.4h}, [x10], %[dst_stride]\n" + "st1 {v26.8h}, [x11], %[dst_stride]\n" + "st1 {v30.4h}, [x10], %[dst_stride]\n" + "st1 {v25.8h}, [x11], %[dst_stride]\n" + "st1 {v29.4h}, [x10], %[dst_stride]\n" + "st1 {v27.8h}, [x11], %[dst_stride]\n" + "st1 {v31.4h}, [x10], %[dst_stride]\n" + + "zip2 v16.8h, v0.8h, v1.8h\n" + "zip2 v17.8h, v2.8h, v3.8h\n" + "zip2 v18.8h, v4.8h, v5.8h\n" + "zip2 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v16.8h, v8.8h, v9.8h\n" + "zip2 v17.8h, v10.8h, v11.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + + "trn1 v28.2d, v20.2d, v20.2d\n" + "trn2 v29.2d, v20.2d, v20.2d\n" + "trn1 v30.2d, v21.2d, v21.2d\n" + "trn2 v31.2d, v21.2d, v21.2d\n" + + "st1 {v24.8h}, [x11], %[dst_stride]\n" + "st1 {v28.4h}, [x10], %[dst_stride]\n" + "st1 {v26.8h}, [x11], %[dst_stride]\n" + "st1 {v30.4h}, [x10], %[dst_stride]\n" + "st1 {v25.8h}, [x11], %[dst_stride]\n" + "st1 {v29.4h}, [x10], %[dst_stride]\n" + "st1 {v27.8h}, [x11], %[dst_stride]\n" + "st1 {v31.4h}, [x10], %[dst_stride]\n" + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [src_stride] "r"(src_stride), [dst_stride] "r"(dst_stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#endif +} + +inline void Transpose16x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) { + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8h}, [x10], %[src_stride]\n" + "ld1 {v1.8h}, [x10], %[src_stride]\n" + "ld1 {v2.8h}, [x10], %[src_stride]\n" + "ld1 {v3.8h}, [x10], %[src_stride]\n" + "ld1 {v4.8h}, [x10], %[src_stride]\n" + "ld1 {v5.8h}, [x10], %[src_stride]\n" + "ld1 {v6.8h}, [x10], %[src_stride]\n" + "ld1 {v7.8h}, [x10], %[src_stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "ld1 {v8.8h}, [x10], %[src_stride]\n" + "ld1 {v9.8h}, [x10], %[src_stride]\n" + "ld1 {v10.8h}, [x10], %[src_stride]\n" + "ld1 {v11.8h}, [x10], %[src_stride]\n" + "ld1 {v12.8h}, [x10], %[src_stride]\n" + "ld1 {v13.8h}, [x10], %[src_stride]\n" + "ld1 {v14.8h}, [x10], %[src_stride]\n" + "ld1 {v15.8h}, [x10], %[src_stride]\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip1 v16.8h, v8.8h, v9.8h\n" + "zip1 v17.8h, v10.8h, v11.8h\n" + "zip1 v18.8h, v12.8h, v13.8h\n" + "zip1 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "add x10, x11, #16\n" + "st1 {v24.8h}, [x11], %[dst_stride]\n" + "st1 {v28.8h}, [x10], %[dst_stride]\n" + "st1 {v26.8h}, [x11], %[dst_stride]\n" + "st1 {v30.8h}, [x10], %[dst_stride]\n" + "st1 {v25.8h}, [x11], %[dst_stride]\n" + "st1 {v29.8h}, [x10], %[dst_stride]\n" + "st1 {v27.8h}, [x11], %[dst_stride]\n" + "st1 {v31.8h}, [x10], %[dst_stride]\n" + + "zip2 v16.8h, v0.8h, v1.8h\n" + "zip2 v17.8h, v2.8h, v3.8h\n" + "zip2 v18.8h, v4.8h, v5.8h\n" + "zip2 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v16.8h, v8.8h, v9.8h\n" + "zip2 v17.8h, v10.8h, v11.8h\n" + "zip2 v18.8h, v12.8h, v13.8h\n" + "zip2 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], %[dst_stride]\n" + "st1 {v28.8h}, [x10], %[dst_stride]\n" + "st1 {v26.8h}, [x11], %[dst_stride]\n" + "st1 {v30.8h}, [x10], %[dst_stride]\n" + "st1 {v25.8h}, [x11], %[dst_stride]\n" + "st1 {v29.8h}, [x10], %[dst_stride]\n" + "st1 {v27.8h}, [x11], %[dst_stride]\n" + "st1 {v31.8h}, [x10], %[dst_stride]\n" + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [src_stride] "r"(src_stride), [dst_stride] "r"(dst_stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..d2b3d0f75b25fb550dd6855137577f02c60fd3fb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.h @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_PACK_FP16_H_ +#define NNACL_FP16_PACK_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void PackHWCToWHCFp16(const float16_t *src, float16_t *dst, int height, int width, int channel); + +void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count); + +void PackNHWCToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count); + +void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToNHWC8Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWFp32ToNC8HW8Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWFp16ToNC8HW8Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC8HW8ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC8HW8ToNHWCFp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWCToNC8HW8NotAlignedFp16(const float16_t *src, float16_t *dst, const int batch, const int plane, + const int channel); + +void PackNHWCFp32ToNHWC8Fp16(const float *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWCFp32ToC8HWN8Fp16(const float *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWCFp16ToC8HWN8Fp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWC8Fp16ToNHWCFp32(const float16_t *src, float *dst, int batch, int plane, int channel); + +void PackNHWC8ToNHWCFp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel); + +#ifdef ENABLE_ARM82_A32 +void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); + +void Transpose12x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); +#endif + +#ifdef ENABLE_ARM64 +void Transpose4x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); +void Transpose8x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); +void Transpose12x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride); +void Transpose16x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); +#endif + +#ifdef ENABLE_ARM +void PackWeightConvDw3x3Fp16(const void *src, void *dst, int channel); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_PACK_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pad_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pad_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..ecee18ad8ea94891b215ebed6e69319cc23ba19d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pad_fp16.c @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/pad_fp16.h" +#include "nnacl_c/common_func.h" + +void PadFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *output_shape, + const int *paddings, int tid, int thread_num) { + int in[DEFAULT_PAD_NDIMS], out[DEFAULT_PAD_NDIMS]; + for (in[0] = 0; in[0] < input_shape[0]; in[0]++) { + out[0] = in[0] + paddings[0]; + for (in[1] = tid; in[1] < input_shape[1]; in[1] += thread_num) { + out[1] = in[1] + paddings[2]; + for (in[2] = 0; in[2] < input_shape[2]; in[2]++) { + out[2] = in[2] + paddings[4]; + for (in[3] = 0; in[3] < input_shape[3]; in[3]++) { + out[3] = in[3] + paddings[6]; + for (in[4] = 0; in[4] < input_shape[4]; in[4]++) { + out[4] = in[4] + paddings[8]; + float16_t *dst = output_data + Offset6d(output_shape, out) + paddings[10]; + const float16_t *src = input_data + Offset6d(input_shape, in); + memcpy(dst, src, input_shape[5] * sizeof(float16_t)); + } + } + } + } + } +} + +void MirrorPadFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *in_strides, + const int *out_strides, const int *padding, int mirror_offset, int begin, int end) { + for (int i = begin; i < end; ++i) { + output_data[i] = input_data[GetInputFlattenIndex(i, input_shape, in_strides, out_strides, padding, mirror_offset)]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pad_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pad_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..d1666622f8230cfc514bece8501624dc57d944a0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pad_fp16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_PAD_FP16_H_ +#define NNACL_FP16_PAD_FP16_H_ + +#include "nnacl_c/fp32/pad_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PadFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *output_shape, + const int *paddings, int tid, int thread_num); +void MirrorPadFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *in_strides, + const int *out_strides, const int *padding, int mirror_offset, int begin, int end); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_PAD_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..fa72c846b577c0eb886366c347e32fe3c36eedbe --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.c @@ -0,0 +1,305 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/pooling_fp16.h" +#include +#include "nnacl_c/errorcode.h" + +int AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + float16_t min = (float16_t)pooling_args->minf; + float16_t max = (float16_t)pooling_args->maxf; + + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int c8 = channel / C8NUM; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + +#ifdef ENABLE_NEON + MS_FLOAT16X8 min_value = MS_MOVQ_F16(min); + MS_FLOAT16X8 max_value = MS_MOVQ_F16(max); +#endif + + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + for (int batch = 0; batch < pooling_args->output_batch_; batch++) { + const float16_t *src_b_ptr = input_ptr + batch * in_h * in_w * channel; + float16_t *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + + const float16_t *src_plane_ptr = src_b_ptr; + float16_t *dst_plane_ptr = dst_b_ptr + index * channel; + + int real_win_h_start = MSMAX(0, -in_h_index); + int real_win_h_end = MSMIN(win_h, in_h - in_h_index); + int real_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_end = MSMIN(win_w, in_w - in_w_index); + + for (int ci = 0; ci < c8; ci++) { + const float16_t *src_c_ptr = src_plane_ptr + ci * C8NUM; + float16_t *dst_c_ptr = dst_plane_ptr + ci * C8NUM; +#ifdef ENABLE_NEON + MS_FLOAT16X8 tmp_avg = MS_MOVQ_F16(0); +#else + float16_t tmp_avg0 = 0; + float16_t tmp_avg1 = 0; + float16_t tmp_avg2 = 0; + float16_t tmp_avg3 = 0; + float16_t tmp_avg4 = 0; + float16_t tmp_avg5 = 0; + float16_t tmp_avg6 = 0; + float16_t tmp_avg7 = 0; +#endif + int real_count = 0; + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + const float16_t *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_avg = MS_ADDQ_F16(tmp_avg, MS_LDQ_F16(src_win_ptr)); +#else + tmp_avg0 += src_win_ptr[0]; + tmp_avg1 += src_win_ptr[1]; + tmp_avg2 += src_win_ptr[2]; + tmp_avg3 += src_win_ptr[3]; + tmp_avg4 += src_win_ptr[4]; + tmp_avg5 += src_win_ptr[5]; + tmp_avg6 += src_win_ptr[6]; + tmp_avg7 += src_win_ptr[7]; +#endif + ++real_count; + } + } + if (real_count == 0) { + return NNACL_ERR; + } +#ifdef ENABLE_NEON + tmp_avg = MS_DIVQ_F16(tmp_avg, MS_MOVQ_F16((float16_t)real_count)); + MS_STQ_F16(dst_c_ptr, MS_MINQ_F16(MS_MAXQ_F16(tmp_avg, min_value), max_value)); +#else + dst_c_ptr[0] = MSMIN(MSMAX(tmp_avg0 / (float16_t)real_count, min), max); + dst_c_ptr[1] = MSMIN(MSMAX(tmp_avg1 / (float16_t)real_count, min), max); + dst_c_ptr[2] = MSMIN(MSMAX(tmp_avg2 / (float16_t)real_count, min), max); + dst_c_ptr[3] = MSMIN(MSMAX(tmp_avg3 / (float16_t)real_count, min), max); + dst_c_ptr[4] = MSMIN(MSMAX(tmp_avg4 / (float16_t)real_count, min), max); + dst_c_ptr[5] = MSMIN(MSMAX(tmp_avg5 / (float16_t)real_count, min), max); + dst_c_ptr[6] = MSMIN(MSMAX(tmp_avg6 / (float16_t)real_count, min), max); + dst_c_ptr[7] = MSMIN(MSMAX(tmp_avg7 / (float16_t)real_count, min), max); +#endif + } // c8 loop + int channel_s = c8 * C8NUM; + for (int ci = channel_s; ci < channel; ci++) { + const float16_t *src_c_ptr = src_plane_ptr + ci; + float16_t *dst_c_ptr = dst_plane_ptr + ci; + float16_t tmp_avg = 0; + int real_count = 0; + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + const float16_t *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += src_win_ptr[0]; + ++real_count; + } + } + if (real_count == 0) { + return NNACL_ERR; + } + tmp_avg = tmp_avg / (float16_t)real_count; + tmp_avg = fmax(tmp_avg, min); + tmp_avg = fmin(tmp_avg, max); + dst_c_ptr[0] = tmp_avg; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + } // out_batch loop + return NNACL_OK; +} + +void MaxPoolingC8Fp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingComputeParam *pooling_args, + float16_t min, float16_t max, int in_batch_offset, int out_plane_offset, int real_win_h_start, + int real_win_h_end, int real_win_w_start, int real_win_w_end, int in_h_index, int in_w_index) { + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int c8 = channel / C8NUM; +#ifdef ENABLE_NEON + float16x8_t min_value = vdupq_n_f16(min); + float16x8_t max_value = vdupq_n_f16(max); +#endif + for (int j = 0; j < c8; j++) { + int in_channel_offset = in_batch_offset + j * C8NUM; + int out_channel_offset = out_plane_offset + j * C8NUM; +#ifdef ENABLE_NEON + float16x8_t tmp_max = vdupq_n_f16(min); +#else + float16_t tmp_max[8] = {min, min, min, min, min, min, min, min}; +#endif + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmaxq_f16(tmp_max, vld1q_f16(input_ptr + in_offset)); +#else + for (int k = 0; k < C8NUM; k++) { + tmp_max[k] = fmax(tmp_max[k], *(input_ptr + in_offset + k)); + } +#endif + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + tmp_max = vmaxq_f16(tmp_max, min_value); + tmp_max = vminq_f16(tmp_max, max_value); + vst1q_f16(output_ptr + out_channel_offset, tmp_max); +#else + for (int l = 0; l < C8NUM; ++l) { + tmp_max[l] = fmax(tmp_max[l], min); + tmp_max[l] = fmin(tmp_max[l], max); + *(output_ptr + out_channel_offset + l) = tmp_max[l]; + } +#endif + } // c8 loop +} + +void MaxPoolingC4Fp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingComputeParam *pooling_args, + float16_t min, float16_t max, int in_batch_offset, int out_plane_offset, int real_win_h_start, + int real_win_h_end, int real_win_w_start, int real_win_w_end, int in_h_index, int in_w_index) { + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int c8 = channel / C8NUM; + int c8_res = channel % C8NUM; + int c4 = c8_res / C4NUM; +#ifdef ENABLE_NEON + float16x4_t min_value2 = vdup_n_f16(min); + float16x4_t max_value2 = vdup_n_f16(max); +#endif + int c4_offset = c8 * C8NUM; + for (int j = 0; j < c4; j++) { + int in_channel_offset = in_batch_offset + c4_offset + j * C4NUM; + int out_channel_offset = out_plane_offset + c4_offset + j * C4NUM; +#ifdef ENABLE_NEON + float16x4_t tmp_max = vdup_n_f16(min); +#else + float16_t tmp_max[4] = {min, min, min, min}; +#endif + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmax_f16(tmp_max, vld1_f16(input_ptr + in_offset)); +#else + for (int k = 0; k < C4NUM; k++) { + tmp_max[k] = fmax(tmp_max[k], *(input_ptr + in_offset + k)); + } +#endif + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + tmp_max = vmax_f16(tmp_max, min_value2); + tmp_max = vmin_f16(tmp_max, max_value2); + vst1_f16(output_ptr + out_channel_offset, tmp_max); +#else + for (int l = 0; l < C4NUM; ++l) { + tmp_max[l] = fmax(tmp_max[l], min); + tmp_max[l] = fmin(tmp_max[l], max); + output_ptr[out_channel_offset + l] = tmp_max[l]; + } +#endif + } // c4 loop +} +void MaxPoolingC1Fp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingComputeParam *pooling_args, + float16_t min, float16_t max, int in_batch_offset, int out_plane_offset, int real_win_h_start, + int real_win_h_end, int real_win_w_start, int real_win_w_end, int in_h_index, int in_w_index) { + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int c8 = channel / C8NUM; + int c8_res = channel % C8NUM; + int c4 = c8_res / C4NUM; + int channel_s = c8 * C8NUM + c4 * C4NUM; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + float16_t tmp_max = -FLT_MAX; + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = fmax(tmp_max, *(input_ptr + in_offset)); + } // win_w loop + } // win_h loop + tmp_max = fmax(tmp_max, min); + tmp_max = fmin(tmp_max, max); + output_ptr[out_channel_offset] = tmp_max; + } // channel_res loop +} + +void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + float16_t min = (float16_t)pooling_args->minf; + float16_t max = (float16_t)pooling_args->maxf; + + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int output_batch = pooling_args->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + + // input channel is equal to output channel + NNACL_CHECK_ZERO_RETURN(output_w); + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + int real_win_h_start = MSMAX(0, -in_h_index); + int real_win_h_end = MSMIN(win_h, in_h - in_h_index); + int real_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_end = MSMIN(win_w, in_w - in_w_index); + MaxPoolingC8Fp16(input_ptr, output_ptr, pooling_args, min, max, in_batch_offset, out_plane_offset, + real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, in_h_index, in_w_index); + MaxPoolingC4Fp16(input_ptr, output_ptr, pooling_args, min, max, in_batch_offset, out_plane_offset, + real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, in_h_index, in_w_index); + MaxPoolingC1Fp16(input_ptr, output_ptr, pooling_args, min, max, in_batch_offset, out_plane_offset, + real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, in_h_index, in_w_index); + } // real_cal_num loop + } // out_plane loop + } // out_batch loop +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..6f42e09e9e6a9400ab746fdfef6625aed9ce4f37 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_POOLING_FP16_H_ +#define NNACL_FP16_POOLING_FP16_H_ + +#include +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +int AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); + +void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_POOLING_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/power_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/power_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..c2a9d157bcf6c87a1e24ab31ee95864eb15ce2ac --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/power_fp16.c @@ -0,0 +1,117 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/power_fp16.h" +#include "nnacl_c/errorcode.h" + +#if defined(ENABLE_NEON) +float16x8_t OptimizedPowerSimdFp16(float16x8_t x, const void *exponent) { + int tmp = (int)(*(float16_t *)exponent); + int exp = abs(tmp); + float16x8_t result = vmovq_n_f16(1.0f); + while (exp) { + if (exp % 2) { + result *= x; + } + x *= x; + exp = exp / 2; + } + if (tmp >= 0) { + return result; + } + return 1 / result; +} +#endif + +float16_t OptimizedPowerScalarFp16(float16_t x, const void *exponent) { + int tmp = *(float16_t *)exponent; + int exp = abs(tmp); + float16_t result = 1; + while (exp) { + if (exp % 2) { + result *= x; + } + x *= x; + exp = exp / 2; + } + return tmp >= 0 ? result : 1 / result; +} + +void PowerBroadCastFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, + float shift) { + PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; +#if defined(ENABLE_NEON) + PowerSimdFunFp16 PowerSimdFunFp16_ = NULL; +#endif + + if (CheckIntegerFp16(*exponent)) { +#if defined(ENABLE_NEON) + PowerSimdFunFp16_ = OptimizedPowerSimdFp16; +#endif + PowerScalarFunFp16_ = OptimizedPowerScalarFp16; + } else { +#if defined(ENABLE_NEON) + PowerSimdFunFp16_ = StdPowerSimdFp16; +#endif + PowerScalarFunFp16_ = StdPowerScalarFp16; + } + int i = 0; +#ifdef ENABLE_NEON + int len_c8 = DOWN_ROUND(len, C8NUM); + float16x8_t scale_8 = vmovq_n_f16(scale); + float16x8_t shift_8 = vmovq_n_f16(shift); + for (; i < len_c8; i += C8NUM) { + float16x8_t result = PowerSimdFunFp16_(scale_8 * vld1q_f16(input + i) + shift_8, exponent); + vst1q_f16(output + i, result); + } +#endif + for (; i < len; ++i) { + output[i] = PowerScalarFunFp16_(scale * input[i] + shift, exponent); + } +} + +void PowerSingleFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, + float shift) { + int i = 0; + PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; +#ifdef ENABLE_NEON + int len_c8 = DOWN_ROUND(len, C8NUM); + float16x8_t scale_8 = vmovq_n_f16(scale); + float16x8_t shift_8 = vmovq_n_f16(shift); + for (; i < len_c8; i += C8NUM) { + float16x8_t tmp_8 = scale_8 * vld1q_f16(input + i) + shift_8; + for (int j = 0; j < 8; ++j) { + PowerScalarFunFp16_ = CheckIntegerFp16(exponent[i + j]) ? OptimizedPowerScalarFp16 : StdPowerScalarFp16; + output[i + j] = PowerScalarFunFp16_(tmp_8[j], exponent + i + j); + } + } +#endif + for (; i < len; ++i) { + PowerScalarFunFp16_ = CheckIntegerFp16(exponent[i]) ? OptimizedPowerScalarFp16 : StdPowerScalarFp16; + output[i] = PowerScalarFunFp16_(scale * input[i] + shift, exponent + i); + } +} + +int PowerFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, float shift, + bool broadcast) { + if (input == NULL || exponent == NULL || output == NULL) { + return NNACL_NULL_PTR; + } + PowerFunFp16 PowerFunFp16_ = NULL; + PowerFunFp16_ = broadcast ? PowerBroadCastFp16 : PowerSingleFp16; + PowerFunFp16_(input, exponent, output, len, scale, shift); + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/power_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/power_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..a46139e9d486b53c8083b53cf7d393d5338ee6c3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/power_fp16.h @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_POWER_FP16_H_ +#define NNACL_FP16_POWER_FP16_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/pow_parameter.h" + +#if defined(ENABLE_NEON) +typedef float16x8_t (*PowerSimdFunFp16)(float16x8_t x, const void *exponent); +#endif +typedef float16_t (*PowerScalarFunFp16)(float16_t x, const void *exponent); +typedef void (*PowerFunFp16)(const float16_t *, const float16_t *, float16_t *, int, float, float); + +#ifdef __cplusplus +extern "C" { +#endif +static inline bool CheckIntegerFp16(float16_t f) { return floorf(f) == f; } + +static inline float16_t StdPowerScalarFp16(float16_t x, const void *exponent) { + return powf(x, *(float16_t *)exponent); +} + +#if defined(ENABLE_NEON) +static inline float16x8_t StdPowerSimdFp16(float16x8_t x, const void *exponent) { + float16x8_t result; + result[0] = powf(x[0], *(float16_t *)exponent); + result[1] = powf(x[1], *(float16_t *)exponent); + result[2] = powf(x[2], *(float16_t *)exponent); + result[3] = powf(x[3], *(float16_t *)exponent); + result[4] = powf(x[4], *(float16_t *)exponent); + result[5] = powf(x[5], *(float16_t *)exponent); + result[6] = powf(x[6], *(float16_t *)exponent); + result[7] = powf(x[7], *(float16_t *)exponent); + return result; +} +#endif +int PowerFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, float shift, + bool broadcast); +void PowerSingleFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, + float shift); +void PowerBroadCastFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, + float shift); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_POWER_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/prelu_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/prelu_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..0a062adc6b20403424806fea8b49b051ef3b88c9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/prelu_fp16.c @@ -0,0 +1,146 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/prelu_fp16.h" + +#ifdef ENABLE_ARM64 +static inline void PReluFp164x32(const float16_t *in, float16_t *out, const float16_t *cur_slope, size_t step) { + asm volatile( + "mov x10, %[in]\n" + "mov x11, %[out]\n" + "mov x12, %[cur_slope]\n" + "ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12]\n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], %[step]\n" + "fmul v16.8h, v0.8h, v4.8h\n" + "fmul v17.8h, v1.8h, v5.8h\n" + "fmul v18.8h, v2.8h, v6.8h\n" + "fmul v19.8h, v3.8h, v7.8h\n" + "fcmgt v20.8h, v0.8h, #0\n" + "fcmgt v21.8h, v1.8h, #0\n" + "fcmgt v22.8h, v2.8h, #0\n" + "fcmgt v23.8h, v3.8h, #0\n" + "ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10], %[step]\n" + "bif v0.16b, v16.16b, v20.16b\n" + "bif v1.16b, v17.16b, v21.16b\n" + "bif v2.16b, v18.16b, v22.16b\n" + "bif v3.16b, v19.16b, v23.16b\n" + "fmul v8.8h, v24.8h, v4.8h\n" + "fmul v9.8h, v25.8h, v5.8h\n" + "fmul v10.8h, v26.8h, v6.8h\n" + "fmul v11.8h, v27.8h, v7.8h\n" + "st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x11], %[step]\n" + "fcmgt v12.8h, v24.8h, #0\n" + "fcmgt v13.8h, v25.8h, #0\n" + "fcmgt v14.8h, v26.8h, #0\n" + "fcmgt v15.8h, v27.8h, #0\n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], %[step]\n" + "bif v24.16b, v8.16b, v12.16b\n" + "bif v25.16b, v9.16b, v13.16b\n" + "bif v26.16b, v10.16b, v14.16b\n" + "bif v27.16b, v11.16b, v15.16b\n" + "fmul v16.8h, v0.8h, v4.8h\n" + "fmul v17.8h, v1.8h, v5.8h\n" + "fmul v18.8h, v2.8h, v6.8h\n" + "fmul v19.8h, v3.8h, v7.8h\n" + "st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x11], %[step]\n" + "fcmgt v20.8h, v0.8h, #0\n" + "fcmgt v21.8h, v1.8h, #0\n" + "fcmgt v22.8h, v2.8h, #0\n" + "fcmgt v23.8h, v3.8h, #0\n" + "ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10]\n" + "bif v0.16b, v16.16b, v20.16b\n" + "bif v1.16b, v17.16b, v21.16b\n" + "bif v2.16b, v18.16b, v22.16b\n" + "bif v3.16b, v19.16b, v23.16b\n" + "fmul v8.8h, v24.8h, v4.8h\n" + "fmul v9.8h, v25.8h, v5.8h\n" + "fmul v10.8h, v26.8h, v6.8h\n" + "fmul v11.8h, v27.8h, v7.8h\n" + "fcmgt v12.8h, v24.8h, #0\n" + "fcmgt v13.8h, v25.8h, #0\n" + "fcmgt v14.8h, v26.8h, #0\n" + "fcmgt v15.8h, v27.8h, #0\n" + "st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x11], %[step]\n" + "bif v24.16b, v8.16b, v12.16b\n" + "bif v25.16b, v9.16b, v13.16b\n" + "bif v26.16b, v10.16b, v14.16b\n" + "bif v27.16b, v11.16b, v15.16b\n" + "st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x11]\n" + : + : [in] "r"(in), [out] "r"(out), [cur_slope] "r"(cur_slope), [step] "r"(step) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27"); +} +#endif + +void PReluFp16(const float16_t *input, float16_t *output, const float16_t *slope, int start, int end, int channel) { + int i = start; +#ifdef ENABLE_ARM64 + for (; i <= end - C4NUM; i += C4NUM) { + const float16_t *cur_in = input + i * channel; + float16_t *cur_out = output + i * channel; + int j = 0; + for (; j <= channel - C32NUM; j += C32NUM) { + const float16_t *in = cur_in + j; + float16_t *out = cur_out + j; + const float16_t *cur_slope = slope + j; + size_t step = channel * sizeof(float16_t); + PReluFp164x32(in, out, cur_slope, step); + } + for (; j < channel; j++) { + cur_out[j] = (cur_in[j] > 0) ? cur_in[j] : (cur_in[j] * slope[j]); + cur_out[j + channel] = (cur_in[j + channel] > 0) ? cur_in[j + channel] : cur_in[j + channel] * slope[j]; + cur_out[j + 2 * channel] = + (cur_in[j + 2 * channel] > 0) ? cur_in[j + 2 * channel] : (cur_in[j + 2 * channel] * slope[j]); + cur_out[j + 3 * channel] = + (cur_in[j + 3 * channel] > 0) ? cur_in[j + 3 * channel] : (cur_in[j + 3 * channel] * slope[j]); + } + } +#endif + for (; i < end; i++) { + const float16_t *cur_in = input + i * channel; + float16_t *cur_out = output + i * channel; + int j = 0; +#ifdef ENABLE_NEON + for (; j <= channel - C8NUM; j += C8NUM) { + float16x8_t in = vld1q_f16(cur_in + j); + float16x8_t s = vld1q_f16(slope + j); + float16x8_t mul = vmulq_f16(in, s); + uint16x8_t mask = vcleq_f16(in, vmovq_n_f16(0.0f)); + vst1q_f16(cur_out + j, vbslq_f16(mask, mul, in)); + } +#endif + for (; j < channel; j++) { + cur_out[j] = cur_in[j] > 0 ? cur_in[j] : cur_in[j] * slope[j]; + } + } +} + +void PReluShareChannelFp16(const float16_t *input, float16_t *output, float16_t slope, int start, int end) { + int i = start; +#ifdef ENABLE_NEON + float16x8_t zero_data = vdupq_n_f16(0); + float16x8_t slope_data = vdupq_n_f16(slope); + for (; i <= end - C8NUM; i += C8NUM) { + float16x8_t src_tmp = vld1q_f16(input + i); + float16x8_t mul_tmp = vmulq_f16(src_tmp, slope_data); + uint16x8_t mask = vcleq_f16(src_tmp, zero_data); + vst1q_f16(output + i, vbslq_f16(mask, mul_tmp, src_tmp)); + } +#endif + for (; i < end; i++) { + output[i] = input[i] > 0 ? input[i] : input[i] * slope; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/prelu_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/prelu_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..01a12799dcbb0e0c36a71a77feef7835b2acc120 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/prelu_fp16.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_PRELU_FP16_H_ +#define NNACL_FP16_PRELU_FP16_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PReluFp16(const float16_t *input, float16_t *output, const float16_t *slope, int start, int end, int channel); + +void PReluShareChannelFp16(const float16_t *input, float16_t *output, float16_t slope, int start, int end); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_PRELU_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..cf5ab3d6617617b9fa34d69e3bee26927963cf41 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.c @@ -0,0 +1,290 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/fp16/quant_dtype_cast_fp16.h" +#include "nnacl_c/errorcode.h" + +#ifdef ENABLE_ARM64 +void Int8ToFp16_arm64(const int8_t *quant_values, float16_t *dst, float scale, int32_t zp, int size) { + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, #0\n" + "beq 2f\n" + + "dup v20.4s, %w[zp32]\n" + "dup v21.4s, %w[scale]\n" + + "cmp w8, #16\n" + "blt 1f\n" + + "0:\n" + "subs w8, w8, #16\n" + "ld1 {v7.16b}, [%[quant_values]], #16\n" + + "sxtl v8.8h, v7.8b\n" + "sxtl2 v9.8h, v7.16b\n" + + "sxtl v0.4s, v8.4h\n" + "sxtl2 v1.4s, v8.8h\n" + "sxtl v2.4s, v9.4h\n" + "sxtl2 v3.4s, v9.8h\n" + "sub v0.4s, v0.4s, v20.4s\n" + "sub v1.4s, v1.4s, v20.4s\n" + "sub v2.4s, v2.4s, v20.4s\n" + "sub v3.4s, v3.4s, v20.4s\n" + "scvtf v4.4s, v0.4s\n" + "scvtf v5.4s, v1.4s\n" + "scvtf v6.4s, v2.4s\n" + "scvtf v7.4s, v3.4s\n" + + "fmul v0.4s, v4.4s, v21.4s\n" + "fmul v1.4s, v5.4s, v21.4s\n" + "fmul v2.4s, v6.4s, v21.4s\n" + "fmul v3.4s, v7.4s, v21.4s\n" + + "fcvtn v4.4h, v0.4s\n" + "fcvtn2 v4.8h, v1.4s\n" + "fcvtn v5.4h, v2.4s\n" + "fcvtn2 v5.8h, v3.4s\n" + + "st1 {v4.8h, v5.8h}, [%[dst]], #32\n" + "beq 2f\n" + "cmp w8, #16\n" + "bge 0b\n" + + "1:\n" + "ldrsb w9, [%[quant_values]], #1\n" + + "subs w8, w8, #1\n" + "sub w9, w9, %w[zp32]\n" + "scvtf s9, w9\n" + + "fmul s9, s9, s21\n" + "fcvtn v4.4h, v9.4s\n" + "str h4, [%[dst]], #2\n" + "bne 1b\n" + + "2:\n" + + : + : [quant_values] "r"(quant_values), [dst] "r"(dst), [scale] "r"(scale), [zp32] "r"(zp), [size] "r"(size) + : "w8", "w9", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v20", "v21"); +} +#endif + +int DoDequantizeInt8ToFp16(const int8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } +#ifdef ENABLE_ARM64 + Int8ToFp16_arm64(quant_values, real_values, scale, zp, size); +#else + for (int i = 0; i < size; ++i) { + real_values[i] = (quant_values[i] - zp) * scale; + } +#endif + return NNACL_OK; +} + +#ifdef ENABLE_ARM64 +void Fp16ToInt8_arm64(const float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { + const float one = 1.0f; + const float ivs = one / scale; + const int32_t min_value = -128; + const int32_t max_value = 127; + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, wzr\n" + "beq 3f\n" + + "dup v28.4s, %w[ivs]\n" + "dup v29.4s, %w[min_value]\n" + "dup v30.4s, %w[max_value]\n" + + "cmp w8, #32\n" + "blt 2f\n" + "1:\n" // loop 32 + "subs w8, w8, #32\n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%[real_values]], #64\n" + "fcvtl v8.4s, v0.4h\n" + "fcvtl2 v9.4s, v0.8h\n" + "fcvtl v10.4s, v1.4h\n" + "fcvtl2 v11.4s, v1.8h\n" + "fcvtl v12.4s, v2.4h\n" + "fcvtl2 v13.4s, v2.8h\n" + "fcvtl v14.4s, v3.4h\n" + "fcvtl2 v15.4s, v3.8h\n" + + "dup v16.4s, %w[zp]\n" + "dup v17.4s, %w[zp]\n" + "dup v18.4s, %w[zp]\n" + "dup v19.4s, %w[zp]\n" + "dup v20.4s, %w[zp]\n" + "dup v21.4s, %w[zp]\n" + "dup v22.4s, %w[zp]\n" + "dup v23.4s, %w[zp]\n" + "scvtf v16.4s, v16.4s\n" + "scvtf v17.4s, v17.4s\n" + "scvtf v18.4s, v18.4s\n" + "scvtf v19.4s, v19.4s\n" + "scvtf v20.4s, v20.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v22.4s, v22.4s\n" + "scvtf v23.4s, v23.4s\n" + + "fmla v16.4s, v8.4s, v28.4s\n" + "fmla v17.4s, v9.4s, v28.4s\n" + "fmla v18.4s, v10.4s, v28.4s\n" + "fmla v19.4s, v11.4s, v28.4s\n" + "fmla v20.4s, v12.4s, v28.4s\n" + "fmla v21.4s, v13.4s, v28.4s\n" + "fmla v22.4s, v14.4s, v28.4s\n" + "fmla v23.4s, v15.4s, v28.4s\n" + + "fcvtas v8.4s, v16.4s\n" + "fcvtas v9.4s, v17.4s\n" + "fcvtas v10.4s, v18.4s\n" + "fcvtas v11.4s, v19.4s\n" + "fcvtas v12.4s, v20.4s\n" + "fcvtas v13.4s, v21.4s\n" + "fcvtas v14.4s, v22.4s\n" + "fcvtas v15.4s, v23.4s\n" + + "smax v8.4s, v8.4s, v29.4s\n" + "smax v9.4s, v9.4s, v29.4s\n" + "smax v10.4s, v10.4s, v29.4s\n" + "smax v11.4s, v11.4s, v29.4s\n" + "smax v12.4s, v12.4s, v29.4s\n" + "smax v13.4s, v13.4s, v29.4s\n" + "smax v14.4s, v14.4s, v29.4s\n" + "smax v15.4s, v15.4s, v29.4s\n" + "smin v8.4s, v8.4s, v30.4s\n" + "smin v9.4s, v9.4s, v30.4s\n" + "smin v10.4s, v10.4s, v30.4s\n" + "smin v11.4s, v11.4s, v30.4s\n" + "smin v12.4s, v12.4s, v30.4s\n" + "smin v13.4s, v13.4s, v30.4s\n" + "smin v14.4s, v14.4s, v30.4s\n" + "smin v15.4s, v15.4s, v30.4s\n" + + "sqxtn v16.4h, v8.4s\n" + "sqxtn2 v16.8h, v9.4s\n" + "sqxtn v17.4h, v10.4s\n" + "sqxtn2 v17.8h, v11.4s\n" + "sqxtn v18.4h, v12.4s\n" + "sqxtn2 v18.8h, v13.4s\n" + "sqxtn v19.4h, v14.4s\n" + "sqxtn2 v19.8h, v15.4s\n" + "sqxtn v20.8b, v16.8h\n" + "sqxtn2 v20.16b, v17.8h\n" + "sqxtn v21.8b, v18.8h\n" + "sqxtn2 v21.16b, v19.8h\n" + + "st1 {v20.16b, v21.16b}, [%[quant_values]], #32\n" + + "beq 3f\n" + "cmp w8, #32\n" + "bge 1b\n" + + "2:\n" // 1 by 1 + "scvtf s10, %w[zp]\n" + "subs w8, w8, #1\n" + "ldr h0, [%[real_values]], #2\n" + "fcvt s0, h0\n" + "fmul s0, s0, s28\n" + "fadd s0, s0, s10\n" + "fcvtas s0, s0\n" + "smax v0.4s, v0.4s, v29.4s\n" + "smin v0.4s, v0.4s, v30.4s\n" + "sqxtn v0.4h, v0.4s\n" + "sqxtn v0.8b, v0.8h\n" + "st1 {v0.b}[0], [%[quant_values]], #1\n" + "bne 2b\n" + + "3:\n" + : + : [size] "r"(size), [ivs] "r"(ivs), [real_values] "r"(real_values), [quant_values] "r"(quant_values), [zp] "r"(zp), + [min_value] "r"(min_value), [max_value] "r"(max_value) + : "w8", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v28", "v29", "v30"); +} +#endif + +int DoQuantizeFp16ToInt8(const float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } +#ifdef ENABLE_ARM64 + Fp16ToInt8_arm64(real_values, quant_values, scale, zp, size); +#else + const int8_t min_value = -128; + const int8_t max_value = 127; + for (int i = 0; i < size; ++i) { + if (real_values[i] == INFINITY) { + quant_values[i] = max_value; + continue; + } + if (real_values[i] == -INFINITY) { + quant_values[i] = min_value; + continue; + } + float temp = round((float)real_values[i] / scale + zp); + if (temp > max_value) { + quant_values[i] = max_value; + } else if (temp < min_value) { + quant_values[i] = min_value; + } else { + quant_values[i] = (int8_t)temp; + } + } +#endif + return NNACL_OK; +} + +int DoDequantizeUInt8ToFp16(const uint8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size) { + uint8_t zp_ = (uint8_t)zp; + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + real_values[i] = (quant_values[i] - zp_) * scale; + } + return NNACL_OK; +} + +int DoQuantizeFp16ToUInt8(const float16_t *real_values, uint8_t *quant_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + if (isinf((float)real_values[i])) { + quant_values[i] = 255; + continue; + } + float temp = round((float)real_values[i] / scale + zp); + if (temp > 255.0f) { + quant_values[i] = 255; + } else if (temp < 0.0f) { + quant_values[i] = 0; + } else { + quant_values[i] = (uint8_t)temp; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..b45b16a8dfbe1e8fe865ee37165c35f20049fa9b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_QUANTDTYPECAST_FP16_H_ +#define NNACL_FP16_QUANTDTYPECAST_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoDequantizeInt8ToFp16(const int8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size); +int DoQuantizeFp16ToInt8(const float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size); + +int DoDequantizeUInt8ToFp16(const uint8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size); +int DoQuantizeFp16ToUInt8(const float16_t *real_values, uint8_t *quant_values, float scale, int32_t zp, int size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_QUANTDTYPECAST_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/ragged_range_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/ragged_range_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..ce1e26d02f2ab427d1268c22931b27425c42b89a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/ragged_range_fp16.c @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/ragged_range_fp16.h" + +void RaggedRangeFp16(const float16_t *starts, const float16_t *limits, const float16_t *deltas, int *splits, + float16_t *value, const RaggedRangeStruct *param) { + splits[0] = 0; + for (int i = 0; i < param->rows_; i++) { + float16_t start = param->starts_is_scalar_ ? starts[0] : starts[i]; + float16_t limit = param->limits_is_scalar_ ? limits[0] : limits[i]; + float16_t delta = param->deltas_is_scalar_ ? deltas[0] : deltas[i]; + int len = NNACL_MAX((int)ceil((float16_t)(limit - start) / delta), 0); + splits[i + 1] = splits[i] + len; + for (int j = 0; j < len; j++) { + *value++ = start; + start += delta; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/ragged_range_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/ragged_range_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..91088cfead959b147a6cfa6a07892a20db0fa217 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/ragged_range_fp16.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_RAGGED_RANGE_FP16_H_ +#define NNACL_FP16_RAGGED_RANGE_FP16_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/ragged_range.h" + +void RaggedRangeFp16(const float16_t *starts, const float16_t *limits, const float16_t *deltas, int *splits, + float16_t *value, const RaggedRangeStruct *param); + +#endif // NNACL_FP16_RAGGED_RANGE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/range_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/range_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..9eb9d83366319fb4e246fa992876eb580c57c5af --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/range_fp16.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_RANGE_FP16_H_ +#define NNACL_FP16_RANGE_FP16_H_ + +#include "nnacl_c/op_base.h" + +void RangeFp16(float16_t *output_ptr, float16_t start, float16_t delta, int nums) { + for (int i = 0; i < nums; ++i, start += delta) { + output_ptr[i] = start; + } +} + +#endif // NNACL_FP16_RANGE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/reduce_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/reduce_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..5c9f4ac4be2256f614e8b61e9d87453e08c73c22 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/reduce_fp16.c @@ -0,0 +1,198 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "nnacl_c/fp16/reduce_fp16.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +int ReduceMeanFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + if (axis_size == 0) { + return NNACL_ERR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float16_t *outer_src = src_data + j * axis_size * inner_size; + float16_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float16_t *inner_src = outer_src + k; + float16_t *inner_dst = outer_dst + k; + float tmp = 0.0; + for (i = 0; i < axis_size; i++) { + tmp += inner_src[i * inner_size]; + } + *inner_dst = (float16_t)(tmp / axis_size); + } + } + return NNACL_OK; +} + +int ReduceMaxFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float16_t *outer_src = src_data + j * axis_size * inner_size; + float16_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float16_t *inner_src = outer_src + k; + float16_t *inner_dst = outer_dst + k; + float tmp = -FLT_MAX; + for (i = 0; i < axis_size; i++) { + tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceMinFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float16_t *outer_src = src_data + j * axis_size * inner_size; + float16_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float16_t *inner_src = outer_src + k; + float16_t *inner_dst = outer_dst + k; + float16_t tmp = 65504; // fp16 max value + for (i = 0; i < axis_size; i++) { + tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceProdFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float16_t *outer_src = src_data + j * axis_size * inner_size; + float16_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float16_t *inner_src = outer_src + k; + float16_t *inner_dst = outer_dst + k; + float16_t tmp = 1.0f; + for (i = 0; i < axis_size; i++) { + tmp *= inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceSumFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + int stride = UP_DIV(outer_size, thread_num); + int start = stride * tid; + int end = MSMIN(outer_size, start + stride); + int num = end - start; +#ifdef ENABLE_NEON + int block_c8 = inner_size - inner_size % C8NUM; +#endif + + int src_stride = axis_size * inner_size; + src_data += start * src_stride; + dst_data += start * inner_size; + + for (int i = 0; i < num; i++, src_data += src_stride, dst_data += inner_size) { + int j = 0; +#ifdef ENABLE_NEON + for (; j < block_c8; j += C8NUM) { + const float16_t *inner_src = src_data + j; + float16_t *inner_dst = dst_data + j; + float16x8_t tmp = {0, 0, 0, 0, 0, 0, 0, 0}; + for (int k = 0; k < axis_size; k++) { + tmp = vaddq_f16(tmp, vld1q_f16(inner_src + k * inner_size)); + } + vst1q_f16(inner_dst, tmp); + } +#endif + for (; j < inner_size; j++) { + const float16_t *inner_src = src_data + j; + float16_t *inner_dst = dst_data + j; + float tmp = 0.0f; + for (int k = 0; k < axis_size; k++) { + tmp += inner_src[k * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceL2NormFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + int stride = UP_DIV(outer_size, thread_num); + int start = stride * tid; + int end = MSMIN(outer_size, start + stride); + int num = end - start; +#ifdef ENABLE_NEON + int block_c8 = inner_size - inner_size % C8NUM; +#endif + + int src_stride = axis_size * inner_size; + src_data += start * src_stride; + dst_data += start * inner_size; + + for (int i = 0; i < num; i++, src_data += src_stride, dst_data += inner_size) { + int j = 0; +#ifdef ENABLE_NEON + for (; j < block_c8; j += C8NUM) { + const float16_t *inner_src = src_data + j; + float16_t *inner_dst = dst_data + j; + float16x8_t tmp = {0, 0, 0, 0, 0, 0, 0, 0}; + for (int k = 0; k < axis_size; k++) { + float16x8_t src = vld1q_f16(inner_src + k * inner_size); + tmp = MS_FMAQ_F16(tmp, src, src); + } + vst1q_f16(inner_dst, MS_SQRTFX8_F16(tmp)); + } +#endif + for (; j < inner_size; j++) { + const float16_t *inner_src = src_data + j; + float tmp = 0.0f; + for (int k = 0; k < axis_size; k++) { + tmp += inner_src[k * inner_size] * inner_src[k * inner_size]; + } + dst_data[j] = sqrtf(tmp); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/reduce_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/reduce_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..638f76ded436079476da48f975efba6ba337dddf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/reduce_fp16.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_REDUCE_FP16_H_ +#define NNACL_FP16_REDUCE_FP16_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/reduce_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ReduceMeanFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +int ReduceMaxFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +int ReduceMinFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +int ReduceProdFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +int ReduceSumFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +int ReduceL2NormFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_REDUCE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/resize_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/resize_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..a5b0b318c3e33abdd41915098c18e6616dd65d02 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/resize_fp16.c @@ -0,0 +1,380 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp16/resize_fp16.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/errorcode.h" + +void CalculateCoordinateFp16(float16_t out, int in, int *bottom, int *top, float16_t *bottom_weight) { + *bottom = (int)(floorf(out)); + *bottom = *bottom >= 0 ? *bottom : 0; // extrapolate may generate neg value + *top = *bottom + 1 < in ? (*bottom + 1) : (in - 1); + float16_t top_weight = (float16_t)out - (float16_t)(*bottom); + *bottom_weight = 1.0f - top_weight; +} + +static void BicubicBaseFuncFp16(float16_t a, float16_t x, float16_t *weight) { + float16_t abs_x = fabsf(x); + if (abs_x >= 0 && abs_x <= 1) { + *weight = ((a + 2) * abs_x - (a + 3)) * abs_x * abs_x + 1; + } else if (abs_x > 1 && abs_x <= 2) { + *weight = a * abs_x * abs_x * abs_x - 5 * a * abs_x * abs_x + 8 * a * abs_x - 4 * a; + } else { + *weight = 0; + } +} + +// a is a coefficient +// W(x) = { (a + 2) * |x| * |x| * |x| - (a + 3) * |x| * |x| + 1, for |x| <= 1 +// { a * |x| * |x| * |x| - 5 * a * |x| * |x| + 8 * a *|x| - 4 * a, for 1 < |x| < 2 +// { 0, otherwise +// the value of 'a' depends on if is half_pixel_center(the scheme is the same as tf). +// If is half pixel mode, a equals to -0.5, otherwise -0.75. +void CalculateWeightForBicubicFp16(float16_t out, int in, int *index, float16_t *weights, float16_t a) { + int floor_index = (int)(floorf(out)); + index[0] = (floor_index - 1) < 0 ? 0 : (floor_index - 1); + index[1] = floor_index; + index[2] = (floor_index + 1) < in ? (floor_index + 1) : (in - 1); + index[3] = (floor_index + 2) < in ? (floor_index + 2) : (in - 1); + + // get positive value + float16_t distance[4] = {-1, 0, 1, 2}; + float16_t tmp_dis = out - (float16_t)floor_index; + distance[0] -= tmp_dis; + distance[1] -= tmp_dis; + distance[2] -= tmp_dis; + distance[3] -= tmp_dis; + + for (int i = 0; i < 4; ++i) { + BicubicBaseFuncFp16(a, distance[i], &weights[i]); + } +} + +int PrepareResizeBilinearFp16(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate, + int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float16_t *y_bottom_weights, + float16_t *x_left_weights) { + if (input_shape == NULL || output_shape == NULL || y_bottoms == NULL || y_tops == NULL || x_lefts == NULL || + x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_h = input_shape[1]; + int in_w = input_shape[2]; + + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int h = 0; h < new_height; h++) { + float16_t actual_y = calculate(h, in_h, new_height); + CalculateCoordinateFp16(actual_y, in_h, y_bottoms + h, y_tops + h, y_bottom_weights + h); + } + for (int w = 0; w < new_width; w++) { + float16_t actual_x = calculate(w, in_w, new_width); + CalculateCoordinateFp16(actual_x, in_w, x_lefts + w, x_rights + w, x_left_weights + w); + } + return NNACL_OK; +} + +int PrepareResizeBicubicFp16(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate, + int *y_tops, int *x_lefts, float16_t *y_weights, float16_t *x_weights, + float16_t cubic_coeff) { + if (input_shape == NULL || output_shape == NULL || y_tops == NULL || x_lefts == NULL || y_weights == NULL || + x_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int h = 0; h < new_height; h++) { + float16_t actual_y = calculate(h, in_h, new_height); + CalculateWeightForBicubicFp16(actual_y, in_h, y_tops + 4 * h, y_weights + 4 * h, cubic_coeff); + } + for (int w = 0; w < new_width; w++) { + float16_t actual_x = calculate(w, in_w, new_width); + CalculateWeightForBicubicFp16(actual_x, in_w, x_lefts + 4 * w, x_weights + 4 * w, cubic_coeff); + } + return NNACL_OK; +} + +int InterpRowFp16(const float16_t *src_line, float16_t *linear_output, int new_width, const float16_t *x_left_weights, + const int *x_lefts, const int *x_rights, int in_c) { + int w; + for (w = 0; w < new_width; w++) { + int c = 0; +#if defined(ENABLE_NEON) + float16x8_t left_w_8 = vdupq_n_f16(x_left_weights[w]); + float16x8_t right_w_8 = vdupq_n_f16(1.0f - x_left_weights[w]); + for (; c <= in_c - C8NUM; c += C8NUM) { + float16x8_t left = vld1q_f16(src_line + x_lefts[w] * in_c + c); + float16x8_t right = vld1q_f16(src_line + x_rights[w] * in_c + c); + float16x8_t interp_value = vaddq_f16(vmulq_f16(left, left_w_8), vmulq_f16(right, right_w_8)); + vst1q_f16(linear_output + w * in_c + c, interp_value); + } +#endif + int left_w_offset = x_lefts[w] * in_c; + int right_w_offset = x_rights[w] * in_c; + for (; c < in_c; c++) { + float16_t left = src_line[left_w_offset + c]; + float16_t right = src_line[right_w_offset + c]; + linear_output[w * in_c + c] = left * x_left_weights[w] + right * (1.0f - x_left_weights[w]); + } + } + return 0; +} + +int InterpColFp16(const float16_t *bottom_line, const float16_t *top_line, float16_t *output, int new_width, + float16_t y_bottom_weight, int in_c) { + int w; + for (w = 0; w < new_width; w++) { + int c = 0; +#if defined(ENABLE_NEON) + float16x8_t bottom_w_8 = vdupq_n_f16(y_bottom_weight); + float16x8_t top_w_8 = vdupq_n_f16(1.0f - y_bottom_weight); + for (; c <= in_c - C8NUM; c += C8NUM) { + float16x8_t bottom = vld1q_f16(bottom_line + w * in_c + c); + float16x8_t top = vld1q_f16(top_line + w * in_c + c); + float16x8_t interp_value = vaddq_f16(vmulq_f16(bottom, bottom_w_8), vmulq_f16(top, top_w_8)); + vst1q_f16(output + w * in_c + c, interp_value); + } +#endif + for (; c < in_c; c++) { + float16_t bottom = bottom_line[w * in_c + c]; + float16_t top = top_line[w * in_c + c]; + output[w * in_c + c] = bottom * y_bottom_weight + top * (1.0f - y_bottom_weight); + } + } + return 0; +} + +void BilinearFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *output_shape, + const int *y_bottom, const int *y_top, const int *x_left, const int *x_right, + const float16_t *y_bottom_weight, const float16_t *x_left_weight, float16_t *line0, float16_t *line1, + const int h_begin, const int h_end) { + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_width = output_shape[2]; + int h_stride = new_width * in_c; + + bool cache_line_used[2] = {false, false}; + int cache_line_num[2] = {-1, -1}; + float16_t *const cache_line_ptr[2] = {line0, line1}; + float16_t *current_line_ptr[2] = {line0, line1}; + int current_line_num[2] = {-1, -1}; + + for (int h = h_begin; h < h_end; h++) { + current_line_num[0] = y_bottom[h]; + current_line_num[1] = y_top[h]; + + for (int i = 0; i < 2; i++) { + cache_line_used[i] = false; + } + // search if we cached + for (int j = 0; j < 2; j++) { + bool find = false; + for (int k = 0; k < 2; k++) { + if (current_line_num[j] == cache_line_num[k]) { + cache_line_used[k] = true; + current_line_ptr[j] = cache_line_ptr[k]; + find = true; + break; + } + } + + if (!find) { + const float16_t *line = input_data + current_line_num[j] * in_w * in_c; + for (int k = 0; k < 2; k++) { + if (!cache_line_used[k]) { + cache_line_num[k] = current_line_num[j]; + cache_line_used[k] = true; + current_line_ptr[j] = cache_line_ptr[k]; + InterpRowFp16(line, current_line_ptr[j], new_width, x_left_weight, x_left, x_right, in_c); + break; + } + } + } + } + // do col interp + InterpColFp16(current_line_ptr[0], current_line_ptr[1], output_data + h * h_stride, new_width, y_bottom_weight[h], + in_c); + } +} + +int ResizeBilinearFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, const int *y_bottoms, const int *y_tops, const int *x_lefts, + const int *x_rights, const float16_t *y_bottom_weights, const float16_t *x_left_weights, + float16_t *line0, float16_t *line1, const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_bottoms == NULL || + y_tops == NULL || x_lefts == NULL || x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_b = input_shape[0]; + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int b = 0; b < in_b; b++) { + const float16_t *input = input_data + b * in_h * in_w * in_c; + float16_t *output = output_data + b * new_height * new_width * in_c; + BilinearFp16(input, output, input_shape, output_shape, y_bottoms, y_tops, x_lefts, x_rights, y_bottom_weights, + x_left_weights, line0, line1, h_begin, h_end); + } + return NNACL_OK; +} + +void BicubicInterpRowFp16(const float16_t *src, float16_t *dst, const float16_t *weights, const int *lefts, int width, + int channel) { + for (int w = 0; w < width; w++) { + const float16_t *weight = weights + 4 * w; + float16_t *dst_w = dst + w * channel; + const float16_t *src0_w = src + lefts[4 * w] * channel; + const float16_t *src1_w = src + lefts[4 * w + 1] * channel; + const float16_t *src2_w = src + lefts[4 * w + 2] * channel; + const float16_t *src3_w = src + lefts[4 * w + 3] * channel; + int c = 0; +#if defined(ENABLE_NEON) + float16x8_t weight0_vec_8 = vdupq_n_f16(weight[0]); + float16x8_t weight1_vec_8 = vdupq_n_f16(weight[1]); + float16x8_t weight2_vec_8 = vdupq_n_f16(weight[2]); + float16x8_t weight3_vec_8 = vdupq_n_f16(weight[3]); + for (; c <= channel - C8NUM; c += C8NUM) { + float16x8_t src0_vec = vld1q_f16(src0_w + c); + float16x8_t src1_vec = vld1q_f16(src1_w + c); + float16x8_t src2_vec = vld1q_f16(src2_w + c); + float16x8_t src3_vec = vld1q_f16(src3_w + c); + float16x8_t dst0 = vmulq_f16(src0_vec, weight0_vec_8); + float16x8_t dst1 = vmulq_f16(src1_vec, weight1_vec_8); + float16x8_t dst2 = vmulq_f16(src2_vec, weight2_vec_8); + float16x8_t dst3 = vmulq_f16(src3_vec, weight3_vec_8); + float16x8_t interp_value = vaddq_f16(dst3, vaddq_f16(dst2, vaddq_f16(dst1, dst0))); + vst1q_f16(dst_w + c, interp_value); + } +#endif + for (; c < channel; c++) { + dst_w[c] = src0_w[c] * weight[0] + src1_w[c] * weight[1] + src2_w[c] * weight[2] + src3_w[c] * weight[3]; + } + } +} + +void BicubicInterpColFp16(const float16_t *src, float16_t *dst, const float16_t *weights, int width, int channel) { + const float16_t *src0 = src; + const float16_t *src1 = src + width * channel; + const float16_t *src2 = src + 2 * width * channel; + const float16_t *src3 = src + 3 * width * channel; + for (int w = 0; w < width; w++) { + float16_t *dst_w = dst + w * channel; + const float16_t *src0_w = src0 + w * channel; + const float16_t *src1_w = src1 + w * channel; + const float16_t *src2_w = src2 + w * channel; + const float16_t *src3_w = src3 + w * channel; + int c = 0; +#ifdef ENABLE_NEON + float16x8_t weight0_vec_8 = vdupq_n_f16(weights[0]); + float16x8_t weight1_vec_8 = vdupq_n_f16(weights[1]); + float16x8_t weight2_vec_8 = vdupq_n_f16(weights[2]); + float16x8_t weight3_vec_8 = vdupq_n_f16(weights[3]); + for (; c <= channel - C8NUM; c += C8NUM) { + float16x8_t src0_vec = vld1q_f16(src0_w + c); + float16x8_t src1_vec = vld1q_f16(src1_w + c); + float16x8_t src2_vec = vld1q_f16(src2_w + c); + float16x8_t src3_vec = vld1q_f16(src3_w + c); + float16x8_t dst1 = vmulq_f16(src0_vec, weight0_vec_8); + float16x8_t dst2 = vmulq_f16(src1_vec, weight1_vec_8); + float16x8_t dst3 = vmulq_f16(src2_vec, weight2_vec_8); + float16x8_t dst4 = vmulq_f16(src3_vec, weight3_vec_8); + float16x8_t interp_value = vaddq_f16(dst4, vaddq_f16(dst3, vaddq_f16(dst1, dst2))); + vst1q_f16(dst_w + c, interp_value); + } +#endif + for (; c < channel; c++) { + dst_w[c] = src0_w[c] * weights[0] + src1_w[c] * weights[1] + src2_w[c] * weights[2] + src3_w[c] * weights[3]; + } + } +} + +void BicubicFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *output_shape, + const int *y_tops, const int *x_lefts, const float16_t *y_weights, const float16_t *x_weights, + float16_t *line_buffer, const int h_begin, const int h_end) { + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_width = output_shape[2]; + int h_stride = new_width * in_c; + + for (int h = h_begin; h < h_end; h++) { + for (int i = 0; i < 4; ++i) { + BicubicInterpRowFp16(input_data + y_tops[4 * h + i] * in_w * in_c, line_buffer + i * h_stride, x_weights, x_lefts, + new_width, in_c); + } + BicubicInterpColFp16(line_buffer, output_data + h * h_stride, y_weights + 4 * h, new_width, in_c); + } +} + +int ResizeBicubicFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, const int *y_tops, const int *x_lefts, const float16_t *y_weights, + const float16_t *x_weights, float16_t *line_buffer, const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_tops == NULL || + x_lefts == NULL || y_weights == NULL || x_weights == NULL) { + return NNACL_NULL_PTR; + } + int input_cube_per_batch = input_shape[1] * input_shape[2] * input_shape[3]; + int output_cube_per_batch = output_shape[1] * output_shape[2] * input_shape[3]; + for (int b = 0; b < input_shape[0]; b++) { + const float16_t *input = input_data + b * input_cube_per_batch; + float16_t *output = output_data + b * output_cube_per_batch; + BicubicFp16(input, output, input_shape, output_shape, y_tops, x_lefts, y_weights, x_weights, line_buffer, h_begin, + h_end); + } + return NNACL_OK; +} + +int ResizeNearestNeighborFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, CalculateOriginalCoordinate calculate, + int coordinate_transform_mode, int tid, int thread_num) { + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int c = input_shape[3]; + bool align_corners = coordinate_transform_mode == 1; + for (int batch = 0; batch < output_shape[0]; batch++) { + for (int y = tid; y < output_shape[1]; y += thread_num) { + float16_t actual_y = calculate(y, input_shape[1], output_shape[1]); + int input_y; + if (align_corners) { + input_y = (int)(roundf(actual_y)); + } else { + input_y = (int)(floorf(actual_y)); + } + for (int x = 0; x < output_shape[2]; x++) { + float16_t actual_x = calculate(x, input_shape[2], output_shape[2]); + int input_x; + if (align_corners) { + input_x = (int)(roundf(actual_x)); + } else { + input_x = (int)(floorf(actual_x)); + } + int in_offset = Offset(input_shape, batch, input_y, input_x, 0); + int out_offset = Offset(output_shape, batch, y, x, 0); + memcpy(output_data + out_offset, input_data + in_offset, c * sizeof(float16_t)); + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/resize_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/resize_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..a0fecbe15e576b3210cc12f4eeeb6509a70b5575 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/resize_fp16.h @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_RESIZE_FP16_H_ +#define NNACL_FP16_RESIZE_FP16_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/resize_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/crop_parameter.h" +#include "nnacl_c/fp32/resize_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PrepareResizeBilinearFp16(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate, + int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float16_t *y_bottom_weights, + float16_t *x_left_weights); + +int PrepareResizeBicubicFp16(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate, + int *y_tops, int *x_lefts, float16_t *y_weights, float16_t *x_weights, + float16_t cubic_coeff); + +int ResizeBilinearFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, const int *y_bottoms, const int *y_tops, const int *x_lefts, + const int *x_rights, const float16_t *y_bottom_weights, const float16_t *x_left_weights, + float16_t *line0, float16_t *line1, const int h_begin, const int h_end); + +int ResizeBicubicFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, const int *y_tops, const int *x_lefts, const float16_t *y_weights, + const float16_t *x_weights, float16_t *line_buffer, const int h_begin, const int h_end); + +int ResizeNearestNeighborFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, CalculateOriginalCoordinate calculate, + int coordinate_transform_mode, int tid, int thread_num); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_RESIZE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/scale_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/scale_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..579aa6d5d95ec06012eaec88d410dfec17516a60 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/scale_fp16.c @@ -0,0 +1,226 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/scale_fp16.h" + +void Fp16ScaleInner(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size, int inner_size) { + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_NEON + for (; in_index < inner_size - 8; in_index += 8) { + int in_offset = axis_offset + in_index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vdupq_n_f16(scale[i]); + float16x8_t offset_8 = vdupq_n_f16(offset[i]); + float16x8_t result = vfmaq_f16(offset_8, data, scale_8); + + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + out_data[in_offset] = in_data[in_offset] * scale[i] + offset[i]; + } + } + } +} + +void Fp16ScaleAxis(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size) { + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_NEON + for (; index < axis_size - 8; index += 8) { + int in_offset = out_offset + index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vld1q_f16(scale + index); + float16x8_t offset_8 = vld1q_f16(offset + index); + float16x8_t result = vfmaq_f16(offset_8, data, scale_8); + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + out_data[in_offset] = in_data[in_offset] * scale[index] + offset[index]; + } + } +} + +void DoScaleFp16(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + Fp16ScaleAxis(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + Fp16ScaleInner(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} + +void Fp16ScaleInnerRelu(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size, int inner_size) { +#ifdef ENABLE_NEON + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_NEON + for (; in_index < inner_size - 8; in_index += 8) { + int in_offset = axis_offset + in_index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vdupq_n_f16(scale[i]); + float16x8_t offset_8 = vdupq_n_f16(offset[i]); + float16x8_t tmp = vfmaq_f16(offset_8, data, scale_8); + float16x8_t result = vmaxq_f16(tmp, zeros); + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + float tmp = in_data[in_offset] * scale[i] + offset[i]; + out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; + } + } + } +} + +void Fp16ScaleAxisRelu(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size) { +#ifdef ENABLE_NEON + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_NEON + for (; index < axis_size - 8; index += 8) { + int in_offset = out_offset + index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vld1q_f16(scale + index); + float16x8_t offset_8 = vld1q_f16(offset + index); + float16x8_t tmp = vfmaq_f16(offset_8, data, scale_8); + float16x8_t result = vmaxq_f16(tmp, zeros); + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + float tmp = in_data[in_offset] * scale[index] + offset[index]; + out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; + } + } +} + +void Fp16DoScaleRelu(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + Fp16ScaleAxisRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + Fp16ScaleInnerRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} + +void Fp16ScaleInnerRelu6(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size, int inner_size) { +#ifdef ENABLE_NEON + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; + float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_NEON + for (; in_index < inner_size - 8; in_index += 8) { + int in_offset = axis_offset + in_index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vdupq_n_f16(scale[i]); + float16x8_t offset_8 = vdupq_n_f16(offset[i]); + float16x8_t tmp = vfmaq_f16(offset_8, data, scale_8); + float16x8_t result = vminq_f16(vmaxq_f16(tmp, zeros), bounds); + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + float tmp = in_data[in_offset] * scale[i] + offset[i]; + out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); + } + } + } +} + +void Fp16ScaleAxisRelu6(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size) { +#ifdef ENABLE_NEON + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; + float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_NEON + for (; index < axis_size - 8; index += 8) { + int in_offset = out_offset + index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vld1q_f16(scale + index); + float16x8_t offset_8 = vld1q_f16(offset + index); + float16x8_t tmp = vfmaq_f16(offset_8, data, scale_8); + float16x8_t result = vminq_f16(vmaxq_f16(tmp, zeros), bounds); + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + float tmp = in_data[in_offset] * scale[index] + offset[index]; + out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); + } + } +} + +void DoScaleRelu6Fp16(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + Fp16ScaleAxisRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + Fp16ScaleInnerRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/scale_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/scale_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..516bfcb978fd88b5b26922d7f742d86a4d4c609d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/scale_fp16.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_SCALE_FP16_H_ +#define NNACL_FP16_SCALE_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/kernel/scale.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DoScaleFp16(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param); +void Fp16DoScaleRelu(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param); +void DoScaleRelu6Fp16(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_SCALE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/softmax_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/softmax_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..58b4ac9185c0a26f98bdbadbb01b62976cde81a8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/softmax_fp16.c @@ -0,0 +1,134 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/softmax_fp16.h" +#include +#include "nnacl_c/fp16/exp_fp16.h" + +void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channel) { + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + int j = 0; +#ifdef ENABLE_NEON + float16x8_t max_8 = vdupq_n_f16(-FLT16_MAX); + int count = (channel / C8NUM) * C8NUM; + for (; j < count; j += C8NUM) { + float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + j); + max_8 = vmaxq_f16(max_8, input_8); + } + float16_t max = MS_MAXVQ_F16(max_8); +#else + float16_t max = -FLT_MAX; +#endif + for (; j < channel; j++) { + float16_t input = src[cur_batch_offset + j]; + if (input > max) { + max = input; + } + } + int k = 0; +#ifdef ENABLE_NEON + int count2 = (channel / C8NUM) * C8NUM; + for (; k < count2; k += C8NUM) { + float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + k); + float16x8_t output_8 = vsubq_f16(input_8, vdupq_n_f16(max)); + vst1q_f16(dst + cur_batch_offset + k, output_8); + } +#endif + for (; k < channel; k++) { + int offset = cur_batch_offset + k; + dst[offset] = src[offset] - max; + } + } +} + +void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel) { + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + float16_t sum = 0.0f; + int j = 0; +#ifdef ENABLE_NEON + float16x8_t sum8 = vdupq_n_f16(0); + int count = (channel / C8NUM) * C8NUM; + for (; j < count; j += C8NUM) { + sum8 = vaddq_f16(sum8, vld1q_f16(src + cur_batch_offset + j)); + } + sum = sum8[0] + sum8[1] + sum8[2] + sum8[3] + sum8[4] + sum8[5] + sum8[6] + sum8[7]; +#endif + for (; j < channel; j++) { + sum += src[cur_batch_offset + j]; + } + int k = 0; +#ifdef ENABLE_NEON + const float16_t div = 1.0f / sum; + for (; k < count; k += C8NUM) { + vst1q_f16(dst + cur_batch_offset + k, vmulq_n_f16(vld1q_f16(src + cur_batch_offset + k), div)); + } +#endif + for (; k < channel; k++) { + dst[cur_batch_offset + k] = src[cur_batch_offset + k] / sum; + } + } +} + +void SoftmaxLastAxisFp16(const float16_t *src, float16_t *dst, int batch, int channel) { + SoftmaxNormFp16(src, dst, batch, channel); + ExpFp16(dst, dst, batch * channel); + SumAndDivFp16(dst, dst, batch, channel); +} + +// output = exp(input) / reduce_sum(exp(input), axis) +void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, int axis, int n_dim, + const int *input_shape) { + int inner_size = 1; + int outter_size = 1; + + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + float16_t max_data = input_ptr[inner_offset]; + sum_data[k + sum_outter_offset] = 0; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = expf(input_ptr[axis_offset] - max_data); + sum_data[k + sum_outter_offset] += output_ptr[axis_offset]; + } + } + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[k + sum_outter_offset]; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/softmax_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/softmax_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..cf5cf43e9ff38ea8d645b4f4cfa4ed0f033f1e44 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/softmax_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_SOFTMAX_FP16_H_ +#define NNACL_FP16_SOFTMAX_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/softmax_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channel); +void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, int axis, int n_dim, + const int *input_shape); +void SoftmaxLastAxisFp16(const float16_t *src, float16_t *dst, int batch, int channel); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_SOFTMAX_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/sparse_to_dense_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/sparse_to_dense_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..3a2c66397a65bb25d3a8b4259e02b4955cfdf13d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/sparse_to_dense_fp16.c @@ -0,0 +1,78 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/sparse_to_dense_fp16.h" +#include "nnacl_c/errorcode.h" + +int SparseToDenseSetDefaultFp16(float16_t *output, float16_t default_value, SparseToDenseParameter *param, + int task_id) { + if (output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->output_num, param->op_parameter_.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->output_num); + for (int i = begin; i < end; i++) { + output[i] = default_value; + } + return NNACL_OK; +} + +int SparseToDenseFp16(int *indices_vec, const float16_t *sparse_values, float16_t default_value, float16_t *output, + SparseToDenseParameter *param, int task_id) { + if (indices_vec == NULL || sparse_values == NULL || output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter_.thread_num_) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->index_num, param->op_parameter_.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->index_num); + + int stride0 = param->output_stride[0]; + int stride1 = param->output_stride[1]; + int stride2 = param->output_stride[2]; + + if (param->validate_indices_ == true) { + int index_before = -1; + for (int i = begin; i < end; i++) { + int *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + if (index <= index_before) { + return NNACL_ERR; + } + index_before = index; + } + } + + if (param->is_scalar == true) { + for (int i = begin; i < end; i++) { + int *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + output[index] = sparse_values[0]; + } + } else { + for (int i = begin; i < end; i++) { + int *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + output[index] = sparse_values[i]; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/sparse_to_dense_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/sparse_to_dense_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..7f225dbaecc096cac90607d0baf513f3b27138d1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/sparse_to_dense_fp16.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_SPARSE_TO_DENSE_FP16_H_ +#define NNACL_FP16_SPARSE_TO_DENSE_FP16_H_ + +#include "nnacl_c/sparse_to_dense_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SparseToDenseSetDefaultFp16(float16_t *output, float16_t default_value, SparseToDenseParameter *param, int task_id); +int SparseToDenseFp16(int *indices_vec, const float16_t *sparse_values, float16_t default_value, float16_t *output, + SparseToDenseParameter *param, int task_id); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_SPARSE_TO_DENSE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/splice_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/splice_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..d4718e37589c280dc1d0b6cd5017a9c43bb3a38f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/splice_fp16.c @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/splice_fp16.h" +void SpliceFp16(const float16_t *src_data, int src_row, int src_col, const SpliceParameter *param, float16_t *dst_data, + int dst_row, int dst_col) { + int forward_index = 0; + for (int r = 0; r < dst_row; ++r) { + float16_t *dst_row_data = dst_data + r * dst_col; + for (int off = 0; off < param->context_dim_; ++off) { + int r_off = param->forward_indexes_[forward_index]; + forward_index++; + const float16_t *tmp_src_data = src_data + r_off * src_col; + float16_t *tmp_dst_data = dst_row_data + off * src_col; + memcpy(tmp_dst_data, tmp_src_data, (size_t)(src_col) * sizeof(float16_t)); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/splice_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/splice_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..2c2272fb2bc3096c340947d654685c2dbfba3732 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/splice_fp16.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_SPLICE_FP16_H_ +#define NNACL_FP16_SPLICE_FP16_H_ +#include +#include "nnacl_c/splice_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif + +void SpliceFp16(const float16_t *src_data, int src_row, int src_col, const SpliceParameter *param, float16_t *dst_data, + int dst_row, int dst_col); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_SPLICE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/topk_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/topk_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..feeeddc9b813d80fc1cae4812855cc676a4b2140 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/topk_fp16.c @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/topk_fp16.h" + +int TopkFp16DescendCmp(const void *a, const void *b) { + float16_t sub = ((const TopkFp16Node *)b)->element - ((const TopkFp16Node *)a)->element; + if (sub > 0) { + return 1; + } else if (sub < 0) { + return -1; + } + if (((const TopkFp16Node *)a)->index > ((const TopkFp16Node *)b)->index) { + return 1; + } else { + return -1; + } +} + +int TopkFp16IndexSortCmp(const void *a, const void *b) { + if (((const TopkFp16Node *)a)->index > ((const TopkFp16Node *)b)->index) { + return 1; + } else { + return -1; + } +} + +void TopkFp16(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter) { + int dim_size = parameter->dim_size_; + int outer_loop_num = parameter->outer_loop_num_; + int inner_loop_num = parameter->inner_loop_num_; + int k = parameter->k_; + TopkFp16Node *top_map = (TopkFp16Node *)parameter->topk_node_list_; + + float16_t *cur_input_data = (float16_t *)input_data; + float16_t *cur_output_data = (float16_t *)output_data; + int32_t *cur_output_index = output_index; + for (int i = 0; i < outer_loop_num; i++) { + int in_offset = i * dim_size * inner_loop_num; + int out_offset = i * k * inner_loop_num; + for (int j = 0; j < inner_loop_num; j++) { + for (int m = 0; m < dim_size; m++) { + int offset = in_offset + m * inner_loop_num + j; + top_map[m].element = *(cur_input_data + offset); + top_map[m].index = m; + } + qsort(top_map, dim_size, sizeof(top_map[0]), TopkFp16DescendCmp); + if (!parameter->sorted_) { + qsort(top_map, k, sizeof(top_map[0]), TopkFp16IndexSortCmp); + } + for (int m = 0; m < k; m++) { + int offset = out_offset + m * inner_loop_num + j; + cur_output_data[offset] = top_map[m].element; + cur_output_index[offset] = top_map[m].index; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/topk_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/topk_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..4054851d7782edfbe3f9f072bd208d9135a67c51 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/topk_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_TOPK_FP16_H_ +#define NNACL_FP16_TOPK_FP16_H_ + +#include "nnacl_c/fp32/topk_fp32.h" +#include "nnacl_c/op_base.h" + +typedef struct TopkFp16Node { + float16_t element; + int32_t index; +} TopkFp16Node; + +#ifdef __cplusplus +extern "C" { +#endif +void TopkFp16(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_TOPK_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/transpose_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/transpose_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..d211eda100f7bb17758dc4980e6845737926efaa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/transpose_fp16.c @@ -0,0 +1,257 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/transpose_fp16.h" +#include +#include "nnacl_c/errorcode.h" + +void Fp16TransposeDim2(const float16_t *in_data, float16_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * output1; + int stride0_i = i * 1 * stride0; + for (int j = 0; j < output1; ++j) { + out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; + } + } +} + +void Fp16TransposeDim3(const float16_t *in_data, float16_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; + } + } + } +} + +void Fp16TransposeDim4(const float16_t *in_data, float16_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = + in_data[stride0_i + stride1_j + stride2_k + m * stride3]; + } + } + } + } +} + +void Fp16TransposeDim5(const float16_t *in_data, float16_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; + } + } + } + } + } +} + +void Fp16TransposeDim6(const float16_t *in_data, float16_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int stride5 = strides[perm[5]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int out_stride4 = out_strides[4]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + const int output5 = output_shape[5]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + int out_stride4_n = n * out_stride4; + int stride4_n = n * stride4; + for (int g = 0; g < output5; ++g) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n + g] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_n + g * stride5]; + } + } + } + } + } + } +} + +void TransposeDimsFp16(const void *in, void *out, const int *output_shape, int *perm, int *strides, int *out_strides, + int num_axes, int task_id, int thread_num) { + const float16_t *in_data = (const float16_t *)in; + float16_t *out_data = (float16_t *)out; + + NNACL_CHECK_NULL_RETURN_VOID(in_data); + NNACL_CHECK_NULL_RETURN_VOID(out_data); + NNACL_CHECK_NULL_RETURN_VOID(output_shape); + NNACL_CHECK_NULL_RETURN_VOID(perm); + NNACL_CHECK_NULL_RETURN_VOID(strides); + NNACL_CHECK_NULL_RETURN_VOID(out_strides); + NNACL_CHECK_ZERO_RETURN(thread_num); + + size_t data_size = (*out_strides) * output_shape[0]; + size_t offset_size = UP_DIV(data_size, thread_num); + size_t task_offset = offset_size * task_id; + int count = data_size - task_offset; + if (count <= 0) { + return; + } + count = MSMIN(offset_size, count); + for (size_t idx = task_offset; idx < task_offset + count; ++idx) { + int pos = idx; + int output_idx = 0; + int input_idx = 0; + for (int i = 0; i < num_axes; ++i) { + NNACL_CHECK_ZERO_RETURN(*(out_strides + i)); + int position = pos / *(out_strides + i); + int out_stride = i < num_axes - 1 ? out_strides[i] : 1; + output_idx += (position * out_stride); + input_idx += (position * strides[perm[i]]); + pos -= position * (*(out_strides + i)); + } + out_data[output_idx] = in_data[input_idx]; + } +} + +int DoTransposeFp16(const void *in, void *out, const int *output_shape, int *perm, int *strides, int *out_strides, + int data_size, int num_axes) { + const float16_t *in_data = (const float16_t *)in; + float16_t *out_data = (float16_t *)out; + + NNACL_CHECK_NULL_RETURN_ERR(in_data); + NNACL_CHECK_NULL_RETURN_ERR(out_data); + NNACL_CHECK_NULL_RETURN_ERR(output_shape); + NNACL_CHECK_NULL_RETURN_ERR(perm); + NNACL_CHECK_NULL_RETURN_ERR(strides); + NNACL_CHECK_NULL_RETURN_ERR(out_strides); + + // check if transpose is needed + bool needTranspose = false; + for (int i = 1; i < num_axes; ++i) { + if (perm[i] - perm[i - 1] != 1) { + needTranspose = true; + break; + } + } + + if (!needTranspose) { + (void)memcpy(out_data, in_data, data_size); + return NNACL_OK; + } + for (int i = 0; i < num_axes; ++i) { + if (perm[i] < 0) { + return NNACL_PARAM_INVALID; + } + } + if (num_axes == 2) { + Fp16TransposeDim2(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 3) { + Fp16TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 4) { + Fp16TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 5) { + Fp16TransposeDim5(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 6) { + Fp16TransposeDim6(in_data, out_data, strides, out_strides, perm, output_shape); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/transpose_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/transpose_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..36ba018b117bb16f8629308a667e65050180b868 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/transpose_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_TRANSPOSE_FP16_H_ +#define NNACL_FP16_TRANSPOSE_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/transpose_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void TransposeDimsFp16(const void *src, void *dst, const int *output_shape, int *perm, int *strides, int *out_strides, + int num_axes, int task_id, int thread_num); +int DoTransposeFp16(const void *src, void *dst, const int *output_shape, int *perm, int *strides, int *out_strides, + int data_size, int num_axes); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_TRANSPOSE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/unique_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/unique_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..cb876aa955b6136d5d0c1475c3602a727fbb1d03 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/unique_fp16.c @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/unique_fp16.h" + +int FindFp16(const float16_t *array, int len, float16_t target) { + for (int i = 0; i < len; ++i) { + if (array[i] == target) { + return i; + } + } + return -1; +} + +void UniqueFp16(const float16_t *input, int input_len, float16_t *output0, int *output0_len, int *output1) { + *output0_len = 0; + for (int i = 0; i < input_len; i++) { + int idx = FindFp16(output0, *output0_len, input[i]); + if (idx != -1) { + *output1++ = idx; + } else { + output0[(*output0_len)++] = input[i]; + *output1++ = *output0_len - 1; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/unique_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/unique_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..c5d7defaa8279e54cd69e4686b734f9cefd38eab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/unique_fp16.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_UNIQUE_FP16_H +#define NNACL_FP16_UNIQUE_FP16_H + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void UniqueFp16(const float16_t *input, int input_len, float16_t *output0, int *output0_len, int *output1); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_UNIQUE_FP16_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/utils_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/utils_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..0ff524c6b23a24351f7d6d3673c54b293088678b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/utils_fp16.c @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/utils_fp16.h" +#include "nnacl_c/fp16/common_func_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/tensor_c_utils.h" + +void *GetOrAllocFp16Data(TensorC *t, ExecEnv *env, bool cast) { + NNACL_CHECK_NULL_RETURN_NULL(t); + if (t->data_type_ == kNumberTypeFloat16) { + return t->data_; + } + if (t->data_type_ == kNumberTypeFloat32) { + int ele_num = NNACLGetElementNum(t); + void *fp16_data = env->Alloc(env->allocator_, ele_num * sizeof(float16_t)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(fp16_data); + if (cast) { + Float32ToFloat16((float *)t->data_, (float16_t *)fp16_data, ele_num); + } + return fp16_data; + } + return NULL; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/utils_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/utils_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..223f9232468ac466f44c3ea5a97c6a7e8c640703 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/utils_fp16.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_UTILS_FP16_H_ +#define NNACL_FP16_UTILS_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +void *GetOrAllocFp16Data(TensorC *t, ExecEnv *env, bool cast); + +#endif // NNACL_FP16_UTILS_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/where_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/where_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..eb458e26d2e5d33753710e07cc58273359f9ee72 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/where_fp16.c @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/where_fp16.h" +#include "nnacl_c/common_func.h" + +void WhereWithTripleInputsFp16(const float16_t *x, const float16_t *y, float16_t *output, const WhereArgs *param, + int task_id, int thread_num) { + const bool *condition = param->condition_; + int stride = UP_DIV(param->max_num_, thread_num); + int begin = task_id * stride; + int end = MSMIN(begin + stride, param->max_num_); + + for (int i = begin; i < end; ++i) { + bool cond = condition[param->condition_num_ > 1 ? i : 0]; + if (cond) { + output[i] = x[param->x_num_ > 1 ? i : 0]; + } else { + output[i] = y[param->y_num_ > 1 ? i : 0]; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/where_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/where_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..6d927e0ea3cf7469679939d55b760bda0e9aa605 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/where_fp16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_WHERE_FP16_H_ +#define NNACL_FP16_WHERE_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/where_parameter.h" +#include "nnacl_c/kernel/where.h" + +#ifdef __cplusplus +extern "C" { +#endif +void WhereWithTripleInputsFp16(const float16_t *x, const float16_t *y, float16_t *output, const WhereArgs *param, + int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_WHERE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..452d8ab411a74dcd74d77a15aac8755b07c2a578 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.c @@ -0,0 +1,360 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/winograd_transform_fp16.h" + +void PrepareTransInputFp16(const float16_t *src_data, float16_t *dst_data, int interval_x_s, int interval_x_e, + int interval_y_s, int interval_y_e, int real_c, const ConvParameter *conv_param) { + int input_unit = conv_param->input_unit_; + int in_channel = conv_param->input_channel_; + int input_w = conv_param->input_w_; + + // clear tmp buffer + if (interval_x_e - interval_x_s != input_unit || interval_y_e - interval_y_s != input_unit) { + memset(dst_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t)); + } + + // get real input block with padding + if (real_c == C8NUM) { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * C8NUM; + const float16_t *src_addr = src_data + src_x_offset; + float16_t *dst_addr = dst_data + dst_x_offset; +#ifdef ENABLE_NEON + vst1q_f16(dst_addr, vld1q_f16(src_addr)); +#else + for (int k = 0; k < C8NUM; k++) { + dst_addr[k] = src_addr[k]; + } +#endif + } + } + } else if (real_c < 8 && real_c >= 4) { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * C8NUM; + const float16_t *src_addr = src_data + src_x_offset; + float16_t *dst_addr = dst_data + dst_x_offset; + int rc = real_c - 4; +#ifdef ENABLE_NEON + vst1_f16(dst_addr, vld1_f16(src_addr)); +#else + for (int k = 0; k < C4NUM; k++) { + dst_addr[k] = src_addr[k]; + } +#endif + src_addr += 4; + dst_addr += 4; + for (int i = 0; i < rc; ++i) { + dst_addr[i] = src_addr[i]; + } + } + } + } else { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * C8NUM; + const float16_t *src_addr = src_data + src_x_offset; + float16_t *dst_addr = dst_data + dst_x_offset; + for (int k = 0; k < real_c; k++) { + dst_addr[k] = src_addr[k]; + } + } + } + } +} + +// fp16 common winograd +void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransFp16Func func) { +#ifdef ENABLE_ARM64 + const int tile_num = 16; +#else + const int tile_num = 12; +#endif + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; + int ic8 = UP_DIV(in_channel, C8NUM); + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + if (out_w_block_num == 0) { + return; + } + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * in_channel; + for (int ic = 0; ic < ic8; ic++) { + int real_c = in_channel - ic * C8NUM; + real_c = real_c > C8NUM ? C8NUM : real_c; + const float16_t *src_data = input_data + src_plane_offset + ic * C8NUM; + PrepareTransInputFp16(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, + conv_param); + + // input transform + int dst_ic8_offset = dst_plane_offset + ic * C8NUM; + size_t dst_step = in_channel * tile_num; + float16_t *trans_input_ptr = trans_input + dst_ic8_offset; + func(tmp_data, trans_input_ptr, C8NUM, dst_step, real_c); + } + out_tile_index++; + } // cal_tile_num loop +} + +// Only support arm64 +void WinogradInputTransformOptStepFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, + int cal_num, int out_tile_index, int out_w_block_num, + const ConvParameter *conv_param, InputTransStepFp16Func func) { + const int tile_num = 16; + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; + int ic8 = UP_DIV(in_channel, C8NUM); + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + if (out_w_block_num == 0) { + return; + } + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * C8NUM; + for (int ic = 0; ic < ic8; ic++) { + int real_c = in_channel - ic * C8NUM; + real_c = real_c > C8NUM ? C8NUM : real_c; + const float16_t *src_data = input_data + src_plane_offset + ic * C8NUM; + PrepareTransInputFp16(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, + conv_param); + + // input transform + int dst_ic8_offset = dst_plane_offset + ic * tile_num * input_unit * input_unit * C8NUM; + size_t dst_step = input_unit * tile_num * C8NUM; + float16_t *trans_input_ptr = trans_input + dst_ic8_offset; + func(tmp_data, trans_input_ptr, C8NUM, dst_step, tile_num * C8NUM); + } + out_tile_index++; + } // cal_tile_num loop +} + +void WinogradOutputNHWCTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, + int cal_num, int out_tile_index, int output_unit_num, + const ConvParameter *conv_param, OutputTransFp16Func func) { + int output_unit = conv_param->output_unit_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int output_channel = conv_param->output_channel_; + int oc8 = UP_DIV(output_channel, C8NUM); + int input_unit = conv_param->input_unit_; + NNACL_CHECK_ZERO_RETURN(output_unit_num); + for (int i = 0; i < cal_num; i++) { + int dst_x_s = out_tile_index % output_unit_num; + int dst_y_s = out_tile_index / output_unit_num; + int r_w = output_w - dst_x_s * output_unit; + r_w = r_w > output_unit ? output_unit : r_w; + int r_h = output_h - dst_y_s * output_unit; + r_h = r_h > output_unit ? output_unit : r_h; + int tmp_ix = dst_x_s * output_unit; + dst_x_s = tmp_ix > output_w ? output_w : tmp_ix; + int tmp_iy = dst_y_s * output_unit; + dst_y_s = tmp_iy > output_h ? output_h : tmp_iy; + + int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit; + int dst_tile_offset = output_channel * (dst_x_s + dst_y_s * output_w); + + for (int j = 0; j < oc8; j++) { + int r_c = output_channel - j * C8NUM; + r_c = r_c > C8NUM ? C8NUM : r_c; + int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM; + int dst_oc8_offset = dst_tile_offset + j * C8NUM; + const float16_t *src_ptr = gemm_out + src_oc8_offset; + const float16_t *bias_ptr = bias_data + j * C8NUM; + float16_t *dst_ptr = tmp_out_data + dst_oc8_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, output_channel, r_w, r_h, r_c); + } + out_tile_index++; + } +} + +void WinogradOutputNC8HW8TransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, + int cal_num, int out_tile_index, int output_unit_num, + const ConvParameter *conv_param, OutputTransFp16Func func) { + int output_unit = conv_param->output_unit_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int plane = output_w * output_h; + int output_channel = conv_param->output_channel_; + int oc8 = UP_DIV(output_channel, C8NUM); + int input_unit = conv_param->input_unit_; + NNACL_CHECK_ZERO_RETURN(output_unit_num); + for (int i = 0; i < cal_num; i++) { + int dst_x_s = out_tile_index % output_unit_num; + int dst_y_s = out_tile_index / output_unit_num; + int r_w = output_w - dst_x_s * output_unit; + r_w = r_w > output_unit ? output_unit : r_w; + int r_h = output_h - dst_y_s * output_unit; + r_h = r_h > output_unit ? output_unit : r_h; + int tmp_ix = dst_x_s * output_unit; + dst_x_s = tmp_ix > output_w ? output_w : tmp_ix; + int tmp_iy = dst_y_s * output_unit; + dst_y_s = tmp_iy > output_h ? output_h : tmp_iy; + + int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit; + int dst_tile_offset = dst_x_s + dst_y_s * output_w; + + for (int j = 0; j < oc8; j++) { + int r_c = output_channel - j * C8NUM; + r_c = r_c > C8NUM ? C8NUM : r_c; + int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM; + int dst_oc8_offset = (dst_tile_offset + plane * j) * C8NUM; + const float16_t *src_ptr = gemm_out + src_oc8_offset; + const float16_t *bias_ptr = bias_data + j * C8NUM; + float16_t *dst_ptr = tmp_out_data + dst_oc8_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, r_c, r_w, r_h, r_c); + } + out_tile_index++; + } +} + +int WinogradWeightTransformFp16(const float16_t *weight_data, float16_t *winograd_data, const float *matrix_g, + const float *matrix_gt, int oc_block, int input_unit, int kernel_unit, + int filter_channel, int filter_batch, bool pack) { + // original weight format : ohwi + int oc_block_num = UP_DIV(filter_batch, oc_block); + int block_stride = filter_channel * oc_block; + int block_num_stride = block_stride * oc_block_num; + + float16_t *matrix_gt_data_fp16 = (float16_t *)(malloc(input_unit * kernel_unit * sizeof(float16_t))); + if (matrix_gt_data_fp16 == NULL) { + return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR; + } + Float32ToFloat16(matrix_gt, matrix_gt_data_fp16, input_unit * kernel_unit); + + // trans_filter = G*g*GT (g represents weight_data) = [(g * (G)T)T * (G)T]T + // separate into two steps ===> tmp = (g * (G)T)T ===> out = [tmp * (G)T]T + float16_t *tmp_data = (float16_t *)(malloc(filter_channel * input_unit * kernel_unit * sizeof(float16_t))); + if (tmp_data == NULL) { + free(matrix_gt_data_fp16); + return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR; + } + float16_t *trans_out_data = (float16_t *)(malloc(filter_channel * input_unit * input_unit * sizeof(float16_t))); + if (trans_out_data == NULL) { + free(tmp_data); + free(matrix_gt_data_fp16); + return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR; + } + +#ifndef ENABLE_ARM64 + float16_t *tmp_data1 = (float16_t *)(malloc(filter_channel * input_unit * kernel_unit * sizeof(float16_t))); + if (tmp_data1 == NULL) { + free(tmp_data); + free(matrix_gt_data_fp16); + free(trans_out_data); + return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR; + } + float16_t *trans_out_data1 = (float16_t *)(malloc(filter_channel * input_unit * input_unit * sizeof(float16_t))); + if (trans_out_data1 == NULL) { + free(tmp_data); + free(tmp_data1); + free(matrix_gt_data_fp16); + free(trans_out_data); + return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR; + } +#endif + + int input_oz_offset = kernel_unit * kernel_unit * filter_channel; + for (int i = 0; i < filter_batch; i++) { + int out_c_block = i / oc_block; + int out_c_res = i % oc_block; + int output_oz_offset = out_c_block * block_stride + out_c_res; + +#ifndef ENABLE_ARM64 + // tmp_data = g * GT + MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit, + kernel_unit, input_unit, filter_channel); + // tmp_data1 = (tmp_data)T + PackHWCToWHCFp16(tmp_data, tmp_data1, kernel_unit, input_unit, filter_channel); + // trans_out_data1 = tmp * GT + MatrixMultiplyWinogradFp16(tmp_data1, matrix_gt_data_fp16, trans_out_data1, input_unit, kernel_unit, input_unit, + filter_channel); + // trans_out_data = (trans_out_data1)T + PackHWCToWHCFp16(trans_out_data1, trans_out_data, input_unit, input_unit, filter_channel); +#else + // tmp = (g * GT)T + MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit, + kernel_unit, input_unit, filter_channel); + // trans = (tmp * GT)T + MatrixMultiplyWinogradFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit, kernel_unit, input_unit, + filter_channel); +#endif + + if (pack) { + int in_offset = 0; + for (int j = 0; j < input_unit; ++j) { + for (int k = 0; k < input_unit; ++k) { + for (int c = 0; c < filter_channel; ++c) { + *(winograd_data + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c]; + } + in_offset += filter_channel; + output_oz_offset += block_num_stride; + } + } + } else { + memcpy(winograd_data + i * filter_channel * input_unit * input_unit, trans_out_data, + filter_channel * input_unit * input_unit * sizeof(float16_t)); + } + } + +#ifndef ENABLE_ARM64 + free(tmp_data1); + free(trans_out_data1); +#endif + free(tmp_data); + free(trans_out_data); + free(matrix_gt_data_fp16); + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..3ace951b01c9a6dc4f6a835abf5e14c6f2357b51 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_WINOGRAD_TRANSFORM_FP16_H_ +#define NNACL_FP16_WINOGRAD_TRANSFORM_FP16_H_ + +#include +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/fp16/conv_fp16.h" +#include "nnacl_c/fp16/matrix_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif +// fp16 common winograd +void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransFp16Func func); + +void WinogradInputTransformOptStepFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, + int cal_num, int out_tile_index, int out_w_block_num, + const ConvParameter *conv_param, InputTransStepFp16Func func); + +void WinogradOutputNHWCTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, + int cal_num, int out_tile_index, int output_unit_num, + const ConvParameter *conv_param, OutputTransFp16Func func); + +void WinogradOutputNC8HW8TransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, + int cal_num, int out_tile_index, int output_unit_num, + const ConvParameter *conv_param, OutputTransFp16Func func); + +// fp16 winograd weight trans +int WinogradWeightTransformFp16(const float16_t *weight_data, float16_t *winograd_data, const float *matrix_g, + const float *matrix_gt, int oc_block, int input_unit, int kernel_unit, + int filter_channel, int filter_batch, bool pack); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_WINOGRAD_TRANSFORM_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..0f58223cf7c3c93dcd7e7cd9bbda4e3f6bc489a1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c @@ -0,0 +1,3278 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/winograd_utils_fp16.h" +#include "nnacl_c/fp16/matrix_fp16.h" + +#define MIN_UNIT_FP16 2 +#define MAX_UNIT_FP16 4 + +#ifdef ENABLE_ARM64 +void transpose8(float16x8_t *s0, float16x8_t *s1, float16x8_t *s2, float16x8_t *s3, float16x8_t *s4, float16x8_t *s5, + float16x8_t *s6, float16x8_t *s7) { + float32x4_t m0 = (float32x4_t)(vtrn1q_f16(*s0, *s1)); + float32x4_t m1 = (float32x4_t)(vtrn2q_f16(*s0, *s1)); + float32x4_t m2 = (float32x4_t)(vtrn1q_f16(*s2, *s3)); + float32x4_t m3 = (float32x4_t)(vtrn2q_f16(*s2, *s3)); + float32x4_t m4 = (float32x4_t)(vtrn1q_f16(*s4, *s5)); + float32x4_t m5 = (float32x4_t)(vtrn2q_f16(*s4, *s5)); + float32x4_t m6 = (float32x4_t)(vtrn1q_f16(*s6, *s7)); + float32x4_t m7 = (float32x4_t)(vtrn2q_f16(*s6, *s7)); + + float64x2_t t0 = (float64x2_t)(vtrn1q_f32(m0, m2)); + float64x2_t t2 = (float64x2_t)(vtrn2q_f32(m0, m2)); + float64x2_t t1 = (float64x2_t)(vtrn1q_f32(m1, m3)); + float64x2_t t3 = (float64x2_t)(vtrn2q_f32(m1, m3)); + float64x2_t t4 = (float64x2_t)(vtrn1q_f32(m4, m6)); + float64x2_t t6 = (float64x2_t)(vtrn2q_f32(m4, m6)); + float64x2_t t5 = (float64x2_t)(vtrn1q_f32(m5, m7)); + float64x2_t t7 = (float64x2_t)(vtrn2q_f32(m5, m7)); + + *s0 = (float16x8_t)(vtrn1q_f64(t0, t4)); + *s4 = (float16x8_t)(vtrn2q_f64(t0, t4)); + *s1 = (float16x8_t)(vtrn1q_f64(t1, t5)); + *s5 = (float16x8_t)(vtrn2q_f64(t1, t5)); + *s2 = (float16x8_t)(vtrn1q_f64(t2, t6)); + *s6 = (float16x8_t)(vtrn2q_f64(t2, t6)); + *s3 = (float16x8_t)(vtrn1q_f64(t3, t7)); + *s7 = (float16x8_t)(vtrn2q_f64(t3, t7)); +} +#endif + +static InputTransFp16Func InputTransFp16FuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4UnitFp16, NULL, InputTransform6x6UnitFp16, NULL, InputTransform8x8UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncList4[] = {NULL, NULL, OutputTransform4x2UnitFp16, + OutputTransform4x3UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncReluList4[] = {NULL, NULL, OutputTransform4x2ReluUnitFp16, + OutputTransform4x3ReluUnitFp16}; +static OutputTransFp16Func OutputTransFp16FuncRelu6List4[] = {NULL, NULL, OutputTransform4x2Relu6UnitFp16, + OutputTransform4x3Relu6UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncList6[] = {NULL, + NULL, + OutputTransform6x2UnitFp16, + OutputTransform6x3UnitFp16, + OutputTransform6x4UnitFp16, + OutputTransform6x5UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncReluList6[] = {NULL, + NULL, + OutputTransform6x2ReluUnitFp16, + OutputTransform6x3ReluUnitFp16, + OutputTransform6x4ReluUnitFp16, + OutputTransform6x5ReluUnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncRelu6List6[] = {NULL, + NULL, + OutputTransform6x2Relu6UnitFp16, + OutputTransform6x3Relu6UnitFp16, + OutputTransform6x4Relu6UnitFp16, + OutputTransform6x5Relu6UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncList8[] = {NULL, + NULL, + OutputTransform8x2UnitFp16, + OutputTransform8x3UnitFp16, + OutputTransform8x4UnitFp16, + OutputTransform8x5UnitFp16, + OutputTransform8x6UnitFp16, + OutputTransform8x7UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncReluList8[] = {NULL, + NULL, + OutputTransform8x2ReluUnitFp16, + OutputTransform8x3ReluUnitFp16, + OutputTransform8x4ReluUnitFp16, + OutputTransform8x5ReluUnitFp16, + OutputTransform8x6ReluUnitFp16, + OutputTransform8x7ReluUnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncRelu6List8[] = {NULL, + NULL, + OutputTransform8x2Relu6UnitFp16, + OutputTransform8x3Relu6UnitFp16, + OutputTransform8x4Relu6UnitFp16, + OutputTransform8x5Relu6UnitFp16, + OutputTransform8x6Relu6UnitFp16, + OutputTransform8x7Relu6UnitFp16}; + +InputTransFp16Func GetInputTransFp16Func(int input_unit) { return InputTransFp16FuncList[input_unit]; } + +#ifdef ENABLE_ARM64 +static InputTransStepFp16Func InputTransStepFp16FuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4StepFp16, NULL, InputTransform6x6StepFp16, NULL, InputTransform8x8StepFp16}; + +static InputTransPackFp16Func InputTransPackFp16FuncList[] = {NULL, + NULL, + NULL, + NULL, + InputTransform4x4Pack16Fp16, + NULL, + InputTransform6x6Pack16Fp16, + NULL, + InputTransform8x8Pack16Fp16}; + +InputTransStepFp16Func GetInputTransStepFp16Func(int input_unit) { return InputTransStepFp16FuncList[input_unit]; } + +InputTransPackFp16Func GetInputTransPackFp16Func(int input_unit) { return InputTransPackFp16FuncList[input_unit]; } +#endif + +void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int j = 0; + if (real_c == 8) { + float16x8_t src[16]; + float16x8_t t[16]; + float16x8_t m[16]; + Load16DataFp16; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vsubq_f16(src[offset], src[2 + offset]); + t[4 + l] = vaddq_f16(src[1 + offset], src[2 + offset]); + t[8 + l] = vsubq_f16(src[2 + offset], src[1 + offset]); + t[12 + l] = vsubq_f16(src[3 + offset], src[1 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = vsubq_f16(t[offset], t[2 + offset]); + m[4 + l] = vaddq_f16(t[1 + offset], t[2 + offset]); + m[8 + l] = vsubq_f16(t[2 + offset], t[1 + offset]); + m[12 + l] = vsubq_f16(t[3 + offset], t[1 + offset]); + } + for (int i = 0; i < 16; i++) { + int dst_offset = i * dst_step; + vst1q_f16(dst_data + dst_offset, m[i]); + } + real_c -= 8; + } else if (real_c < 8 && real_c >= 4) { + float16x4_t src[16]; + float16x4_t t[16]; + float16x4_t m[16]; + Load16DataC4Fp16; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vsub_f16(src[offset], src[2 + offset]); + t[4 + l] = vadd_f16(src[1 + offset], src[2 + offset]); + t[8 + l] = vsub_f16(src[2 + offset], src[1 + offset]); + t[12 + l] = vsub_f16(src[3 + offset], src[1 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = vsub_f16(t[offset], t[2 + offset]); + m[4 + l] = vadd_f16(t[1 + offset], t[2 + offset]); + m[8 + l] = vsub_f16(t[2 + offset], t[1 + offset]); + m[12 + l] = vsub_f16(t[3 + offset], t[1 + offset]); + } + for (int i = 0; i < 16; i++) { + int dst_offset = i * dst_step; + vst1_f16(dst_data + dst_offset, m[i]); + } + j = 4; + } + for (; j < real_c; ++j) { + float16_t src[16]; + float16_t t[16]; + float16_t m[16]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[j + k * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] - src[2 + offset]; + t[4 + l] = src[1 + offset] + src[2 + offset]; + t[8 + l] = src[2 + offset] - src[1 + offset]; + t[12 + l] = src[3 + offset] - src[1 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = t[offset] - t[2 + offset]; + m[4 + l] = t[1 + offset] + t[2 + offset]; + m[8 + l] = t[2 + offset] - t[1 + offset]; + m[12 + l] = t[3 + offset] - t[1 + offset]; + } + for (int i = 0; i < 16; i++) { + int dst_offset = i * dst_step; + dst_data[j + dst_offset] = m[i]; + } + } +} + +void InputTransform4x4StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step) { + for (int l = 0; l < 4; ++l) { + const float16_t *src_ptr = src_data + l * 4 * src_step; + float16_t *dst_ptr = dst_data + l * dst_row_step; + + float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step); + float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step); + float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step); + float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step); + float16x8_t m0 = vsubq_f16(s0, s2); + float16x8_t m1 = vaddq_f16(s1, s2); + float16x8_t m2 = vsubq_f16(s2, s1); + float16x8_t m3 = vsubq_f16(s3, s1); + + vst1q_f16(dst_ptr + 0 * dst_step, m0); + vst1q_f16(dst_ptr + 1 * dst_step, m1); + vst1q_f16(dst_ptr + 2 * dst_step, m2); + vst1q_f16(dst_ptr + 3 * dst_step, m3); + } +} + +#ifdef ENABLE_ARM64 +void InputTransform4x4Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile, + int src_point_stride) { + LOAD_LINE_DATA_FP16(0); + LOAD_LINE_DATA_FP16(1); + LOAD_LINE_DATA_FP16(2); + LOAD_LINE_DATA_FP16(3); + + float16x8_t m0 = vsubq_f16(s00, s20); + float16x8_t m1 = vsubq_f16(s01, s21); + vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(s10, s20); + m1 = vaddq_f16(s11, s21); + vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + + m0 = vsubq_f16(s20, s10); + m1 = vsubq_f16(s21, s11); + vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + + m0 = vsubq_f16(s30, s10); + m1 = vsubq_f16(s31, s11); + vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); +} + +void InputTransform4x4Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 16; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; + for (int l = 0; l < 4; ++l) { + float16_t *src_ptr = src_data + l * C8NUM * block_tile; + TRANSPOSE_16x8; + } + + for (int c = 0; c < real_c; ++c) { + float16_t *src_ptr = src_data + c * block_tile; + float16_t *dst_ptr = dst_data + c * block_tile; + InputTransform4x4Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +} +#endif + +void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int j = 0; + if (real_c == 8) { + float16x8_t src[36]; + float16x8_t t[36]; + float16x8_t m[36]; + Load36DataFp16; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vsubq_f16(src[3 + offset], src[1 + offset]); + float16x8_t tmp2 = vsubq_f16(src[4 + offset], src[2 + offset]); + t[l] = vaddq_f16(vsubq_f16(vmulq_n_f16(src[offset], 4), vmulq_n_f16(src[2 + offset], 5)), src[4 + offset]); + t[6 + l] = vaddq_f16(vmulq_n_f16(vaddq_f16(src[1 + offset], src[2 + offset]), -4), + vaddq_f16(src[3 + offset], src[4 + offset])); + t[12 + l] = vaddq_f16(vmulq_n_f16(vsubq_f16(src[1 + offset], src[2 + offset]), 4), + vsubq_f16(src[4 + offset], src[3 + offset])); + t[18 + l] = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); + t[24 + l] = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); + t[30 + l] = + vaddq_f16(vsubq_f16(vmulq_n_f16(src[1 + offset], 4), vmulq_n_f16(src[3 + offset], 5)), src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vsubq_f16(t[3 + offset], t[1 + offset]); + float16x8_t tmp2 = vsubq_f16(t[4 + offset], t[2 + offset]); + m[l] = vaddq_f16(vsubq_f16(vmulq_n_f16(t[offset], 4), vmulq_n_f16(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = + vaddq_f16(vmulq_n_f16(vaddq_f16(t[1 + offset], t[2 + offset]), -4), vaddq_f16(t[3 + offset], t[4 + offset])); + m[12 + l] = + vaddq_f16(vmulq_n_f16(vsubq_f16(t[1 + offset], t[2 + offset]), 4), vsubq_f16(t[4 + offset], t[3 + offset])); + m[18 + l] = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); + m[24 + l] = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); + m[30 + l] = vaddq_f16(vsubq_f16(vmulq_n_f16(t[1 + offset], 4), vmulq_n_f16(t[3 + offset], 5)), t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + int dst_offset = i * dst_step; + vst1q_f16(dst_data + dst_offset, m[i]); + } + real_c -= 8; + } else if (real_c < 8 && real_c >= 4) { + float16x4_t src[36]; + float16x4_t t[36]; + float16x4_t m[36]; + Load36DataC4Fp16; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x4_t tmp1 = vsub_f16(src[3 + offset], src[1 + offset]); + float16x4_t tmp2 = vsub_f16(src[4 + offset], src[2 + offset]); + t[l] = vadd_f16(vsub_f16(vmul_n_f16(src[offset], 4), vmul_n_f16(src[2 + offset], 5)), src[4 + offset]); + t[6 + l] = vadd_f16(vmul_n_f16(vadd_f16(src[1 + offset], src[2 + offset]), -4), + vadd_f16(src[3 + offset], src[4 + offset])); + t[12 + l] = + vadd_f16(vmul_n_f16(vsub_f16(src[1 + offset], src[2 + offset]), 4), vsub_f16(src[4 + offset], src[3 + offset])); + t[18 + l] = vadd_f16(vmul_n_f16(tmp1, 2), tmp2); + t[24 + l] = vadd_f16(vmul_n_f16(tmp1, -2), tmp2); + t[30 + l] = vadd_f16(vsub_f16(vmul_n_f16(src[1 + offset], 4), vmul_n_f16(src[3 + offset], 5)), src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x4_t tmp1 = vsub_f16(t[3 + offset], t[1 + offset]); + float16x4_t tmp2 = vsub_f16(t[4 + offset], t[2 + offset]); + m[l] = vadd_f16(vsub_f16(vmul_n_f16(t[offset], 4), vmul_n_f16(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = + vadd_f16(vmul_n_f16(vadd_f16(t[1 + offset], t[2 + offset]), -4), vadd_f16(t[3 + offset], t[4 + offset])); + m[12 + l] = + vadd_f16(vmul_n_f16(vsub_f16(t[1 + offset], t[2 + offset]), 4), vsub_f16(t[4 + offset], t[3 + offset])); + m[18 + l] = vadd_f16(vmul_n_f16(tmp1, 2), tmp2); + m[24 + l] = vadd_f16(vmul_n_f16(tmp1, -2), tmp2); + m[30 + l] = vadd_f16(vsub_f16(vmul_n_f16(t[1 + offset], 4), vmul_n_f16(t[3 + offset], 5)), t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + int dst_offset = i * dst_step; + vst1_f16(dst_data + dst_offset, m[i]); + } + j = 4; + } + for (; j < real_c; ++j) { + float16_t src[36]; + float16_t t[36]; + float16_t m[36]; + for (int k = 0; k < 36; ++k) { + src[k] = src_data[j + k * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16_t tmp1 = src[3 + offset] - src[1 + offset]; + float16_t tmp2 = src[4 + offset] - src[2 + offset]; + t[l] = src[offset] * 4 - src[2 + offset] * 5 + src[4 + offset]; + t[6 + l] = (src[1 + offset] + src[2 + offset]) * -4 + (src[3 + offset] + src[4 + offset]); + t[12 + l] = (src[1 + offset] - src[2 + offset]) * 4 + (src[4 + offset] - src[3 + offset]); + t[18 + l] = tmp1 * 2 + tmp2; + t[24 + l] = tmp1 * -2 + tmp2; + t[30 + l] = src[1 + offset] * 4 - src[3 + offset] * 5 + src[5 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16_t tmp1 = t[3 + offset] - t[1 + offset]; + float16_t tmp2 = t[4 + offset] - t[2 + offset]; + m[l] = t[offset] * 4 - t[2 + offset] * 5 + t[4 + offset]; + m[6 + l] = (t[1 + offset] + t[2 + offset]) * -4 + (t[3 + offset] + t[4 + offset]); + m[12 + l] = (t[1 + offset] - t[2 + offset]) * 4 + (t[4 + offset] - t[3 + offset]); + m[18 + l] = tmp1 * 2 + tmp2; + m[24 + l] = tmp1 * -2 + tmp2; + m[30 + l] = t[1 + offset] * 4 - t[3 + offset] * 5 + t[5 + offset]; + } + for (int i = 0; i < 36; i++) { + int dst_offset = i * dst_step; + dst_data[j + dst_offset] = m[i]; + } + } +} + +void InputTransform6x6StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step) { + for (int l = 0; l < 6; ++l) { + const float16_t *src_ptr = src_data + l * 6 * src_step; + float16_t *dst_ptr = dst_data + l * dst_row_step; + + float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step); + float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step); + float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step); + float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step); + float16x8_t s4 = vld1q_f16(src_ptr + 4 * src_step); + float16x8_t s5 = vld1q_f16(src_ptr + 5 * src_step); + + float16x8_t tmp1 = vsubq_f16(s3, s1); + float16x8_t tmp2 = vsubq_f16(s4, s2); + float16x8_t m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s0, 4), vmulq_n_f16(s2, 5)), s4); + float16x8_t m1 = vaddq_f16(vmulq_n_f16(vaddq_f16(s1, s2), -4), vaddq_f16(s3, s4)); + float16x8_t m2 = vaddq_f16(vmulq_n_f16(vsubq_f16(s1, s2), 4), vsubq_f16(s4, s3)); + float16x8_t m3 = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); + float16x8_t m4 = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); + float16x8_t m5 = vaddq_f16(vsubq_f16(vmulq_n_f16(s1, 4), vmulq_n_f16(s3, 5)), s5); + + vst1q_f16(dst_ptr + 0 * dst_step, m0); + vst1q_f16(dst_ptr + 1 * dst_step, m1); + vst1q_f16(dst_ptr + 2 * dst_step, m2); + vst1q_f16(dst_ptr + 3 * dst_step, m3); + vst1q_f16(dst_ptr + 4 * dst_step, m4); + vst1q_f16(dst_ptr + 5 * dst_step, m5); + } +} + +#ifdef ENABLE_ARM64 +void InputTransform6x6Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile, + int src_point_stride) { + LOAD_LINE_DATA_FP16(0); + LOAD_LINE_DATA_FP16(1); + LOAD_LINE_DATA_FP16(2); + LOAD_LINE_DATA_FP16(3); + LOAD_LINE_DATA_FP16(4); + LOAD_LINE_DATA_FP16(5); + + float16x8_t m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s00, 4), vmulq_n_f16(s20, 5)), s40); + float16x8_t m1 = vaddq_f16(vsubq_f16(vmulq_n_f16(s01, 4), vmulq_n_f16(s21, 5)), s41); + vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vaddq_f16(s10, s20), -4), vaddq_f16(s30, s40)); + m1 = vaddq_f16(vmulq_n_f16(vaddq_f16(s11, s21), -4), vaddq_f16(s31, s41)); + vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s10, s20), 4), vsubq_f16(s40, s30)); + m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s11, s21), 4), vsubq_f16(s41, s31)); + vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s30, s10), 2), vsubq_f16(s40, s20)); + m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s31, s11), 2), vsubq_f16(s41, s21)); + vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s30, s10), -2), vsubq_f16(s40, s20)); + m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s31, s11), -2), vsubq_f16(s41, s21)); + vst1q_f16(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s10, 4), vmulq_n_f16(s30, 5)), s50); + m1 = vaddq_f16(vsubq_f16(vmulq_n_f16(s11, 4), vmulq_n_f16(s31, 5)), s51); + vst1q_f16(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); +} + +void InputTransform6x6Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 16; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; + for (int l = 0; l < 6; ++l) { + float16_t *src_ptr = src_data + l * C8NUM * block_tile; + TRANSPOSE_16x8; + } + + for (int c = 0; c < real_c; ++c) { + float16_t *src_ptr = src_data + c * block_tile; + float16_t *dst_ptr = dst_data + c * block_tile; + InputTransform6x6Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +} +#endif + +void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int j = 0; + if (real_c == 8) { + float16x8_t src[64]; + float16x8_t t[64]; + float16x8_t m[64]; + Load64DataFp16; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(src[offset], 0.5625), vmulq_n_f16(src[2 + offset], 3.0625)), + vmulq_n_f16(src[4 + offset], 3.5)), + src[6 + offset]); + float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 1.125), vmulq_n_f16(src[5 + offset], 0.5)); + float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 2.25), vmulq_n_f16(src[4 + offset], 3.25)); + t[8 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 0.5625), vmulq_n_f16(src[4 + offset], 2.5)); + t[24 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 0.375), vmulq_n_f16(src[5 + offset], 1.5)); + tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 0.25), vmulq_n_f16(src[4 + offset], 1.25)); + t[40 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = + vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(src[1 + offset], -0.5625), vmulq_n_f16(src[3 + offset], 3.0625)), + vmulq_n_f16(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(t[offset], 0.5625), vmulq_n_f16(t[2 + offset], 3.0625)), + vmulq_n_f16(t[4 + offset], 3.5)), + t[6 + offset]); + float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 1.125), vmulq_n_f16(t[5 + offset], 0.5)); + float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 2.25), vmulq_n_f16(t[4 + offset], 3.25)); + m[8 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 0.5625), vmulq_n_f16(t[4 + offset], 2.5)); + m[24 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 0.375), vmulq_n_f16(t[5 + offset], 1.5)); + tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 0.25), vmulq_n_f16(t[4 + offset], 1.25)); + m[40 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = + vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(t[1 + offset], -0.5625), vmulq_n_f16(t[3 + offset], 3.0625)), + vmulq_n_f16(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + int dst_offset = i * dst_step; + vst1q_f16(dst_data + dst_offset, m[i]); + } + real_c -= 8; + } else if (real_c < 8 && real_c >= 4) { + float16x4_t src[64]; + float16x4_t t[64]; + float16x4_t m[64]; + Load64DataC4Fp16; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = vsub_f16(vadd_f16(vsub_f16(vmul_n_f16(src[offset], 0.5625), vmul_n_f16(src[2 + offset], 3.0625)), + vmul_n_f16(src[4 + offset], 3.5)), + src[6 + offset]); + float16x4_t tmp1 = vadd_f16(vmul_n_f16(src[1 + offset], 1.125), vmul_n_f16(src[5 + offset], 0.5)); + float16x4_t tmp2 = vsub_f16(vmul_n_f16(src[2 + offset], 2.25), vmul_n_f16(src[4 + offset], 3.25)); + t[8 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = vsub_f16(vmul_n_f16(src[2 + offset], 0.5625), vmul_n_f16(src[4 + offset], 2.5)); + t[24 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(src[1 + offset], 0.375), vmul_n_f16(src[5 + offset], 1.5)); + tmp2 = vsub_f16(vmul_n_f16(src[2 + offset], 0.25), vmul_n_f16(src[4 + offset], 1.25)); + t[40 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = vadd_f16(vsub_f16(vadd_f16(vmul_n_f16(src[1 + offset], -0.5625), vmul_n_f16(src[3 + offset], 3.0625)), + vmul_n_f16(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = vsub_f16(vadd_f16(vsub_f16(vmul_n_f16(t[offset], 0.5625), vmul_n_f16(t[2 + offset], 3.0625)), + vmul_n_f16(t[4 + offset], 3.5)), + t[6 + offset]); + float16x4_t tmp1 = vadd_f16(vmul_n_f16(t[1 + offset], 1.125), vmul_n_f16(t[5 + offset], 0.5)); + float16x4_t tmp2 = vsub_f16(vmul_n_f16(t[2 + offset], 2.25), vmul_n_f16(t[4 + offset], 3.25)); + m[8 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = vsub_f16(vmul_n_f16(t[2 + offset], 0.5625), vmul_n_f16(t[4 + offset], 2.5)); + m[24 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(t[1 + offset], 0.375), vmul_n_f16(t[5 + offset], 1.5)); + tmp2 = vsub_f16(vmul_n_f16(t[2 + offset], 0.25), vmul_n_f16(t[4 + offset], 1.25)); + m[40 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = vadd_f16(vsub_f16(vadd_f16(vmul_n_f16(t[1 + offset], -0.5625), vmul_n_f16(t[3 + offset], 3.0625)), + vmul_n_f16(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + int dst_offset = i * dst_step; + vst1_f16(dst_data + dst_offset, m[i]); + } + j = 4; + } + for (; j < real_c; ++j) { + float16_t src[64]; + float16_t t[64]; + float16_t m[64]; + for (int k = 0; k < 64; ++k) { + src[k] = src_data[j + k * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] * 0.5625f - src[2 + offset] * 3.0625f + src[4 + offset] * 3.5f - src[6 + offset]; + float16_t tmp1 = src[1 + offset] * 1.125f + src[5 + offset] * 0.5f; + float16_t tmp2 = src[2 + offset] * 2.25f - src[4 + offset] * 3.25f; + t[8 + l] = tmp1 + tmp2 - src[3 + offset] * 1.625f + src[6 + offset]; + t[16 + l] = tmp2 - tmp1 + src[3 + offset] * 1.625f + src[6 + offset]; + tmp1 = src[1 + offset] * 0.5625f + src[5 + offset]; + tmp2 = src[2 + offset] * 0.5625f - src[4 + offset] * 2.5f; + t[24 + l] = tmp1 + tmp2 - src[3 + offset] * 2.5f + src[6 + offset]; + t[32 + l] = tmp2 - tmp1 + src[3 + offset] * 2.5f + src[6 + offset]; + tmp1 = src[1 + offset] * 0.375f + src[5 + offset] * 1.5f; + tmp2 = src[2 + offset] * 0.25f - src[4 + offset] * 1.25f; + t[40 + l] = tmp1 + tmp2 - src[3 + offset] * 1.875f + src[6 + offset]; + t[48 + l] = tmp2 - tmp1 + src[3 + offset] * 1.875f + src[6 + offset]; + t[56 + l] = src[1 + offset] * -0.5625 + src[3 + offset] * 3.0625f - src[5 + offset] * 3.5f + src[7 + offset]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = t[offset] * 0.5625f - t[2 + offset] * 3.0625f + t[4 + offset] * 3.5f - t[6 + offset]; + float16_t tmp1 = t[1 + offset] * 1.125f + t[5 + offset] * 0.5f; + float16_t tmp2 = t[2 + offset] * 2.25f - t[4 + offset] * 3.25f; + m[8 + l] = tmp1 + tmp2 - t[3 + offset] * 1.625f + t[6 + offset]; + m[16 + l] = tmp2 - tmp1 + t[3 + offset] * 1.625f + t[6 + offset]; + tmp1 = t[1 + offset] * 0.5625f + t[5 + offset]; + tmp2 = t[2 + offset] * 0.5625f - t[4 + offset] * 2.5f; + m[24 + l] = tmp1 + tmp2 - t[3 + offset] * 2.5f + t[6 + offset]; + m[32 + l] = tmp2 - tmp1 + t[3 + offset] * 2.5f + t[6 + offset]; + tmp1 = t[1 + offset] * 0.375f + t[5 + offset] * 1.5f; + tmp2 = t[2 + offset] * 0.25f - t[4 + offset] * 1.25f; + m[40 + l] = tmp1 + tmp2 - t[3 + offset] * 1.875f + t[6 + offset]; + m[48 + l] = tmp2 - tmp1 + t[3 + offset] * 1.875f + t[6 + offset]; + m[56 + l] = t[1 + offset] * -0.5625 + t[3 + offset] * 3.0625f - t[5 + offset] * 3.5f + t[7 + offset]; + } + for (int i = 0; i < 64; i++) { + int dst_offset = i * dst_step; + dst_data[j + dst_offset] = m[i]; + } + } +} + +void InputTransform8x8StepFp16_uint(float16x8_t *s, float16x8_t *m) { + m[0] = + vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s[0], 0.5625), vmulq_n_f16(s[2], 3.0625)), vmulq_n_f16(s[4], 3.5)), s[6]); + float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(s[1], 1.125), vmulq_n_f16(s[5], 0.5)); + float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(s[2], 2.25), vmulq_n_f16(s[4], 3.25)); + m[1] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s[3], 1.625)), s[6]); + m[2] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s[3], 1.625)), s[6]); + tmp1 = vaddq_f16(vmulq_n_f16(s[1], 0.5625), s[5]); + tmp2 = vsubq_f16(vmulq_n_f16(s[2], 0.5625), vmulq_n_f16(s[4], 2.5)); + m[3] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s[3], 2.5)), s[6]); + m[4] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s[3], 2.5)), s[6]); + tmp1 = vaddq_f16(vmulq_n_f16(s[1], 0.375), vmulq_n_f16(s[5], 1.5)); + tmp2 = vsubq_f16(vmulq_n_f16(s[2], 0.25), vmulq_n_f16(s[4], 1.25)); + m[5] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s[3], 1.875)), s[6]); + m[6] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s[3], 1.875)), s[6]); + m[7] = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s[1], -0.5625), vmulq_n_f16(s[3], 3.0625)), vmulq_n_f16(s[5], 3.5)), + s[7]); +} + +void InputTransform8x8StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step) { + for (int l = 0; l < 8; ++l) { + const float16_t *src_ptr = src_data + l * 8 * src_step; + float16_t *dst_ptr = dst_data + l * dst_row_step; + + float16x8_t s[8]; + float16x8_t m[8]; + + s[0] = vld1q_f16(src_ptr + 0 * src_step); + s[1] = vld1q_f16(src_ptr + 1 * src_step); + s[2] = vld1q_f16(src_ptr + 2 * src_step); + s[3] = vld1q_f16(src_ptr + 3 * src_step); + s[4] = vld1q_f16(src_ptr + 4 * src_step); + s[5] = vld1q_f16(src_ptr + 5 * src_step); + s[6] = vld1q_f16(src_ptr + 6 * src_step); + s[7] = vld1q_f16(src_ptr + 7 * src_step); + + InputTransform8x8StepFp16_uint(s, m); + + vst1q_f16(dst_ptr + 0 * dst_step, m[0]); + vst1q_f16(dst_ptr + 1 * dst_step, m[1]); + vst1q_f16(dst_ptr + 2 * dst_step, m[2]); + vst1q_f16(dst_ptr + 3 * dst_step, m[3]); + vst1q_f16(dst_ptr + 4 * dst_step, m[4]); + vst1q_f16(dst_ptr + 5 * dst_step, m[5]); + vst1q_f16(dst_ptr + 6 * dst_step, m[6]); + vst1q_f16(dst_ptr + 7 * dst_step, m[7]); + } +} + +#ifdef ENABLE_ARM64 +void InputTransform8x8Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile, + int src_point_stride) { + LOAD_LINE_DATA_FP16(0); + LOAD_LINE_DATA_FP16(1); + LOAD_LINE_DATA_FP16(2); + LOAD_LINE_DATA_FP16(3); + LOAD_LINE_DATA_FP16(4); + LOAD_LINE_DATA_FP16(5); + LOAD_LINE_DATA_FP16(6); + LOAD_LINE_DATA_FP16(7); + + float16x8_t m0 = + vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s00, 0.5625), vmulq_n_f16(s20, 3.0625)), vmulq_n_f16(s40, 3.5)), s60); + float16x8_t m1 = + vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s01, 0.5625), vmulq_n_f16(s21, 3.0625)), vmulq_n_f16(s41, 3.5)), s61); + vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + + float16x8_t tmp10 = vaddq_f16(vmulq_n_f16(s10, 1.125), vmulq_n_f16(s50, 0.5)); + float16x8_t tmp11 = vaddq_f16(vmulq_n_f16(s11, 1.125), vmulq_n_f16(s51, 0.5)); + float16x8_t tmp20 = vsubq_f16(vmulq_n_f16(s20, 2.25), vmulq_n_f16(s40, 3.25)); + float16x8_t tmp21 = vsubq_f16(vmulq_n_f16(s21, 2.25), vmulq_n_f16(s41, 3.25)); + m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 1.625)), s60); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 1.625)), s61); + vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 1.625)), s60); + m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 1.625)), s61); + vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + + tmp10 = vaddq_f16(vmulq_n_f16(s10, 0.5625), s50); + tmp11 = vaddq_f16(vmulq_n_f16(s11, 0.5625), s51); + tmp20 = vsubq_f16(vmulq_n_f16(s20, 0.5625), vmulq_n_f16(s40, 2.5)); + tmp21 = vsubq_f16(vmulq_n_f16(s21, 0.5625), vmulq_n_f16(s41, 2.5)); + m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 2.5)), s60); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 2.5)), s61); + vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 2.5)), s60); + m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 2.5)), s61); + vst1q_f16(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + + tmp10 = vaddq_f16(vmulq_n_f16(s10, 0.375), vmulq_n_f16(s50, 1.5)); + tmp11 = vaddq_f16(vmulq_n_f16(s11, 0.375), vmulq_n_f16(s51, 1.5)); + tmp20 = vsubq_f16(vmulq_n_f16(s20, 0.25), vmulq_n_f16(s40, 1.25)); + tmp21 = vsubq_f16(vmulq_n_f16(s21, 0.25), vmulq_n_f16(s41, 1.25)); + m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 1.875)), s60); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 1.875)), s61); + vst1q_f16(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 1.875)), s60); + m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 1.875)), s61); + vst1q_f16(dst_ptr + 6 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 6 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s10, -0.5625), vmulq_n_f16(s30, 3.0625)), vmulq_n_f16(s50, 3.5)), s70); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s11, -0.5625), vmulq_n_f16(s31, 3.0625)), vmulq_n_f16(s51, 3.5)), s71); + vst1q_f16(dst_ptr + 7 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 7 * dst_step + 1 * pack_tile, m1); +} + +void InputTransform8x8Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 16; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; + for (int l = 0; l < 8; ++l) { + float16_t *src_ptr = src_data + l * C8NUM * block_tile; + TRANSPOSE_16x8; + } + + for (int c = 0; c < real_c; ++c) { + float16_t *src_ptr = src_data + c * block_tile; + float16_t *dst_ptr = dst_data + c * block_tile; + InputTransform8x8Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +} +#endif + +OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type) { + if (input_unit == 4 && output_unit < 4) { + if (act_type == ActType_Relu) { + return OutputTransFp16FuncReluList4[output_unit]; + } else if (act_type == ActType_Relu6) { + return OutputTransFp16FuncRelu6List4[output_unit]; + } else { + return OutputTransFp16FuncList4[output_unit]; + } + } else if (input_unit == 6 && output_unit < 6) { + if (act_type == ActType_Relu) { + return OutputTransFp16FuncReluList6[output_unit]; + } else if (act_type == ActType_Relu6) { + return OutputTransFp16FuncRelu6List6[output_unit]; + } else { + return OutputTransFp16FuncList6[output_unit]; + } + } else if (input_unit == 8 && output_unit < 8) { + if (act_type == ActType_Relu) { + return OutputTransFp16FuncReluList8[output_unit]; + } else if (act_type == ActType_Relu6) { + return OutputTransFp16FuncRelu6List8[output_unit]; + } else { + return OutputTransFp16FuncList8[output_unit]; + } + } else { + return NULL; + } +} + +void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[16]; + float16x8_t t[8]; + float16x8_t m[4]; + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + } + if (r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[16]; + float16x4_t t[8]; + float16x4_t m[4]; + Load16DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vadd_f16(vadd_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vadd_f16(vsub_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vadd_f16(vadd_f16(vadd_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vadd_f16(vadd_f16(vsub_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + } + if (r_h == 2 && r_w == 2) { + Store4DataC4Fp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[16]; + float16_t t[8]; + float16_t m[4]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + bias_ptr; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset] + bias_ptr; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform4x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[16]; + float16x8_t t[8]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[16]; + float16x4_t t[8]; + float16x4_t m[4]; + float16x4_t zero = vdup_n_f16(0); + Load16DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vadd_f16(vadd_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vadd_f16(vsub_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vadd_f16(vadd_f16(vadd_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vadd_f16(vadd_f16(vsub_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l + 2] = vmax_f16(zero, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataC4Fp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[16]; + float16_t t[8]; + float16_t m[4]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + bias_ptr; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l + 2] = m[l + 2] > 0 ? m[l + 2] : 0; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform4x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[16]; + float16x8_t t[8]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + m[l + 2] = vminq_f16(six, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[16]; + float16x4_t t[8]; + float16x4_t m[4]; + float16x4_t zero = vdup_n_f16(0); + float16x4_t six = vdup_n_f16(6); + Load16DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vadd_f16(vadd_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vadd_f16(vsub_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vadd_f16(vadd_f16(vadd_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vadd_f16(vadd_f16(vsub_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l] = vmin_f16(six, m[l]); + m[l + 2] = vmax_f16(zero, m[l + 2]); + m[l + 2] = vmin_f16(six, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataC4Fp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[16]; + float16_t t[8]; + float16_t m[4]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + bias_ptr; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l] = m[l] < 6 ? m[l] : 6; + m[l + 2] = m[l + 2] > 0 ? m[l + 2] : 0; + m[l + 2] = m[l + 2] < 6 ? m[l + 2] : 6; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform4x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[16]; + float16x8_t t[12]; + float16x8_t m[9]; + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(src[1 + offset], src[2 + offset]); + t[l] = vaddq_f16(src[offset], tmp); + t[l + 4] = vsubq_f16(src[1 + offset], src[2 + offset]); + t[l + 8] = vaddq_f16(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(t[1 + offset], t[2 + offset]); + m[l] = vaddq_f16(vaddq_f16(t[offset], tmp), bias_ptr); + m[l + 3] = vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(tmp, t[3 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform4x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[16]; + float16x8_t t[12]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(src[1 + offset], src[2 + offset]); + t[l] = vaddq_f16(src[offset], tmp); + t[l + 4] = vsubq_f16(src[1 + offset], src[2 + offset]); + t[l + 8] = vaddq_f16(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(t[1 + offset], t[2 + offset]); + m[l] = vaddq_f16(vaddq_f16(t[offset], tmp), bias_ptr); + m[l + 3] = vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(tmp, t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform4x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[16]; + float16x8_t t[12]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(src[1 + offset], src[2 + offset]); + t[l] = vaddq_f16(src[offset], tmp); + t[l + 4] = vsubq_f16(src[1 + offset], src[2 + offset]); + t[l + 8] = vaddq_f16(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(t[1 + offset], t[2 + offset]); + m[l] = vaddq_f16(vaddq_f16(t[offset], tmp), bias_ptr); + m[l + 3] = vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(tmp, t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 3] = vminq_f16(six, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[12]; + float16x8_t m[4]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = vaddq_f16(vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), t[4 + offset]), + bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), + vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[12]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = vaddq_f16(vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), t[4 + offset]), + bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), + vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[12]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = vaddq_f16(vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), t[4 + offset]), + bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), + vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + m[l + 2] = vminq_f16(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[18]; + float16x8_t m[9]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = vaddq_f16( + vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[18]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = vaddq_f16( + vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[18]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = vaddq_f16( + vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 3] = vminq_f16(six, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[24]; + float16x8_t m[16]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[24]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[24]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 4] = vminq_f16(six, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 8] = vminq_f16(six, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 12] = vminq_f16(six, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[30]; + float16x8_t m[25]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)); + t[l + 24] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), bias_ptr); + m[l + 20] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[30]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)); + t[l + 24] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), bias_ptr); + m[l + 20] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[30]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)); + t[l + 24] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), bias_ptr); + m[l + 20] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 5] = vminq_f16(six, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 10] = vminq_f16(six, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 15] = vminq_f16(six, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + m[l + 20] = vminq_f16(six, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[16]; + float16x8_t m[4]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), t[7 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[16]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[16]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + m[l + 2] = vminq_f16(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[24]; + float16x8_t m[9]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 6] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), t[7 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[24]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 6] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[24]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 6] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 3] = vminq_f16(six, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[32]; + float16x8_t m[16]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 12] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[32]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 12] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[32]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 12] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 4] = vminq_f16(six, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 8] = vminq_f16(six, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 12] = vminq_f16(six, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[40]; + float16x8_t m[25]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 20] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[40]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 20] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[40]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 20] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 5] = vminq_f16(six, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 10] = vminq_f16(six, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 15] = vminq_f16(six, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + m[l + 20] = vminq_f16(six, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[64]; + float16x8_t t[48]; + float16x8_t m[36]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[64]; + float16x4_t t[48]; + float16x4_t m[36]; + Load64DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp2 = vadd_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp3 = vadd_f16(src[5 + offset], src[6 + offset]); + float16x4_t tmp4 = vsub_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp5 = vsub_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp6 = vsub_f16(src[5 + offset], src[6 + offset]); + t[l] = vadd_f16(vadd_f16(vadd_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)); + t[l + 16] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)); + t[l + 24] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)); + t[l + 32] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)); + t[l + 40] = + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp2 = vadd_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp3 = vadd_f16(t[5 + offset], t[6 + offset]); + float16x4_t tmp4 = vsub_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp5 = vsub_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp6 = vsub_f16(t[5 + offset], t[6 + offset]); + m[l] = vadd_f16(vadd_f16(vadd_f16(vadd_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vadd_f16( + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[64]; + float16_t t[48]; + float16_t m[36]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16_t tmp1 = src[1 + offset] + src[2 + offset]; + float16_t tmp2 = src[3 + offset] + src[4 + offset]; + float16_t tmp3 = src[5 + offset] + src[6 + offset]; + float16_t tmp4 = src[1 + offset] - src[2 + offset]; + float16_t tmp5 = src[3 + offset] - src[4 + offset]; + float16_t tmp6 = src[5 + offset] - src[6 + offset]; + t[l] = src[offset] + tmp1 + tmp2 + tmp3; + t[l + 8] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f; + t[l + 16] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f; + t[l + 24] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f; + t[l + 32] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f; + t[l + 40] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16_t tmp1 = t[1 + offset] + t[2 + offset]; + float16_t tmp2 = t[3 + offset] + t[4 + offset]; + float16_t tmp3 = t[5 + offset] + t[6 + offset]; + float16_t tmp4 = t[1 + offset] - t[2 + offset]; + float16_t tmp5 = t[3 + offset] - t[4 + offset]; + float16_t tmp6 = t[5 + offset] - t[6 + offset]; + m[l] = t[offset] + tmp1 + tmp2 + tmp3 + bias_ptr; + m[l + 6] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f + bias_ptr; + m[l + 12] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f + bias_ptr; + m[l + 18] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f + bias_ptr; + m[l + 24] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f + bias_ptr; + m[l + 30] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + t[7 + offset] + bias_ptr; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform8x6ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[64]; + float16x8_t t[48]; + float16x8_t m[36]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 18] = vmaxq_f16(zero, m[l + 18]); + m[l + 24] = vmaxq_f16(zero, m[l + 24]); + m[l + 30] = vmaxq_f16(zero, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[64]; + float16x4_t t[48]; + float16x4_t m[36]; + float16x4_t zero = vdup_n_f16(0); + Load64DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp2 = vadd_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp3 = vadd_f16(src[5 + offset], src[6 + offset]); + float16x4_t tmp4 = vsub_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp5 = vsub_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp6 = vsub_f16(src[5 + offset], src[6 + offset]); + t[l] = vadd_f16(vadd_f16(vadd_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)); + t[l + 16] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)); + t[l + 24] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)); + t[l + 32] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)); + t[l + 40] = + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp2 = vadd_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp3 = vadd_f16(t[5 + offset], t[6 + offset]); + float16x4_t tmp4 = vsub_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp5 = vsub_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp6 = vsub_f16(t[5 + offset], t[6 + offset]); + m[l] = vadd_f16(vadd_f16(vadd_f16(vadd_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vadd_f16( + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l + 6] = vmax_f16(zero, m[l + 6]); + m[l + 12] = vmax_f16(zero, m[l + 12]); + m[l + 18] = vmax_f16(zero, m[l + 18]); + m[l + 24] = vmax_f16(zero, m[l + 24]); + m[l + 30] = vmax_f16(zero, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[64]; + float16_t t[48]; + float16_t m[36]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16_t tmp1 = src[1 + offset] + src[2 + offset]; + float16_t tmp2 = src[3 + offset] + src[4 + offset]; + float16_t tmp3 = src[5 + offset] + src[6 + offset]; + float16_t tmp4 = src[1 + offset] - src[2 + offset]; + float16_t tmp5 = src[3 + offset] - src[4 + offset]; + float16_t tmp6 = src[5 + offset] - src[6 + offset]; + t[l] = src[offset] + tmp1 + tmp2 + tmp3; + t[l + 8] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f; + t[l + 16] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f; + t[l + 24] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f; + t[l + 32] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f; + t[l + 40] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16_t tmp1 = t[1 + offset] + t[2 + offset]; + float16_t tmp2 = t[3 + offset] + t[4 + offset]; + float16_t tmp3 = t[5 + offset] + t[6 + offset]; + float16_t tmp4 = t[1 + offset] - t[2 + offset]; + float16_t tmp5 = t[3 + offset] - t[4 + offset]; + float16_t tmp6 = t[5 + offset] - t[6 + offset]; + m[l] = t[offset] + tmp1 + tmp2 + tmp3 + bias_ptr; + m[l + 6] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f + bias_ptr; + m[l + 12] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f + bias_ptr; + m[l + 18] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f + bias_ptr; + m[l + 24] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f + bias_ptr; + m[l + 30] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + t[7 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l + 6] = m[l + 6] > 0 ? m[l + 6] : 0; + m[l + 12] = m[l + 12] > 0 ? m[l + 12] : 0; + m[l + 18] = m[l + 18] > 0 ? m[l + 18] : 0; + m[l + 24] = m[l + 24] > 0 ? m[l + 24] : 0; + m[l + 30] = m[l + 30] > 0 ? m[l + 30] : 0; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform8x6Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[64]; + float16x8_t t[48]; + float16x8_t m[36]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 12] = vminq_f16(six, m[l + 12]); + m[l + 18] = vmaxq_f16(zero, m[l + 18]); + m[l + 18] = vminq_f16(six, m[l + 18]); + m[l + 24] = vmaxq_f16(zero, m[l + 24]); + m[l + 24] = vminq_f16(six, m[l + 24]); + m[l + 30] = vmaxq_f16(zero, m[l + 30]); + m[l + 30] = vminq_f16(six, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[64]; + float16x4_t t[48]; + float16x4_t m[36]; + float16x4_t zero = vdup_n_f16(0); + float16x4_t six = vdup_n_f16(6); + Load64DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp2 = vadd_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp3 = vadd_f16(src[5 + offset], src[6 + offset]); + float16x4_t tmp4 = vsub_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp5 = vsub_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp6 = vsub_f16(src[5 + offset], src[6 + offset]); + t[l] = vadd_f16(vadd_f16(vadd_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)); + t[l + 16] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)); + t[l + 24] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)); + t[l + 32] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)); + t[l + 40] = + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp2 = vadd_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp3 = vadd_f16(t[5 + offset], t[6 + offset]); + float16x4_t tmp4 = vsub_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp5 = vsub_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp6 = vsub_f16(t[5 + offset], t[6 + offset]); + m[l] = vadd_f16(vadd_f16(vadd_f16(vadd_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vadd_f16( + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l] = vmin_f16(six, m[l]); + m[l + 6] = vmax_f16(zero, m[l + 6]); + m[l + 6] = vmin_f16(six, m[l + 6]); + m[l + 12] = vmax_f16(zero, m[l + 12]); + m[l + 12] = vmin_f16(six, m[l + 12]); + m[l + 18] = vmax_f16(zero, m[l + 18]); + m[l + 18] = vmin_f16(six, m[l + 18]); + m[l + 24] = vmax_f16(zero, m[l + 24]); + m[l + 24] = vmin_f16(six, m[l + 24]); + m[l + 30] = vmax_f16(zero, m[l + 30]); + m[l + 30] = vmin_f16(six, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[64]; + float16_t t[48]; + float16_t m[36]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16_t tmp1 = src[1 + offset] + src[2 + offset]; + float16_t tmp2 = src[3 + offset] + src[4 + offset]; + float16_t tmp3 = src[5 + offset] + src[6 + offset]; + float16_t tmp4 = src[1 + offset] - src[2 + offset]; + float16_t tmp5 = src[3 + offset] - src[4 + offset]; + float16_t tmp6 = src[5 + offset] - src[6 + offset]; + t[l] = src[offset] + tmp1 + tmp2 + tmp3; + t[l + 8] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f; + t[l + 16] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f; + t[l + 24] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f; + t[l + 32] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f; + t[l + 40] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16_t tmp1 = t[1 + offset] + t[2 + offset]; + float16_t tmp2 = t[3 + offset] + t[4 + offset]; + float16_t tmp3 = t[5 + offset] + t[6 + offset]; + float16_t tmp4 = t[1 + offset] - t[2 + offset]; + float16_t tmp5 = t[3 + offset] - t[4 + offset]; + float16_t tmp6 = t[5 + offset] - t[6 + offset]; + m[l] = t[offset] + tmp1 + tmp2 + tmp3 + bias_ptr; + m[l + 6] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f + bias_ptr; + m[l + 12] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f + bias_ptr; + m[l + 18] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f + bias_ptr; + m[l + 24] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f + bias_ptr; + m[l + 30] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + t[7 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l] = m[l] > 0 ? m[l] : 0; + m[l + 6] = m[l + 6] > 0 ? m[l + 6] : 0; + m[l + 6] = m[l + 6] < 6 ? m[l + 6] : 6; + m[l + 12] = m[l + 12] > 0 ? m[l + 12] : 0; + m[l + 12] = m[l + 12] < 6 ? m[l + 12] : 6; + m[l + 18] = m[l + 18] > 0 ? m[l + 18] : 0; + m[l + 18] = m[l + 18] < 6 ? m[l + 18] : 6; + m[l + 24] = m[l + 24] > 0 ? m[l + 24] : 0; + m[l + 24] = m[l + 24] < 6 ? m[l + 24] : 6; + m[l + 30] = m[l + 30] > 0 ? m[l + 30] : 0; + m[l + 30] = m[l + 30] < 6 ? m[l + 30] : 6; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform8x7UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[56]; + float16x8_t m[49]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)); + t[l + 48] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 14] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 21] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 28] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 35] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), bias_ptr); + m[l + 42] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + vst1q_f16(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x7ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[56]; + float16x8_t m[49]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)); + t[l + 48] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 14] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 21] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 28] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 35] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), bias_ptr); + m[l + 42] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 7] = vmaxq_f16(zero, m[l + 7]); + m[l + 14] = vmaxq_f16(zero, m[l + 14]); + m[l + 21] = vmaxq_f16(zero, m[l + 21]); + m[l + 28] = vmaxq_f16(zero, m[l + 28]); + m[l + 35] = vmaxq_f16(zero, m[l + 35]); + m[l + 42] = vmaxq_f16(zero, m[l + 42]); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + vst1q_f16(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x7Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[56]; + float16x8_t m[49]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)); + t[l + 48] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 14] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 21] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 28] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 35] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), bias_ptr); + m[l + 42] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 7] = vmaxq_f16(zero, m[l + 7]); + m[l + 7] = vminq_f16(six, m[l + 7]); + m[l + 14] = vmaxq_f16(zero, m[l + 14]); + m[l + 14] = vminq_f16(six, m[l + 14]); + m[l + 21] = vmaxq_f16(zero, m[l + 21]); + m[l + 21] = vminq_f16(six, m[l + 21]); + m[l + 28] = vmaxq_f16(zero, m[l + 28]); + m[l + 28] = vminq_f16(six, m[l + 28]); + m[l + 35] = vmaxq_f16(zero, m[l + 35]); + m[l + 35] = vminq_f16(six, m[l + 35]); + m[l + 42] = vmaxq_f16(zero, m[l + 42]); + m[l + 42] = vminq_f16(six, m[l + 42]); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + vst1q_f16(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +int SelectOutputUnitFp16(const ConvParameter *conv_param) { + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_c = conv_param->input_channel_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_c = conv_param->output_channel_; + int unit2 = UP_DIV(out_w * out_h, C16NUM * conv_param->op_parameter_.thread_num_); + int max_out_unit = (int)(sqrtf((float)unit2)); + max_out_unit = max_out_unit < MAX_UNIT_FP16 ? max_out_unit : MAX_UNIT_FP16; + max_out_unit = max_out_unit > MIN_UNIT_FP16 ? max_out_unit : MIN_UNIT_FP16; + + int unit = 0; + float max_rate = 0.0f; + float common_cost = (float)out_h * out_w * in_c * out_c * kernel_h * kernel_w; + + for (int i = MIN_UNIT_FP16; i <= max_out_unit; ++i) { + int input_unit = i + kernel_w - 1; + if (!GetOutputTransFp16Func(input_unit, i, ActType_No)) { + continue; + } + float penalty = ((float)input_unit * input_unit) / ((float)kernel_h * kernel_w) * 0.12f; + float wino_cost = ((2 + out_c) * (float)input_unit * input_unit * in_c + ((float)input_unit + i) * i * out_c) * + UP_DIV(out_w, i) * UP_DIV(out_h, i); + float reduce_rate = common_cost / wino_cost - penalty; + if (reduce_rate > max_rate) { + max_rate = reduce_rate; + unit = i; + } + } + if (max_rate < 1.0f) { + return 1; + } + // If output_unit is 1, then it is conventional convolution + return unit; +} + +void CheckIfUseWinogradFp16(bool *use_winograd, int *output_unit, const ConvParameter *conv_param) { + if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { + *output_unit = SelectOutputUnitFp16(conv_param); + if (*output_unit > 1) { + *use_winograd = true; + } + } else { + *use_winograd = false; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..40bfdce60aa672db3f15397fc61c89a07f5d79e5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.h @@ -0,0 +1,163 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_WINOGRAD_UTILS_H_ +#define NNACL_FP16_WINOGRAD_UTILS_H_ + +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp16/winograd_utils_fp16_macro.h" + +#define MAX_LEN 256 + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*InputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int real_c); + +typedef void (*InputTransStepFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +typedef void (*InputTransPackFp16Func)(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int real_c); + +typedef void (*OutputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); + +typedef struct TransFp16FuncList { + InputTransFp16Func in_func_; + InputTransStepFp16Func in_step_func_; + InputTransPackFp16Func in_pack_func_; + OutputTransFp16Func out_func_; +} TransFp16FuncList; + +InputTransFp16Func GetInputTransFp16Func(int input_unit); + +#ifdef ENABLE_ARM64 +InputTransStepFp16Func GetInputTransStepFp16Func(int input_unit); + +InputTransPackFp16Func GetInputTransPackFp16Func(int input_unit); +#endif + +void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform4x4StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +void InputTransform6x6StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +void InputTransform8x8StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +#ifdef ENABLE_ARM64 +void InputTransform4x4Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); +#endif + +OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type); + +void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform6x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform8x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); + +int SelectOutputUnitFp16(const ConvParameter *conv_param); + +void CheckIfUseWinogradFp16(bool *use_winograd, int *output_unit, const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_WINOGRAD_UTILS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16_macro.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16_macro.h new file mode 100644 index 0000000000000000000000000000000000000000..defe39cd141739f504af8717a76dc23bd134694b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16_macro.h @@ -0,0 +1,437 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_WINOGRAD_UTILS_MACRO_H_ +#define NNACL_FP16_WINOGRAD_UTILS_MACRO_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif +#define Load16DataFp16 \ + src[0] = vld1q_f16(src_data + 0 * src_step); \ + src[1] = vld1q_f16(src_data + 1 * src_step); \ + src[2] = vld1q_f16(src_data + 2 * src_step); \ + src[3] = vld1q_f16(src_data + 3 * src_step); \ + src[4] = vld1q_f16(src_data + 4 * src_step); \ + src[5] = vld1q_f16(src_data + 5 * src_step); \ + src[6] = vld1q_f16(src_data + 6 * src_step); \ + src[7] = vld1q_f16(src_data + 7 * src_step); \ + src[8] = vld1q_f16(src_data + 8 * src_step); \ + src[9] = vld1q_f16(src_data + 9 * src_step); \ + src[10] = vld1q_f16(src_data + 10 * src_step); \ + src[11] = vld1q_f16(src_data + 11 * src_step); \ + src[12] = vld1q_f16(src_data + 12 * src_step); \ + src[13] = vld1q_f16(src_data + 13 * src_step); \ + src[14] = vld1q_f16(src_data + 14 * src_step); \ + src[15] = vld1q_f16(src_data + 15 * src_step); + +#define Load16DataC4Fp16 \ + src[0] = vld1_f16(src_data + 0 * src_step); \ + src[1] = vld1_f16(src_data + 1 * src_step); \ + src[2] = vld1_f16(src_data + 2 * src_step); \ + src[3] = vld1_f16(src_data + 3 * src_step); \ + src[4] = vld1_f16(src_data + 4 * src_step); \ + src[5] = vld1_f16(src_data + 5 * src_step); \ + src[6] = vld1_f16(src_data + 6 * src_step); \ + src[7] = vld1_f16(src_data + 7 * src_step); \ + src[8] = vld1_f16(src_data + 8 * src_step); \ + src[9] = vld1_f16(src_data + 9 * src_step); \ + src[10] = vld1_f16(src_data + 10 * src_step); \ + src[11] = vld1_f16(src_data + 11 * src_step); \ + src[12] = vld1_f16(src_data + 12 * src_step); \ + src[13] = vld1_f16(src_data + 13 * src_step); \ + src[14] = vld1_f16(src_data + 14 * src_step); \ + src[15] = vld1_f16(src_data + 15 * src_step); + +#define Load36DataFp16 \ + src[0] = vld1q_f16(src_data + 0 * src_step); \ + src[1] = vld1q_f16(src_data + 1 * src_step); \ + src[2] = vld1q_f16(src_data + 2 * src_step); \ + src[3] = vld1q_f16(src_data + 3 * src_step); \ + src[4] = vld1q_f16(src_data + 4 * src_step); \ + src[5] = vld1q_f16(src_data + 5 * src_step); \ + src[6] = vld1q_f16(src_data + 6 * src_step); \ + src[7] = vld1q_f16(src_data + 7 * src_step); \ + src[8] = vld1q_f16(src_data + 8 * src_step); \ + src[9] = vld1q_f16(src_data + 9 * src_step); \ + src[10] = vld1q_f16(src_data + 10 * src_step); \ + src[11] = vld1q_f16(src_data + 11 * src_step); \ + src[12] = vld1q_f16(src_data + 12 * src_step); \ + src[13] = vld1q_f16(src_data + 13 * src_step); \ + src[14] = vld1q_f16(src_data + 14 * src_step); \ + src[15] = vld1q_f16(src_data + 15 * src_step); \ + src[16] = vld1q_f16(src_data + 16 * src_step); \ + src[17] = vld1q_f16(src_data + 17 * src_step); \ + src[18] = vld1q_f16(src_data + 18 * src_step); \ + src[19] = vld1q_f16(src_data + 19 * src_step); \ + src[20] = vld1q_f16(src_data + 20 * src_step); \ + src[21] = vld1q_f16(src_data + 21 * src_step); \ + src[22] = vld1q_f16(src_data + 22 * src_step); \ + src[23] = vld1q_f16(src_data + 23 * src_step); \ + src[24] = vld1q_f16(src_data + 24 * src_step); \ + src[25] = vld1q_f16(src_data + 25 * src_step); \ + src[26] = vld1q_f16(src_data + 26 * src_step); \ + src[27] = vld1q_f16(src_data + 27 * src_step); \ + src[28] = vld1q_f16(src_data + 28 * src_step); \ + src[29] = vld1q_f16(src_data + 29 * src_step); \ + src[30] = vld1q_f16(src_data + 30 * src_step); \ + src[31] = vld1q_f16(src_data + 31 * src_step); \ + src[32] = vld1q_f16(src_data + 32 * src_step); \ + src[33] = vld1q_f16(src_data + 33 * src_step); \ + src[34] = vld1q_f16(src_data + 34 * src_step); \ + src[35] = vld1q_f16(src_data + 35 * src_step); + +#define Load36DataC4Fp16 \ + src[0] = vld1_f16(src_data + 0 * src_step); \ + src[1] = vld1_f16(src_data + 1 * src_step); \ + src[2] = vld1_f16(src_data + 2 * src_step); \ + src[3] = vld1_f16(src_data + 3 * src_step); \ + src[4] = vld1_f16(src_data + 4 * src_step); \ + src[5] = vld1_f16(src_data + 5 * src_step); \ + src[6] = vld1_f16(src_data + 6 * src_step); \ + src[7] = vld1_f16(src_data + 7 * src_step); \ + src[8] = vld1_f16(src_data + 8 * src_step); \ + src[9] = vld1_f16(src_data + 9 * src_step); \ + src[10] = vld1_f16(src_data + 10 * src_step); \ + src[11] = vld1_f16(src_data + 11 * src_step); \ + src[12] = vld1_f16(src_data + 12 * src_step); \ + src[13] = vld1_f16(src_data + 13 * src_step); \ + src[14] = vld1_f16(src_data + 14 * src_step); \ + src[15] = vld1_f16(src_data + 15 * src_step); \ + src[16] = vld1_f16(src_data + 16 * src_step); \ + src[17] = vld1_f16(src_data + 17 * src_step); \ + src[18] = vld1_f16(src_data + 18 * src_step); \ + src[19] = vld1_f16(src_data + 19 * src_step); \ + src[20] = vld1_f16(src_data + 20 * src_step); \ + src[21] = vld1_f16(src_data + 21 * src_step); \ + src[22] = vld1_f16(src_data + 22 * src_step); \ + src[23] = vld1_f16(src_data + 23 * src_step); \ + src[24] = vld1_f16(src_data + 24 * src_step); \ + src[25] = vld1_f16(src_data + 25 * src_step); \ + src[26] = vld1_f16(src_data + 26 * src_step); \ + src[27] = vld1_f16(src_data + 27 * src_step); \ + src[28] = vld1_f16(src_data + 28 * src_step); \ + src[29] = vld1_f16(src_data + 29 * src_step); \ + src[30] = vld1_f16(src_data + 30 * src_step); \ + src[31] = vld1_f16(src_data + 31 * src_step); \ + src[32] = vld1_f16(src_data + 32 * src_step); \ + src[33] = vld1_f16(src_data + 33 * src_step); \ + src[34] = vld1_f16(src_data + 34 * src_step); \ + src[35] = vld1_f16(src_data + 35 * src_step); + +#define Load64DataFp16 \ + src[0] = vld1q_f16(src_data + 0 * src_step); \ + src[1] = vld1q_f16(src_data + 1 * src_step); \ + src[2] = vld1q_f16(src_data + 2 * src_step); \ + src[3] = vld1q_f16(src_data + 3 * src_step); \ + src[4] = vld1q_f16(src_data + 4 * src_step); \ + src[5] = vld1q_f16(src_data + 5 * src_step); \ + src[6] = vld1q_f16(src_data + 6 * src_step); \ + src[7] = vld1q_f16(src_data + 7 * src_step); \ + src[8] = vld1q_f16(src_data + 8 * src_step); \ + src[9] = vld1q_f16(src_data + 9 * src_step); \ + src[10] = vld1q_f16(src_data + 10 * src_step); \ + src[11] = vld1q_f16(src_data + 11 * src_step); \ + src[12] = vld1q_f16(src_data + 12 * src_step); \ + src[13] = vld1q_f16(src_data + 13 * src_step); \ + src[14] = vld1q_f16(src_data + 14 * src_step); \ + src[15] = vld1q_f16(src_data + 15 * src_step); \ + src[16] = vld1q_f16(src_data + 16 * src_step); \ + src[17] = vld1q_f16(src_data + 17 * src_step); \ + src[18] = vld1q_f16(src_data + 18 * src_step); \ + src[19] = vld1q_f16(src_data + 19 * src_step); \ + src[20] = vld1q_f16(src_data + 20 * src_step); \ + src[21] = vld1q_f16(src_data + 21 * src_step); \ + src[22] = vld1q_f16(src_data + 22 * src_step); \ + src[23] = vld1q_f16(src_data + 23 * src_step); \ + src[24] = vld1q_f16(src_data + 24 * src_step); \ + src[25] = vld1q_f16(src_data + 25 * src_step); \ + src[26] = vld1q_f16(src_data + 26 * src_step); \ + src[27] = vld1q_f16(src_data + 27 * src_step); \ + src[28] = vld1q_f16(src_data + 28 * src_step); \ + src[29] = vld1q_f16(src_data + 29 * src_step); \ + src[30] = vld1q_f16(src_data + 30 * src_step); \ + src[31] = vld1q_f16(src_data + 31 * src_step); \ + src[32] = vld1q_f16(src_data + 32 * src_step); \ + src[33] = vld1q_f16(src_data + 33 * src_step); \ + src[34] = vld1q_f16(src_data + 34 * src_step); \ + src[35] = vld1q_f16(src_data + 35 * src_step); \ + src[36] = vld1q_f16(src_data + 36 * src_step); \ + src[37] = vld1q_f16(src_data + 37 * src_step); \ + src[38] = vld1q_f16(src_data + 38 * src_step); \ + src[39] = vld1q_f16(src_data + 39 * src_step); \ + src[40] = vld1q_f16(src_data + 40 * src_step); \ + src[41] = vld1q_f16(src_data + 41 * src_step); \ + src[42] = vld1q_f16(src_data + 42 * src_step); \ + src[43] = vld1q_f16(src_data + 43 * src_step); \ + src[44] = vld1q_f16(src_data + 44 * src_step); \ + src[45] = vld1q_f16(src_data + 45 * src_step); \ + src[46] = vld1q_f16(src_data + 46 * src_step); \ + src[47] = vld1q_f16(src_data + 47 * src_step); \ + src[48] = vld1q_f16(src_data + 48 * src_step); \ + src[49] = vld1q_f16(src_data + 49 * src_step); \ + src[50] = vld1q_f16(src_data + 50 * src_step); \ + src[51] = vld1q_f16(src_data + 51 * src_step); \ + src[52] = vld1q_f16(src_data + 52 * src_step); \ + src[53] = vld1q_f16(src_data + 53 * src_step); \ + src[54] = vld1q_f16(src_data + 54 * src_step); \ + src[55] = vld1q_f16(src_data + 55 * src_step); \ + src[56] = vld1q_f16(src_data + 56 * src_step); \ + src[57] = vld1q_f16(src_data + 57 * src_step); \ + src[58] = vld1q_f16(src_data + 58 * src_step); \ + src[59] = vld1q_f16(src_data + 59 * src_step); \ + src[60] = vld1q_f16(src_data + 60 * src_step); \ + src[61] = vld1q_f16(src_data + 61 * src_step); \ + src[62] = vld1q_f16(src_data + 62 * src_step); \ + src[63] = vld1q_f16(src_data + 63 * src_step); + +#define Load64DataC4Fp16 \ + src[0] = vld1_f16(src_data + 0 * src_step); \ + src[1] = vld1_f16(src_data + 1 * src_step); \ + src[2] = vld1_f16(src_data + 2 * src_step); \ + src[3] = vld1_f16(src_data + 3 * src_step); \ + src[4] = vld1_f16(src_data + 4 * src_step); \ + src[5] = vld1_f16(src_data + 5 * src_step); \ + src[6] = vld1_f16(src_data + 6 * src_step); \ + src[7] = vld1_f16(src_data + 7 * src_step); \ + src[8] = vld1_f16(src_data + 8 * src_step); \ + src[9] = vld1_f16(src_data + 9 * src_step); \ + src[10] = vld1_f16(src_data + 10 * src_step); \ + src[11] = vld1_f16(src_data + 11 * src_step); \ + src[12] = vld1_f16(src_data + 12 * src_step); \ + src[13] = vld1_f16(src_data + 13 * src_step); \ + src[14] = vld1_f16(src_data + 14 * src_step); \ + src[15] = vld1_f16(src_data + 15 * src_step); \ + src[16] = vld1_f16(src_data + 16 * src_step); \ + src[17] = vld1_f16(src_data + 17 * src_step); \ + src[18] = vld1_f16(src_data + 18 * src_step); \ + src[19] = vld1_f16(src_data + 19 * src_step); \ + src[20] = vld1_f16(src_data + 20 * src_step); \ + src[21] = vld1_f16(src_data + 21 * src_step); \ + src[22] = vld1_f16(src_data + 22 * src_step); \ + src[23] = vld1_f16(src_data + 23 * src_step); \ + src[24] = vld1_f16(src_data + 24 * src_step); \ + src[25] = vld1_f16(src_data + 25 * src_step); \ + src[26] = vld1_f16(src_data + 26 * src_step); \ + src[27] = vld1_f16(src_data + 27 * src_step); \ + src[28] = vld1_f16(src_data + 28 * src_step); \ + src[29] = vld1_f16(src_data + 29 * src_step); \ + src[30] = vld1_f16(src_data + 30 * src_step); \ + src[31] = vld1_f16(src_data + 31 * src_step); \ + src[32] = vld1_f16(src_data + 32 * src_step); \ + src[33] = vld1_f16(src_data + 33 * src_step); \ + src[34] = vld1_f16(src_data + 34 * src_step); \ + src[35] = vld1_f16(src_data + 35 * src_step); \ + src[36] = vld1_f16(src_data + 36 * src_step); \ + src[37] = vld1_f16(src_data + 37 * src_step); \ + src[38] = vld1_f16(src_data + 38 * src_step); \ + src[39] = vld1_f16(src_data + 39 * src_step); \ + src[40] = vld1_f16(src_data + 40 * src_step); \ + src[41] = vld1_f16(src_data + 41 * src_step); \ + src[42] = vld1_f16(src_data + 42 * src_step); \ + src[43] = vld1_f16(src_data + 43 * src_step); \ + src[44] = vld1_f16(src_data + 44 * src_step); \ + src[45] = vld1_f16(src_data + 45 * src_step); \ + src[46] = vld1_f16(src_data + 46 * src_step); \ + src[47] = vld1_f16(src_data + 47 * src_step); \ + src[48] = vld1_f16(src_data + 48 * src_step); \ + src[49] = vld1_f16(src_data + 49 * src_step); \ + src[50] = vld1_f16(src_data + 50 * src_step); \ + src[51] = vld1_f16(src_data + 51 * src_step); \ + src[52] = vld1_f16(src_data + 52 * src_step); \ + src[53] = vld1_f16(src_data + 53 * src_step); \ + src[54] = vld1_f16(src_data + 54 * src_step); \ + src[55] = vld1_f16(src_data + 55 * src_step); \ + src[56] = vld1_f16(src_data + 56 * src_step); \ + src[57] = vld1_f16(src_data + 57 * src_step); \ + src[58] = vld1_f16(src_data + 58 * src_step); \ + src[59] = vld1_f16(src_data + 59 * src_step); \ + src[60] = vld1_f16(src_data + 60 * src_step); \ + src[61] = vld1_f16(src_data + 61 * src_step); \ + src[62] = vld1_f16(src_data + 62 * src_step); \ + src[63] = vld1_f16(src_data + 63 * src_step); + +#define LOAD_LINE_DATA_FP16(line) \ + float16x8_t s##line##0 = vld1q_f16(src_ptr + line * src_point_stride + 0 * pack_tile); \ + float16x8_t s##line##1 = vld1q_f16(src_ptr + line * src_point_stride + 1 * pack_tile); + +#define TRANSPOSE_16x8 \ + float16x8_t s0 = vld1q_f16(src_ptr + 0 * pack_tile); \ + float16x8_t s2 = vld1q_f16(src_ptr + 1 * pack_tile); \ + float16x8_t s4 = vld1q_f16(src_ptr + 2 * pack_tile); \ + float16x8_t s6 = vld1q_f16(src_ptr + 3 * pack_tile); \ + float16x8_t s8 = vld1q_f16(src_ptr + 4 * pack_tile); \ + float16x8_t s10 = vld1q_f16(src_ptr + 5 * pack_tile); \ + float16x8_t s12 = vld1q_f16(src_ptr + 6 * pack_tile); \ + float16x8_t s14 = vld1q_f16(src_ptr + 7 * pack_tile); \ + float16x8_t s1 = vld1q_f16(src_ptr + 8 * pack_tile); \ + float16x8_t s3 = vld1q_f16(src_ptr + 9 * pack_tile); \ + float16x8_t s5 = vld1q_f16(src_ptr + 10 * pack_tile); \ + float16x8_t s7 = vld1q_f16(src_ptr + 11 * pack_tile); \ + float16x8_t s9 = vld1q_f16(src_ptr + 12 * pack_tile); \ + float16x8_t s11 = vld1q_f16(src_ptr + 13 * pack_tile); \ + float16x8_t s13 = vld1q_f16(src_ptr + 14 * pack_tile); \ + float16x8_t s15 = vld1q_f16(src_ptr + 15 * pack_tile); \ + transpose8(&s0, &s2, &s4, &s6, &s8, &s10, &s12, &s14); \ + transpose8(&s1, &s3, &s5, &s7, &s9, &s11, &s13, &s15); \ + vst1q_f16(src_ptr + 0 * pack_tile, s0); \ + vst1q_f16(src_ptr + 1 * pack_tile, s1); \ + vst1q_f16(src_ptr + 2 * pack_tile, s2); \ + vst1q_f16(src_ptr + 3 * pack_tile, s3); \ + vst1q_f16(src_ptr + 4 * pack_tile, s4); \ + vst1q_f16(src_ptr + 5 * pack_tile, s5); \ + vst1q_f16(src_ptr + 6 * pack_tile, s6); \ + vst1q_f16(src_ptr + 7 * pack_tile, s7); \ + vst1q_f16(src_ptr + 8 * pack_tile, s8); \ + vst1q_f16(src_ptr + 9 * pack_tile, s9); \ + vst1q_f16(src_ptr + 10 * pack_tile, s10); \ + vst1q_f16(src_ptr + 11 * pack_tile, s11); \ + vst1q_f16(src_ptr + 12 * pack_tile, s12); \ + vst1q_f16(src_ptr + 13 * pack_tile, s13); \ + vst1q_f16(src_ptr + 14 * pack_tile, s14); \ + vst1q_f16(src_ptr + 15 * pack_tile, s15); + +#define Store4DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + dst_step * out_c, m[2]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[3]); + +#define Store4DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + dst_step * out_c, m[2]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[3]); + +#define Store9DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + 2 * out_c, m[2]); \ + vst1q_f16(dst_data + dst_step * out_c, m[3]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[4]); \ + vst1q_f16(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c, m[6]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); + +#define Store9DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + 2 * out_c, m[2]); \ + vst1_f16(dst_data + dst_step * out_c, m[3]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[4]); \ + vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ + vst1_f16(dst_data + 2 * dst_step * out_c, m[6]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); + +#define Store16DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + 2 * out_c, m[2]); \ + vst1q_f16(dst_data + 3 * out_c, m[3]); \ + vst1q_f16(dst_data + dst_step * out_c, m[4]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[5]); \ + vst1q_f16(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ + vst1q_f16(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c, m[8]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c, m[12]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); + +#define Store16DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + 2 * out_c, m[2]); \ + vst1_f16(dst_data + 3 * out_c, m[3]); \ + vst1_f16(dst_data + dst_step * out_c, m[4]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[5]); \ + vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ + vst1_f16(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ + vst1_f16(dst_data + 2 * dst_step * out_c, m[8]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ + vst1_f16(dst_data + 3 * dst_step * out_c, m[12]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); + +#define Store25DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + 2 * out_c, m[2]); \ + vst1q_f16(dst_data + 3 * out_c, m[3]); \ + vst1q_f16(dst_data + 4 * out_c, m[4]); \ + vst1q_f16(dst_data + dst_step * out_c, m[5]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[6]); \ + vst1q_f16(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ + vst1q_f16(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ + vst1q_f16(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c, m[10]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c, m[15]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c, m[20]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); + +#define Store25DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + 2 * out_c, m[2]); \ + vst1_f16(dst_data + 3 * out_c, m[3]); \ + vst1_f16(dst_data + 4 * out_c, m[4]); \ + vst1_f16(dst_data + dst_step * out_c, m[5]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[6]); \ + vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ + vst1_f16(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ + vst1_f16(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ + vst1_f16(dst_data + 2 * dst_step * out_c, m[10]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ + vst1_f16(dst_data + 3 * dst_step * out_c, m[15]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ + vst1_f16(dst_data + 4 * dst_step * out_c, m[20]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_WINOGRAD_UTILS_MACRO_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/activation_grad_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/activation_grad_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..f9f35b6d8d51fcea6ae1a576ccb512958e8a817a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/activation_grad_fp16.c @@ -0,0 +1,151 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16_grad/activation_grad_fp16.h" +#include +#include +#ifdef ENABLE_NEON +#include +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +int ReluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t zero_v = vdupq_n_f16(0); + for (; i <= length - C8NUM; i += C8NUM) { + float16x8_t src0_v = vld1q_f16(src0 + i); + float16x8_t src1_v = vld1q_f16(src1 + i); + uint16x8_t mask_v = vcleq_f16(src1_v, zero_v); + float16x8_t dst_v = vbslq_f16(mask_v, zero_v, src0_v); + vst1q_f16(dst + i, dst_v); + } +#endif + for (; i < length; i++) { + dst[i] = (src1[i] > 0.0f) ? src0[i] : 0.0f; + } + return NNACL_OK; +} + +int Relu6Fp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t zero_8 = vdupq_n_f16(0); + float16x8_t six_8 = vdupq_n_f16(6); + for (; i <= length - C8NUM; i += C8NUM) { + float16x8_t src1_8 = vld1q_f16(src1 + i); + float16x8_t src0_8 = vld1q_f16(src0 + i); + uint16x8_t gt_8 = vcgtq_f16(src1_8, zero_8); + uint16x8_t le_8 = vcleq_f16(src1_8, six_8); + uint16x8_t mask_8 = vandq_u16(gt_8, le_8); + float16x8_t dst_8 = vbslq_f16(mask_8, src0_8, zero_8); + vst1q_f16(dst + i, dst_8); + } +#endif + for (; i < length; ++i) { + dst[i] = (src1[i] > 0.0f && src1[i] <= 6.0f) ? src0[i] : 0.0f; + } + return NNACL_OK; +} + +int LReluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst, float16_t alpha) { + int i = 0; +#ifdef ENABLE_NEON + const MS_FLOAT16X8 one_8 = vdupq_n_f16(1); + for (; i <= length - C8NUM; i += C8NUM) { + MS_FLOAT16X8 src0_8 = MS_LDQ_F16(src0 + i); + MS_FLOAT16X8 src1_8 = MS_LDQ_F16(src1 + i); + MS_STQ_F16(dst + i, vmulq_f16(src0_8, vmulq_f16(src1_8, (one_8 - src1_8)))); + } +#endif + for (; i < length; ++i) { + dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); + } + return NNACL_OK; +} + +int SigmoidFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t one_8 = vdupq_n_f16(1); + for (; i < length - C8NUM; i += C8NUM) { + float16x8_t src0_8 = vld1q_f16(src0 + i); + float16x8_t src1_8 = vld1q_f16(src1 + i); + float16x8_t dst_8 = vmulq_f16(src0_8, vmulq_f16(src1_8, vsubq_f16(one_8, src1_8))); + vst1q_f16(dst + i, dst_8); + } +#endif + for (; i < length; i++) { + dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); + } + return NNACL_OK; +} + +int TanhFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = (float16_t)((1.0f - ((float)src1[i] * (float)src1[i])) * (float)src0[i]); + } + return NNACL_OK; +} + +int HSwishFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + for (int i = 0; i < length; ++i) { + float16_t tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : (2.0f * src1[i] + 3.0f) / 6.0f)); + dst[i] = tmp * src0[i]; + } + return NNACL_OK; +} + +int HSigmoidFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + for (int i = 0; i < length; ++i) { + float16_t tmp = (src1[i] > 3.0f ? 0.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f)); + dst[i] = tmp * src0[i]; + } + return NNACL_OK; +} +int EluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst, float16_t alpha) { + int i = 0; +#ifdef ENABLE_NEON + float16x4_t zero_4 = vdup_n_f16(0); + float16x4_t one_4 = vdup_n_f16(1); + float16x4_t alpha_4 = vdup_n_f16(alpha); + for (; i <= length - C4NUM; i += C4NUM) { + float16x4_t src0_4 = vld1_f16(src0 + i); + float16x4_t src1_4 = vld1_f16(src1 + i); + uint16x4_t mask_4 = vcgt_f16(src1_4, zero_4); + float32x4_t tmp; + simd_exp128(vcvt_f32_f16(src1_4), (float *)&tmp); + float16x4_t expm1_4 = vsub_f16(vcvt_f16_f32(tmp), one_4); + float16x4_t dst_4 = vbsl_f16(mask_4, src0_4, vmul_f16(alpha_4, vmul_f16(expm1_4, src0_4))); + vst1_f16(dst + i, dst_4); + } +#endif + for (; i < length; ++i) { + dst[i] = (src1[i] > 0.0f ? src0[i] : alpha * expm1(src1[i]) * src0[i]); + } + return NNACL_OK; +} + +int GeluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = src0[i] * ((0.5 * (1.0 + erf(src1[i] / 1.4142135623730951))) + + (src1[i] * exp(-0.5 * src1[i] * src1[i]) / 2.5066282746)); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/activation_grad_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/activation_grad_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..99e7c463c5c419df93e486e8fd016d136a302f7b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/activation_grad_fp16.h @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRAD_ACTIVATION_GRAD_FP16_H_ +#define NNACL_FP16_GRAD_ACTIVATION_GRAD_FP16_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/fixed_point.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int Relu6Fp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int LReluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst, float16_t alpha); +int SigmoidFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int TanhFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int HSwishFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int HSigmoidFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int EluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst, float16_t alpha); +int GeluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_ACTIVATION_GRAD_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..40f63a9ae72a9004662c8e28d2e4ec9b1dce0a34 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_grad.c @@ -0,0 +1,158 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16_grad/arithmetic_grad.h" +#include +#include +#include "nnacl_c/fp32_grad/utils.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" + +void ElementDivNegSquareFp16(const float16_t *nom, const float16_t *denom, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = -nom[i] / (denom[i] * denom[i]); + } +} + +void ElementMulAndDivNegSquareFp16(const float16_t *a, const float16_t *b, const float16_t *denom, float16_t *output, + int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = -a[i] * b[i] / (denom[i] * denom[i]); + } +} + +int ElementAbsGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = (in1[i] < 0.f) ? -in2[i] : ((in1[i] > 0.f) ? in2[i] : 0); + } + return NNACL_OK; +} + +void MaximumByAxesFp16(const float16_t *input0, const float16_t *input1, const float16_t *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float16_t *output0, float16_t *output1, + int num_dims) { + int num_output0 = 1; + int num_output1 = 1; + bool same_shape = true; + for (int idx = 0; idx < num_dims; ++idx) { + num_output0 *= input0_dims[idx]; + num_output1 *= input1_dims[idx]; + if (input0_dims[idx] != input1_dims[idx]) { + same_shape = false; + } + } + + if (same_shape) { + int input_iter[C8NUM] = {0}; + + // Iterate through input_data. + do { + size_t offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset] = input0[offset] > input1[offset] ? dy[offset] : 0.; + output1[offset] = input1[offset] >= input0[offset] ? dy[offset] : 0.; + } while (NextIndex(num_dims, input0_dims, input_iter)); + } else { + memset(output0, 0, num_output0 * sizeof(float16_t)); // zero output + memset(output1, 0, num_output1 * sizeof(float16_t)); // zero output + + int input_iter[C8NUM] = {0}; + int axes0[C5NUM] = {0}; + int axes1[C5NUM] = {0}; + int num_axes0 = 0; + int num_axes1 = 0; + for (int i = 0; i < num_dims; i++) { + if (input0_dims[i] == 1) { + axes0[num_axes0++] = i; + } + if (input1_dims[i] == 1) { + axes1[num_axes1++] = i; + } + } + + do { + size_t offset0 = GetOutputOffset(num_dims, input0_dims, input_iter, num_axes0, axes0); + size_t offset1 = GetOutputOffset(num_dims, input1_dims, input_iter, num_axes1, axes1); + size_t yt_offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset0] += input0[offset0] > input1[offset1] ? dy[yt_offset] : 0.; + output1[offset1] += input1[offset1] >= input0[offset0] ? dy[yt_offset] : 0.; + } while (NextIndex(num_dims, dy_dims, input_iter)); + } +} + +void MinimumByAxesFp16(const float16_t *input0, const float16_t *input1, const float16_t *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float16_t *output0, float16_t *output1, + int num_dims) { + int num_output0 = 1; + int num_output1 = 1; + bool same_shape = true; + for (int idx = 0; idx < num_dims; ++idx) { + num_output0 *= input0_dims[idx]; + num_output1 *= input1_dims[idx]; + if (input0_dims[idx] != input1_dims[idx]) { + same_shape = false; + } + } + + if (same_shape) { + int input_iter[C8NUM] = {0}; + + // Iterate through input_data. + do { + size_t offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset] = input0[offset] < input1[offset] ? dy[offset] : 0.; + output1[offset] = input1[offset] <= input0[offset] ? dy[offset] : 0.; + } while (NextIndex(num_dims, input0_dims, input_iter)); + } else { + memset(output0, 0, num_output0 * sizeof(float16_t)); // zero output + memset(output1, 0, num_output1 * sizeof(float16_t)); // zero output + + int input_iter[C8NUM] = {0}; + int axes0[C5NUM] = {0}; + int axes1[C5NUM] = {0}; + int num_axes0 = 0; + int num_axes1 = 0; + for (int i = 0; i < num_dims; i++) { + if (input0_dims[i] == 1) { + axes0[num_axes0++] = i; + } + if (input1_dims[i] == 1) { + axes1[num_axes1++] = i; + } + } + + do { + size_t offset0 = GetOutputOffset(num_dims, input0_dims, input_iter, num_axes0, axes0); + size_t offset1 = GetOutputOffset(num_dims, input1_dims, input_iter, num_axes1, axes1); + size_t yt_offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset0] += input0[offset0] < input1[offset1] ? dy[yt_offset] : 0.; + output1[offset1] += input1[offset1] <= input0[offset0] ? dy[yt_offset] : 0.; + } while (NextIndex(num_dims, dy_dims, input_iter)); + } +} + +int ElementSqrtGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, const int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = 0.5f * in2[i] / in1[i]; + } + return NNACL_OK; +} + +int ElementRsqrtGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, const int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = -0.5f * in2[i] * in1[i] * in1[i] * in1[i]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..77aad4cb717449a123a4a077061c2c7640d8b11d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_grad.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRAD_ARITHMETIC_GRAD_H_ +#define NNACL_FP16_GRAD_ARITHMETIC_GRAD_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ElementDivNegSquareFp16(const float16_t *nom, const float16_t *denom, float16_t *output, int element_size); +void ElementMulAndDivNegSquareFp16(const float16_t *a, const float16_t *b, const float16_t *denom, float16_t *output, + int element_size); +int ElementAbsGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, int element_size); +void MaximumByAxesFp16(const float16_t *input0, const float16_t *input1, const float16_t *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float16_t *output0, float16_t *output1, + int num_dims); +void MinimumByAxesFp16(const float16_t *input0, const float16_t *input1, const float16_t *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float16_t *output0, float16_t *output1, + int num_dims); +int ElementSqrtGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, const int element_size); +int ElementRsqrtGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, const int element_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_ARITHMETIC_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_self_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_self_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..5311f1f84733b9ec6dd13075427d70a32a6a0a09 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_self_grad.c @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp16_grad/arithmetic_self_grad.h" +#include "nnacl_c/errorcode.h" + +int Fp16LogGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t log_10 = vdupq_n_f16(log(10)); + for (; i < length - 4; i += 4) { + float16x8_t src0_4 = vld1q_f16(src0 + i); + float16x8_t src1_4 = vld1q_f16(src1 + i); + float16x8_t dst_4 = vmulq_f16(src0_4, vrecpeq_f16(vmulq_f16(src1_4, log_10))); + vst1q_f16(dst + i, dst_4); + } +#endif + for (; i < length; i++) { + dst[i] = src0[i] * 1.0f / (src1[i] * log(10)); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_self_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_self_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..7e4f6f4a9a4c4c89de84645f80b7af005cc143cd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_self_grad.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRAD_ARITHMETHIC_SELF_GRAD_H_ +#define NNACL_FP16_GRAD_ARITHMETHIC_SELF_GRAD_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "nnacl_c/op_base.h" + +typedef struct ArithmeticSelfGradParameterFp16 { + OpParameter op_parameter; + int type_; +} ArithmeticSelfGradParameterFp16; +#ifdef __cplusplus +extern "C" { +#endif + +int Fp16LogGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP16_GRAD_ARITHMETHIC_SELF_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/batch_norm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/batch_norm.c new file mode 100644 index 0000000000000000000000000000000000000000..33ac0542815491e17d7db58b97c340d2492b1615 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/batch_norm.c @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "nnacl_c/fp16_grad/batch_norm.h" + +void var2InvarFp16(float16_t *save_var, int size, float eps) { + for (int i = 0; i < size; i++) { + save_var[i] = (float16_t)(1.0f / sqrtf((float)save_var[i] + eps)); + } +} + +void backwardAllFp16(const float16_t *restrict in, const float16_t *restrict yt, const float16_t *restrict mean, + const float16_t *restrict invar, const float16_t *restrict scale, int size, int ch, + float *restrict dxhat_sum, float *restrict dxhathat_sum, float16_t *restrict dbias, + float16_t *restrict dscale, float16_t *restrict dx) { + NNACL_CHECK_ZERO_RETURN(size); + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + int ix = i * ch + c; + dbias[c] += yt[ix]; + // dscale + float16_t x_hat = (in[ix] - mean[c]) * invar[c]; + dscale[c] += (yt[ix] * x_hat); + // dx_1 + float dx_hat = (float)(yt[ix] * scale[c]); + dxhat_sum[c] += dx_hat; + dxhathat_sum[c] += (float)(dx_hat * x_hat); + } + } + float N = (float)size; + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + // dx_2 + int ix = i * ch + c; + float16_t x_hat = (in[ix] - mean[c]) * invar[c]; + float16_t dx_hat = yt[ix] * scale[c]; + dx[ix] = (float16_t)((float)((invar[c]) * (N * dx_hat - dxhat_sum[c] - x_hat * dxhathat_sum[c])) / N); + } + } +} +void backwardP1Fp16(const float16_t *restrict in, const float16_t *restrict yt, const float16_t *restrict mean, + const float16_t *restrict invar, const float16_t *restrict scale, int size, int ch, + float *restrict dxhat_sum, float *restrict dxhathat_sum, float16_t *restrict dbias, + float16_t *restrict dscale) { + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + int ix = i * ch + c; + dbias[c] += yt[ix]; + // dscale + float x_hat = (float)((in[ix] - mean[c]) * invar[c]); + dscale[c] += (yt[ix] * x_hat); + // dx_1 + float dx_hat = (float)(yt[ix] * scale[c]); + dxhat_sum[c] += dx_hat; + dxhathat_sum[c] += dx_hat * x_hat; + } + } +} + +void backwardP2Fp16(const float16_t *restrict in, const float16_t *restrict yt, const float16_t *restrict mean, + const float16_t *restrict invar, const float16_t *restrict scale, int size, int total_size, int ch, + const float *dxhat_sum, const float *dxhathat_sum, float16_t *restrict dx) { + NNACL_CHECK_ZERO_RETURN(total_size); + const float N = (float)total_size; + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + // dx_2 + int ix = i * ch + c; + float x_hat = (float)((in[ix] - mean[c]) * invar[c]); + float dx_hat = (float)(yt[ix] * scale[c]); + dx[ix] = (float16_t)(((float)(invar[c]) * (N * dx_hat - dxhat_sum[c] - x_hat * dxhathat_sum[c])) / N); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/batch_norm.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/batch_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..b744aa4d5bfed19b83913cdb7ccac6f8d355a42e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/batch_norm.h @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_BATCH_NORM_H_ +#define NNACL_FP16_GRAD_BATCH_NORM_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void var2InvarFp16(float16_t *save_var, int size, float eps); +void backwardAllFp16(const float16_t *in, const float16_t *yt, const float16_t *mean, const float16_t *invar, + const float16_t *scale, int size, int ch, float *dxhat_sum, float *dxhathat_sum, float16_t *dbias, + float16_t *dscale, float16_t *dx); +void backwardP1Fp16(const float16_t *in, const float16_t *yt, const float16_t *mean, const float16_t *invar, + const float16_t *scale, int size, int ch, float *dxhat_sum, float *dxhathat_sum, float16_t *dbias, + float16_t *dscale); +void backwardP2Fp16(const float16_t *in, const float16_t *yt, const float16_t *mean, const float16_t *invar, + const float16_t *scale, int size, int total_size, int ch, const float *dxhat_sum, + const float *dxhathat_sum, float16_t *dx); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_BATCH_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_filter.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_filter.c new file mode 100644 index 0000000000000000000000000000000000000000..d3b07044e273c13e08fcb475013f2936fc58f140 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_filter.c @@ -0,0 +1,361 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16_grad/convolution_grad_filter.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/errorcode.h" +#ifdef ENABLE_NEON +#include +#endif + +#ifdef ENABLE_NEON + +static int FilterGrad32Arm(const float16_t *x, const float16_t *dy, int i_c, int k_idx, float16_t *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~31); i_c += 32) { + float32x4_t sum_0 = vdupq_n_f32(0.0f); + float32x4_t sum_1 = vdupq_n_f32(0.0f); + float32x4_t sum_2 = vdupq_n_f32(0.0f); + float32x4_t sum_3 = vdupq_n_f32(0.0f); + float32x4_t sum_4 = vdupq_n_f32(0.0f); + float32x4_t sum_5 = vdupq_n_f32(0.0f); + float32x4_t sum_6 = vdupq_n_f32(0.0f); + float32x4_t sum_7 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float16x8_t x_0 = vld1q_f16(x_addr + offset_x); + float16x8_t dy_0 = vld1q_f16(dy_addr + offset_dy); + sum_0 = MS_VMLAL_F16(vget_low_f16(x_0), vget_low_f16(dy_0), sum_0); + sum_1 = MS_VMLAL_F16(vget_high_f16(x_0), vget_high_f16(dy_0), sum_1); + + float16x8_t x_1 = vld1q_f16(x_addr + offset_x + 8); + float16x8_t dy_1 = vld1q_f16(dy_addr + offset_dy + 8); + sum_2 = MS_VMLAL_F16(vget_low_f16(x_1), vget_low_f16(dy_1), sum_2); + sum_3 = MS_VMLAL_F16(vget_high_f16(x_1), vget_high_f16(dy_1), sum_3); + + float16x8_t x_2 = vld1q_f16(x_addr + offset_x + 16); + float16x8_t dy_2 = vld1q_f16(dy_addr + offset_dy + 16); + sum_4 = MS_VMLAL_F16(vget_low_f16(x_2), vget_low_f16(dy_2), sum_4); + sum_5 = MS_VMLAL_F16(vget_high_f16(x_2), vget_high_f16(dy_2), sum_5); + + float16x8_t x_3 = vld1q_f16(x_addr + offset_x + 24); + float16x8_t dy_3 = vld1q_f16(dy_addr + offset_dy + 24); + sum_6 = MS_VMLAL_F16(vget_low_f16(x_3), vget_low_f16(dy_3), sum_6); + sum_7 = MS_VMLAL_F16(vget_high_f16(x_3), vget_high_f16(dy_3), sum_7); + } + } + } + // store into memory + for (int l = 0; l < 4; l++) { + dw[(i_c + l) * k_spatial + k_idx] = sum_0[l]; + dw[(i_c + 4 + l) * k_spatial + k_idx] = sum_1[l]; + dw[(i_c + 8 + l) * k_spatial + k_idx] = sum_2[l]; + dw[(i_c + 12 + l) * k_spatial + k_idx] = sum_3[l]; + dw[(i_c + 16 + l) * k_spatial + k_idx] = sum_4[l]; + dw[(i_c + 20 + l) * k_spatial + k_idx] = sum_5[l]; + dw[(i_c + 24 + l) * k_spatial + k_idx] = sum_6[l]; + dw[(i_c + 28 + l) * k_spatial + k_idx] = sum_7[l]; + } + } + return i_c; +} + +static int FilterGrad16Arm(const float16_t *x, const float16_t *dy, int i_c, int k_idx, float16_t *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~15); i_c += 16) { + float32x4_t sum_0 = vdupq_n_f32(0.0f); + float32x4_t sum_1 = vdupq_n_f32(0.0f); + float32x4_t sum_2 = vdupq_n_f32(0.0f); + float32x4_t sum_3 = vdupq_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float16x8_t x_0 = vld1q_f16(x_addr + offset_x); + float16x8_t dy_0 = vld1q_f16(dy_addr + offset_dy); + sum_0 = MS_VMLAL_F16(vget_low_f16(x_0), vget_low_f16(dy_0), sum_0); + sum_1 = MS_VMLAL_F16(vget_high_f16(x_0), vget_high_f16(dy_0), sum_1); + + float16x8_t x_1 = vld1q_f16(x_addr + offset_x + 8); + float16x8_t dy_1 = vld1q_f16(dy_addr + offset_dy + 8); + sum_2 = MS_VMLAL_F16(vget_low_f16(x_1), vget_low_f16(dy_1), sum_2); + sum_3 = MS_VMLAL_F16(vget_high_f16(x_1), vget_high_f16(dy_1), sum_3); + } + } + } + for (int l = 0; l < 4; l++) { + dw[(i_c + l) * k_spatial + k_idx] = sum_0[l]; + dw[(i_c + l + 4) * k_spatial + k_idx] = sum_1[l]; + dw[(i_c + l + 8) * k_spatial + k_idx] = sum_2[l]; + dw[(i_c + l + 12) * k_spatial + k_idx] = sum_3[l]; + } + } + return i_c; +} + +static int FilterGrad8Arm(const float16_t *x, const float16_t *dy, int i_c, int k_idx, float16_t *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~7); i_c += 8) { + float32x4_t sum_0 = vdupq_n_f32(0.0f); + float32x4_t sum_1 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + float16x8_t x_0 = vld1q_f16(x_addr + offset_x); + float16x8_t dy_0 = vld1q_f16(dy_addr + offset_dy); + sum_0 = MS_VMLAL_F16(vget_low_f16(x_0), vget_low_f16(dy_0), sum_0); + sum_1 = MS_VMLAL_F16(vget_high_f16(x_0), vget_high_f16(dy_0), sum_1); + } + } + } + for (int l = 0; l < 4; l++) { + dw[(i_c + l) * k_spatial + k_idx] = sum_0[l]; + dw[(i_c + 4 + l) * k_spatial + k_idx] = sum_1[l]; + } + } + return i_c; +} + +static int FilterGrad4Arm(const float16_t *x, const float16_t *dy, int i_c, int k_idx, float16_t *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~3); i_c += 4) { + float32x4_t sum_0 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + float16x4_t x_0 = vld1_f16(x_addr + offset_x); + float16x4_t dy_0 = vld1_f16(dy_addr + offset_dy); + sum_0 = MS_VMLAL_F16(x_0, dy_0, sum_0); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_0[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_0[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_0[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_0[3]; + } + return i_c; +} + +static int FilterGradLeftoverArm(const float16_t *x, const float16_t *dy, int i_c, int k_idx, float16_t *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + int leftover = out_ch - i_c; + if (leftover > 0) { + float32x4_t sum_0 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + float16x4_t x_0 = vld1_f16(x_addr + offset_x); + float16x4_t dy_0 = vld1_f16(dy_addr + offset_dy); + sum_0 = MS_VMLAL_F16(x_0, dy_0, sum_0); + } + } + } + for (int l = 0; l < leftover; l++) { + dw[(i_c + l) * k_spatial + k_idx] = sum_0[l]; + } + } + return out_ch; +} + +#endif + +int ConvDwFilterFp16Grad(const float16_t *x, const float16_t *dy, float16_t *dw, int start, int count, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + + for (int i_k = 0; i_k < count; i_k++) { + int k_idx = start + i_k; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + int i_c = 0; +#ifdef ENABLE_NEON + i_c = FilterGrad32Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad16Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad8Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad4Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGradLeftoverArm(x, dy, i_c, k_idx, dw, conv_param); +#endif + for (; i_c < out_ch; i_c++) { + float sum = 0; + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + sum += x_addr[offset_x] * dy_addr[offset_dy]; + } + } + } + dw[i_c * k_spatial + k_idx] = sum; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_filter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_filter.h new file mode 100644 index 0000000000000000000000000000000000000000..ce3a413e398803a89dd7ec990e3d4b63b78c4e28 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_filter.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_CONVOLUTION_GRAD_FILTER_H_ +#define NNACL_FP16_GRAD_CONVOLUTION_GRAD_FILTER_H_ + +#include +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwFilterFp16Grad(const float16_t *x, const float16_t *dy, float16_t *dw, int start, int count, + const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_CONVOLUTION_GRAD_FILTER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_input.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_input.c new file mode 100644 index 0000000000000000000000000000000000000000..4ff6af5e7ff8a34bce233c7f5fbe56ea2b7531cd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_input.c @@ -0,0 +1,332 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16_grad/convolution_grad_input.h" +#include "nnacl_c/errorcode.h" +#ifdef ENABLE_ARM +#include +#endif + +static int ConvDwInputGrad16(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int end, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_spatial = conv_param->output_h_ * conv_param->output_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int k_spatial = k_h * k_w; + int batch = conv_param->input_batch_; + int in_size = in_h * in_w * in_ch; + int out_size = out_h * out_w * out_ch; + + int j = start; + for (; j <= (end - C16NUM); j += C16NUM) { + float16_t *c = dx + j; + const float16_t *mat_b[C16NUM]; + for (int j_i = 0; j_i < C16NUM; j_i++) { + mat_b[j_i] = w + (j + j_i) * k_spatial; + } + for (int si = 0; si < out_spatial; si++) { + const float16_t *a = dy + j + si * out_ch; + int output_row = (si) / out_w; + int output_col = (si) % out_w; + int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_; + int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_; + for (int k = 0; k < k_spatial; k++) { + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; +#ifdef ENABLE_ARM +#ifdef ENABLE_ARM64 + float16x8_t mat_b0 = {mat_b[0][k], mat_b[1][k], mat_b[2][k], mat_b[3][k], + mat_b[4][k], mat_b[5][k], mat_b[6][k], mat_b[7][k]}; + float16x8_t mat_b1 = {mat_b[8][k], mat_b[9][k], mat_b[10][k], mat_b[11][k], + mat_b[12][k], mat_b[13][k], mat_b[14][k], mat_b[15][k]}; +#else + float16x4_t mat_b00; + float16x4_t mat_b01; + float16x4_t mat_b10; + float16x4_t mat_b11; + asm volatile( + "vld1.16 %0[0], [%2]\n" + "vld1.16 %0[1], [%3]\n" + "vld1.16 %0[2], [%4]\n" + "vld1.16 %0[3], [%5]\n" + "vld1.16 %1[0], [%6]\n" + "vld1.16 %1[1], [%7]\n" + "vld1.16 %1[2], [%8]\n" + "vld1.16 %1[3], [%9]\n" + : "=w"(mat_b00), "=w"(mat_b01) + : "r"(mat_b[0] + k), "r"(mat_b[1] + k), "r"(mat_b[2] + k), "r"(mat_b[3] + k), "r"(mat_b[4] + k), + "r"(mat_b[5] + k), "r"(mat_b[6] + k), "r"(mat_b[7] + k) + :); + asm volatile( + "vld1.16 %0[0], [%2]\n" + "vld1.16 %0[1], [%3]\n" + "vld1.16 %0[2], [%4]\n" + "vld1.16 %0[3], [%5]\n" + "vld1.16 %1[0], [%6]\n" + "vld1.16 %1[1], [%7]\n" + "vld1.16 %1[2], [%8]\n" + "vld1.16 %1[3], [%9]\n" + : "=w"(mat_b10), "=w"(mat_b11) + : "r"(mat_b[8] + k), "r"(mat_b[9] + k), "r"(mat_b[10] + k), "r"(mat_b[11] + k), "r"(mat_b[12] + k), + "r"(mat_b[13] + k), "r"(mat_b[14] + k), "r"(mat_b[15] + k) + :); + float16x8_t mat_b0 = vcombine_f16(mat_b00, mat_b01); + float16x8_t mat_b1 = vcombine_f16(mat_b10, mat_b11); +#endif + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + float16x8_t mat_c0 = vld1q_f16(c + dx_offset); + float16x8_t mat_a0 = vld1q_f16(a + dy_offset); + mat_c0 = vfmaq_f16(mat_c0, mat_b0, mat_a0); + vst1q_f16(c + dx_offset, mat_c0); + + float16x8_t mat_c1 = vld1q_f16(c + dx_offset + 8); + float16x8_t mat_a1 = vld1q_f16(a + dy_offset + 8); + mat_c1 = vfmaq_f16(mat_c1, mat_b1, mat_a1); + vst1q_f16(c + dx_offset + 8, mat_c1); + } +#else + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + for (int j_i = 0; j_i < C16NUM; j_i++) { + c[dx_offset + j_i] += a[dy_offset + j_i] * mat_b[j_i][k]; + } + } +#endif + } + } + } + } + return j; +} + +static int ConvDwInputGrad8(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int end, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_spatial = conv_param->output_h_ * conv_param->output_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int k_spatial = k_h * k_w; + int batch = conv_param->input_batch_; + int in_size = in_h * in_w * in_ch; + int out_size = out_h * out_w * out_ch; + + int j = start; + for (; j <= (end - C8NUM); j += C8NUM) { + float16_t *c = dx + j; + const float16_t *mat_b[C8NUM]; + for (int j_i = 0; j_i < C8NUM; j_i++) { + mat_b[j_i] = w + (j + j_i) * k_spatial; + } + + for (int si = 0; si < out_spatial; si++) { + const float16_t *a = dy + j + si * out_ch; + int output_row = (si) / out_w; + int output_col = (si) % out_w; + int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_; + int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_; + for (int k = 0; k < k_spatial; k++) { + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; +#ifdef ENABLE_ARM +#ifdef ENABLE_ARM64 + float16x8_t mat_b0 = {mat_b[0][k], mat_b[1][k], mat_b[2][k], mat_b[3][k], + mat_b[4][k], mat_b[5][k], mat_b[6][k], mat_b[7][k]}; +#else + float16x4_t mat_b00; + float16x4_t mat_b01; + asm volatile( + "vld1.16 %0[0], [%2]\n" + "vld1.16 %0[1], [%3]\n" + "vld1.16 %0[2], [%4]\n" + "vld1.16 %0[3], [%5]\n" + "vld1.16 %1[0], [%6]\n" + "vld1.16 %1[1], [%7]\n" + "vld1.16 %1[2], [%8]\n" + "vld1.16 %1[3], [%9]\n" + : "=w"(mat_b00), "=w"(mat_b01) + : "r"(mat_b[0] + k), "r"(mat_b[1] + k), "r"(mat_b[2] + k), "r"(mat_b[3] + k), "r"(mat_b[4] + k), + "r"(mat_b[5] + k), "r"(mat_b[6] + k), "r"(mat_b[7] + k) + :); + float16x8_t mat_b0 = vcombine_f16(mat_b00, mat_b01); +#endif + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + float16x8_t mat_c0 = vld1q_f16(c + dx_offset); + float16x8_t mat_a0 = vld1q_f16(a + dy_offset); + mat_c0 = vfmaq_f16(mat_c0, mat_b0, mat_a0); + vst1q_f16(c + dx_offset, mat_c0); + } +#else + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + for (int j_i = 0; j_i < C8NUM; j_i++) { + c[dx_offset + j_i] += a[dy_offset + j_i] * mat_b[j_i][k]; + } + } +#endif + } + } + } + } + return j; +} + +static int ConvDwInputGrad4(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int end, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_spatial = conv_param->output_h_ * conv_param->output_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int k_spatial = k_h * k_w; + int batch = conv_param->input_batch_; + int in_size = in_h * in_w * in_ch; + int out_size = out_h * out_w * out_ch; + + int j = start; + for (; j <= (end - C4NUM); j += C4NUM) { + float16_t *c = dx + j; + const float16_t *mat_b_0 = w + (j + 0) * k_spatial; + const float16_t *mat_b_1 = w + (j + 1) * k_spatial; + const float16_t *mat_b_2 = w + (j + 2) * k_spatial; + const float16_t *mat_b_3 = w + (j + 3) * k_spatial; + + for (int si = 0; si < out_spatial; si++) { + const float16_t *a = dy + j + si * out_ch; + int output_row = (si) / out_w; + int output_col = (si) % out_w; + int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_; + int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_; + for (int k = 0; k < k_spatial; k++) { + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; +#ifdef ENABLE_ARM +#ifdef ENABLE_ARM64 + float16x4_t mat_b = {mat_b_0[k], mat_b_1[k], mat_b_2[k], mat_b_3[k]}; +#else + float16x4_t mat_b; + asm volatile( + "vld1.16 %0[0], [%1]\n" + "vld1.16 %0[1], [%2]\n" + "vld1.16 %0[2], [%3]\n" + "vld1.16 %0[3], [%4]\n" + : "=w"(mat_b) + : "r"(mat_b_0 + k), "r"(mat_b_1 + k), "r"(mat_b_2 + k), "r"(mat_b_3 + k) + :); +#endif + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + float16x4_t mat_c = vld1_f16(c + dx_offset); + float16x4_t mat_a = vld1_f16(a + dy_offset); + mat_c = vfma_f16(mat_c, mat_b, mat_a); + vst1_f16(c + dx_offset, mat_c); + } +#else + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + c[dx_offset + 0] += a[dy_offset + 0] * mat_b_0[k]; + c[dx_offset + 1] += a[dy_offset + 1] * mat_b_1[k]; + c[dx_offset + 2] += a[dy_offset + 2] * mat_b_2[k]; + c[dx_offset + 3] += a[dy_offset + 3] * mat_b_3[k]; + } +#endif + } + } + } + } + return j; +} + +int ConvDwInputGradFp16(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int count, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_spatial = conv_param->output_h_ * conv_param->output_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int k_spatial = k_h * k_w; + int end = start + count; + int batch = conv_param->input_batch_; + int in_size = in_h * in_w * in_ch; + int out_size = out_h * out_w * out_ch; + + int j = start; + j = ConvDwInputGrad16(dy, w, dx, j, end, conv_param); + j = ConvDwInputGrad8(dy, w, dx, j, end, conv_param); + j = ConvDwInputGrad4(dy, w, dx, j, end, conv_param); + for (; j < end; j++) { + float16_t *c = dx + j; + const float16_t *b = w + j * k_spatial; + for (int si = 0; si < out_spatial; si++) { + const float16_t *a = dy + j + si * out_ch; + int output_row = si / out_w; + int output_col = si % out_w; + int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_; + int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_; + for (int k = 0; k < k_spatial; k++) { + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; + for (int bi = 0; bi < batch; bi++) { + c[bi * in_size + offset + 0] += a[0 + bi * out_size] * b[k]; + } + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_input.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_input.h new file mode 100644 index 0000000000000000000000000000000000000000..5e7c2485cb5fa73a1a934817ddc95f121a14655f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_input.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ +#define NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ + +#include +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwInputGradFp16(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int count, + const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/dropout_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/dropout_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..48695c53c093f92022c44e6990ef579f4450a7a5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/dropout_grad.c @@ -0,0 +1,24 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16_grad/dropout_grad.h" + +void DropoutFp16Grad(const float16_t *yt_ptr, const float16_t *mask, float16_t *output_ptr, int length, + float16_t scale) { + for (int i = 0; i < length; i++) { + output_ptr[i] = yt_ptr[i] * mask[i] * scale; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/dropout_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/dropout_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..629f199612255aaa3b4caaba042ab678855b4047 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/dropout_grad.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_DROPOUT_GRAD_H_ +#define NNACL_FP16_GRAD_DROPOUT_GRAD_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void DropoutFp16Grad(const float16_t *yt_ptr, const float16_t *mask, float16_t *output_ptr, int length, + float16_t ratio); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_DROPOUT_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..408ef280bb20cbe9939ba6ca5e28ecaad9cb19b7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.c @@ -0,0 +1,385 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16_grad/gemm_fp16.h" +#include +#ifdef __ARM_NEON +#include +#endif +#include "nnacl_c/fp16/matmul_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" + +#ifdef ENABLE_ARM64 +static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) { + size_t stride = col * 2; + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.8h}, [x10], %[stride]\n" + "ld1 {v1.8h}, [x10], %[stride]\n" + "ld1 {v2.8h}, [x10], %[stride]\n" + "ld1 {v3.8h}, [x10], %[stride]\n" + "ld1 {v4.8h}, [x10], %[stride]\n" + "ld1 {v5.8h}, [x10], %[stride]\n" + "ld1 {v6.8h}, [x10], %[stride]\n" + "ld1 {v7.8h}, [x10], %[stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "ld1 {v8.8h}, [x10], %[stride]\n" + "ld1 {v9.8h}, [x10], %[stride]\n" + "ld1 {v10.8h}, [x10], %[stride]\n" + "ld1 {v11.8h}, [x10], %[stride]\n" + "ld1 {v12.8h}, [x10], %[stride]\n" + "ld1 {v13.8h}, [x10], %[stride]\n" + "ld1 {v14.8h}, [x10], %[stride]\n" + "ld1 {v15.8h}, [x10], %[stride]\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip1 v16.8h, v8.8h, v9.8h\n" + "zip1 v17.8h, v10.8h, v11.8h\n" + "zip1 v18.8h, v12.8h, v13.8h\n" + "zip1 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + + "zip2 v16.8h, v0.8h, v1.8h\n" + "zip2 v17.8h, v2.8h, v3.8h\n" + "zip2 v18.8h, v4.8h, v5.8h\n" + "zip2 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v16.8h, v8.8h, v9.8h\n" + "zip2 v17.8h, v10.8h, v11.8h\n" + "zip2 v18.8h, v12.8h, v13.8h\n" + "zip2 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + : + : [dst_c] "r"(dst_ptr), [src_c] "r"(src_ptr), [stride] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif + +void AddMatrixFp16(const float16_t *restrict v1, float16_t *restrict v2, float16_t beta, int row, int col, int stride) { + const float16_t *src_ptr = v1; + float16_t *dst_ptr = v2; +#ifdef ENABLE_NEON + float16x8_t beta_0 = vdupq_n_f16(beta); +#endif + for (int r = 0; r < row; r++) { + int c = 0; +#ifdef ENABLE_NEON + for (; c <= (col - C8NUM); c += C8NUM) { + float16x8_t dst_0 = vld1q_f16(dst_ptr + c); + float16x8_t src_0 = vld1q_f16(src_ptr + c); + float16x8_t sum_0 = vfmaq_f16(dst_0, beta_0, src_0); + vst1q_f16(dst_ptr + c, sum_0); + } +#endif + for (; c < col; c++) { + dst_ptr[c] += beta * src_ptr[c]; + } + src_ptr += stride; + dst_ptr += stride; + } +} + +int MatSizeFp16(int row, int col, int round) { + int res = UP_ROUND(row, round) * col; + return res; +} + +int MatSizeTotalFp16(int row, int col, int deep, int stride) { +#ifdef ENABLE_ARM64 + const int num = C16NUM; +#else + const int num = C12NUM; +#endif + int res = MatSizeFp16(row, deep, num) + MatSizeFp16(col, deep, C8NUM); + if (stride > 0) res += row * stride; + return res; +} + +#ifdef ENABLE_ARM64 +static void RowMajor2Col16MajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { + size_t row_up_16 = UP_ROUND(row, C16NUM); + size_t row16 = row / C16NUM * C16NUM; + size_t col8 = col / C8NUM * C8NUM; + const float16_t *src_r = src; + float16_t *dst_r = dst; + size_t ri = 0; + // find 16 block unit + for (; ri < row16; ri += C16NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; +#ifdef ENABLE_ARM64 + Row2Col16Block16(src_c, dst_c, stride); +#else + for (int tr = 0; tr < C16NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C16NUM + tr] = src_c[tr * stride + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; + for (size_t i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * stride]; + } + } + src_r += C16NUM * stride; + dst_r += C16NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C16NUM] = src_r[i]; + } + src_r += stride; + dst_r += 1; + } + for (; ri < row_up_16; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C16NUM] = 0; + } + dst_r += 1; + } + return; +} +#endif + +void RowMajor2Row16MajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int c_div16 = c / C16NUM; + int c_mod16 = c % C16NUM; + dst[c_div16 * C16NUM * row + r * C16NUM + c_mod16] = src[r * stride + c]; + } + } +} + +void RowMajor2Col12MajorStrideFp16(const float16_t *src, float16_t *dst, size_t row, size_t col, int stride) { + size_t row_up_12 = UP_ROUND(row, C12NUM); + size_t row12 = row / C12NUM * C12NUM; + size_t col8 = col / C8NUM * C8NUM; + const float16_t *src_r = src; + float16_t *dst_r = dst; + size_t ri = 0; + // transpose 12x8 + for (; ri < row12; ri += C12NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; +#ifdef ENABLE_ARM82_A32 + Transpose12x8A32Fp16(src_c, dst_c, stride * sizeof(float16_t), 24); +#else + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C12NUM + tr] = src_c[tr * stride + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; + for (size_t i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * stride]; + } + } + src_r += C12NUM * stride; + dst_r += C12NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C12NUM] = src_r[i]; + } + src_r += stride; + dst_r += 1; + } + for (; ri < row_up_12; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C12NUM] = 0; + } + dst_r += 1; + } +} + +void RowMajor2Row12MajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int c_div12 = c / C12NUM; + int c_mod12 = c % C12NUM; + dst[c_div12 * C12NUM * row + r * C12NUM + c_mod12] = src[r * stride + c]; + } + } +} + +static void RowMajor2Col8MajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r_div8 = r / C8NUM; + int r_mod8 = r % C8NUM; + dst[r_div8 * C8NUM * col + c * C8NUM + r_mod8] = src[r * stride + c]; + } + } +} + +static void RowMajor2Row8MajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { + for (int r = 0; r < row; r++) { + const float16_t *src_ptr = src + r * stride; + int c = 0; + for (; c < col; c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst[cd8 * C8NUM * row + r * C8NUM + cm8] = src_ptr[c]; + } + for (; c < UP_ROUND(col, C8NUM); c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst[cd8 * C8NUM * row + r * C8NUM + cm8] = 0; + } + } + return; +} + +static void RowMajor2ColXMajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorStrideFp16(src, dst, row, col, stride); +#else + RowMajor2Col12MajorStrideFp16(src, dst, row, col, stride); +#endif +} + +static void RowMajor2RowXMajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { +#ifdef ENABLE_ARM64 + RowMajor2Row16MajorStrideFp16(src, dst, row, col, stride); +#else + RowMajor2Row12MajorStrideFp16(src, dst, row, col, stride); +#endif +} + +void GemmMatmulFp16(int ta, int tb, int M, int N, int K, float16_t alpha, const float16_t *mat_a, int lda, + const float16_t *mat_b, int ldb, float16_t beta, float16_t *mat_c, int ldc, float16_t *workspace) { + GemmCbFp16 gcb; + gcb.atype = ActType_No; + gcb.ca = 0; + gcb.cb = 0; + gcb.bias = NULL; + GemmMatmulPlusFp16(ta, tb, M, N, K, alpha, mat_a, lda, mat_b, ldb, beta, mat_c, ldc, workspace, &gcb); +} + +void GemmMatmulPlusFp16(int ta, int tb, int M, int N, int K, float16_t alpha, const float16_t *mat_a, int lda, + const float16_t *mat_b, int ldb, float16_t beta, float16_t *mat_c, int ldc, + float16_t *workspace, GemmCbFp16 *gcb) { +#ifdef ENABLE_ARM64 + const int num = C16NUM; +#else + const int num = C12NUM; +#endif + float16_t *output = mat_c; + float16_t *fworkspace = workspace; + int incremental = (beta < 0.f) || (beta > 0.f); + float16_t *mat_a_input = (float16_t *)mat_a; + float16_t *mat_b_input = (float16_t *)mat_b; + + if (!gcb->ca) { + mat_a_input = fworkspace; + fworkspace += MatSizeFp16(M, K, num); + if (ta) { + RowMajor2RowXMajorStrideFp16(mat_a, mat_a_input, K, M, lda); + } else { + RowMajor2ColXMajorStrideFp16(mat_a, mat_a_input, M, K, lda); + } + } + if (!gcb->cb) { + mat_b_input = fworkspace; + fworkspace += MatSizeFp16(N, K, C8NUM); + if (tb) { + RowMajor2Col8MajorStrideFp16(mat_b, mat_b_input, N, K, ldb); + } else { + RowMajor2Row8MajorStrideFp16(mat_b, mat_b_input, K, N, ldb); + } + } + if (incremental) output = fworkspace; + MatMulFp16(mat_a_input, mat_b_input, output, gcb->bias, gcb->atype, K, M, N, ldc, OutType_Nhwc); + if (incremental) AddMatrixFp16(output, mat_c, beta, M, N, ldc); + gcb->mat_a = mat_a_input; + gcb->mat_b = mat_b_input; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..3f964c27c513700ef2aa3e0eecdd38d13afa2c41 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.h @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_GEMM_FP16_H_ +#define NNACL_FP16_GRAD_GEMM_FP16_H_ + +#include +#include "nnacl_c/op_base.h" +#ifdef __cplusplus +extern "C" { +#endif +typedef struct { + int ca; + int cb; + ActType atype; + float16_t *bias; + float16_t *mat_a; + float16_t *mat_b; +} GemmCbFp16; + +void GemmMatmulFp16(int ta, int tb, int M, int N, int K, float16_t alpha, const float16_t *mat_a, int lda, + const float16_t *mat_b, int ldb, float16_t beta, float16_t *mat_c, int ldc, float16_t *workspace); +void GemmMatmulPlusFp16(int ta, int tb, int M, int N, int K, float16_t alpha, const float16_t *mat_a, int lda, + const float16_t *mat_b, int ldb, float16_t beta, float16_t *mat_c, int ldc, + float16_t *workspace, GemmCbFp16 *gcb); +int MatSizeFp16(int row, int col, int round); +int MatSizeTotalFp16(int row, int col, int deep, int inc); +void AddMatrixFp16(const float16_t *v1, float16_t *v2, float16_t beta, int row, int col, int stride); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_GEMM_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/layernorm_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/layernorm_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..c7a819c45f1cc817b01c02f38815595e61b6396e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/layernorm_grad.c @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16_grad/layernorm_grad.h" +#include +#include + +void LayerNormFp16Grad(const float16_t *x, const float16_t *dy, const float16_t *var, const float16_t *mean, + const float16_t *gamma, int param_num, int param_size, int block_num, int block_size, + float16_t *dx, float16_t *dg, float16_t *db) { + // var is actually 1/sqrf(var)-> var^0.5 + NNACL_CHECK_ZERO_RETURN(block_size); + const float16_t *var_sqrt_rev = var; + for (size_t i = 0; i < param_num; ++i) { + float dgamma = 0.0f; + float dbeta = 0.0f; + for (size_t j = i; j < param_size * param_num; j += param_num) { + int norm_shift = (int)(j / block_size); + dgamma += dy[j] * var_sqrt_rev[norm_shift] * (x[j] - mean[norm_shift]); + dbeta += dy[j]; + } + dg[i] = (float16_t)dgamma; + db[i] = (float16_t)dbeta; + } + for (size_t i = 0; i < block_num; ++i) { + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) { + int param_shift = j % param_num; + int norm_shift = (int)(j / block_size); + float16_t dxm = x[j] - mean[norm_shift]; + float16_t dyg = dy[j] * gamma[param_shift]; + sum1 += -0.5f * dyg * dxm * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift]; + sum2 += dyg; + sum3 += -2.0f * dxm; + } + for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) { + int param_shift = j % param_num; + int norm_shift = (int)(j / block_size); + float16_t var_sqrt = var_sqrt_rev[norm_shift]; + float dx1 = dy[j] * gamma[param_shift] * var_sqrt; + float dx2 = sum1 * 2.0f / block_size * (x[j] - mean[norm_shift]); + float dx3 = (-1.0f * var_sqrt * sum2 + (1.0f / block_size) * sum1 * sum3) * (1.0f / block_size); + dx[j] = (float16_t)(dx1 + dx2 + dx3); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/layernorm_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/layernorm_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..1eca819df5987c5c0a1b127ff0107ea51a859748 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/layernorm_grad.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRAD_LAYERNORM_GRAD_H_ +#define NNACL_FP16_GRAD_LAYERNORM_GRAD_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void LayerNormFp16Grad(const float16_t *x, const float16_t *dy, const float16_t *var, const float16_t *mean, + const float16_t *gamma, int param_num, int param_size, int block_num, int block_size, + float16_t *dx, float16_t *dg, float16_t *db); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_GRAD_LAYERNORM_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pack_fp16_ext.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pack_fp16_ext.c new file mode 100644 index 0000000000000000000000000000000000000000..a2c5b47be8620f9956c6e120605a219227ab3a28 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pack_fp16_ext.c @@ -0,0 +1,201 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/fp16_grad/pack_fp16_ext.h" + +void RollingIm2ColPackDwUnitFp16(const float16_t *in_data, const ConvParameter *conv_param, float16_t *data_col_orig, + int real_cal_num, int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + + const int channels = conv_param->input_channel_; + const int stride = kernel_h * kernel_w; + + int kernel_row, kernel_col; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = start + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + float16_t *data_col = data_col_orig + i * channels * stride; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * channels; + for (int c = 0; c < channels; c++) { + data_col[c * stride] = in_data[offset + c]; + } + data_col++; + } else { + for (int c = 0; c < channels; c++) { + data_col[c * stride] = 0; + } + data_col++; + } + } + } + } +} + +void RollingIm2ColPackUnitFp16(const float16_t *input_data, const ConvParameter *conv_param, float16_t *packed_input, + int real_cal_num, int block_index) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col; + + if (channels == 1) { + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels; + *packed_input = input_data[offset]; + packed_input++; + } else { + *packed_input = 0; + packed_input++; + } + } + } + } + } else { + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels; + memcpy(packed_input, input_data + offset, sizeof(float16_t) * channels); + packed_input += channels; + } else { + memset(packed_input, 0, sizeof(float16_t) * channels); + packed_input += channels; + } + } + } + } + } +} + +void RollingCol2ImPackUnitFp16(const float16_t *data_col, float16_t *data_im, const ConvParameter *conv_param, + int real_cal_num, int block_index) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col; + + if (channels == 1) { + for (int r = 0; r < real_cal_num; r++) { + int output_col = (block_index + r) % output_w; + int output_row = (block_index + r) / output_w; + int row_stride_offset = output_row * stride_h; + int col_stride_offset = output_col * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float16_t *data_im_ptr = data_im + offset; + *data_im_ptr += *data_col; + } + data_col++; + } + } + } + } else { + for (int r = 0; r < real_cal_num; r++) { + int output_col = (block_index + r) % output_w; + int output_row = (block_index + r) / output_w; + int row_stride_offset = output_row * stride_h; + int col_stride_offset = output_col * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float16_t *data_im_ptr = &data_im[offset]; + for (int i = 0; i < channels; i++) { + data_im_ptr[i] += data_col[i]; + } + } + data_col += channels; + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pack_fp16_ext.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pack_fp16_ext.h new file mode 100644 index 0000000000000000000000000000000000000000..0d6d6841757390c2d3f10b01c1a1c3f0254a0fa4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pack_fp16_ext.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_PACK_FP16_EXT_H_ +#define NNACL_FP16_GRAD_PACK_FP16_EXT_H_ + +#include +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void RollingIm2ColPackUnitFp16(const float16_t *input_data, const ConvParameter *conv_param, float16_t *packed_input, + int real_cal_num, int block_index); +void RollingIm2ColPackDwUnitFp16(const float16_t *input_data, const ConvParameter *conv_param, float16_t *packed_input, + int real_cal_num, int block_index); +void RollingCol2ImPackUnitFp16(const float16_t *data_col, float16_t *data_im, const ConvParameter *conv_param, + int real_cal_num, int block_index); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_PACK_FP16_EXT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pooling_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pooling_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..4ab7f96b9beef5b689421539c9779693f67271ab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pooling_grad.c @@ -0,0 +1,192 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "nnacl_c/fp16_grad/pooling_grad.h" +#include "nnacl_c/op_base.h" + +void AvgPoolingFp16Grad(const float16_t *input_ptr, float16_t *output_ptr, int count, PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + + const float16_t kk = 1.0f / (float16_t)(win_h * win_w); +#if ENABLE_NEON + const float16x4_t factor = vdup_n_f16(kk); +#endif + for (int ib = 0; ib < count; ib++) { + float16_t *out = &output_ptr[(ib * in_h * in_w * channel)]; + const float16_t *inPtr = &input_ptr[(ib * output_h * output_w * channel)]; + // iterate over yt + for (int yh = 0; yh < output_h; yh++) { + int over_h = pad_h - yh * stride_h; + int kh_s = MSMAX(0, over_h); + int kh_e = MSMIN(win_h, in_h + over_h); + for (int yw = 0; yw < output_w; yw++) { + int over_w = pad_w - yw * stride_w; + int kw_s = MSMAX(0, over_w); + int kw_e = MSMIN(win_w, in_w + over_w); + int ic = 0; + for (; ic < channel - C4NUM; ic += C4NUM) { + int idx = (yw + yh * output_w) * channel + ic; +#ifdef ENABLE_NEON + float16x4_t in = vld1_f16(inPtr + idx); + float16x4_t delta = vmul_f16(in, factor); +#else + float16_t delta[C4NUM] = {inPtr[idx], inPtr[idx + C1NUM], inPtr[idx + C2NUM], inPtr[idx + C3NUM]}; + for (int i = 0; i < C4NUM; i++) delta[i] *= kk; +#endif + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; +#ifdef ENABLE_NEON + float16_t *out_vec = out + (xw + in_w * xh) * channel + ic; + float16x4_t outr = vld1_f16(out + (xw + in_w * xh) * channel + ic); + float16x4_t outs = vadd_f16(outr, delta); + vst1_f16(out_vec, outs); +#else + + for (int i = 0; i < C4NUM; i++) { + out[(xw + in_w * xh) * channel + ic + i] += ((float16_t *)&delta)[i]; + } +#endif + } + } + } + for (; ic < channel; ic++) { + int idx = (yw + yh * output_w) * channel + ic; + float16_t delta = inPtr[idx] * kk; + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + out[(xw + in_w * xh) * channel + ic] += delta; + } + } + } + } + } + } +} + +#ifdef ENABLE_NEON +static int32x4_t MaxIndex(float16x4_t in, float16x4_t *max, uint32x4_t index, uint32x4_t prev_index) { + uint16x4_t res = vcgt_f16(in, *max); + int16x4_t tmp = vreinterpret_s16_u16(res); + uint32x4_t res_tmp = vreinterpretq_u32_s32(vmovl_s16(tmp)); + int32x4_t m_index = vbslq_s32(res_tmp, (int32x4_t)index, (int32x4_t)prev_index); + *max = vbsl_f16(res, in, *max); + return m_index; +} +#endif + +void MaxPoolingFp16Grad(const float16_t *input_ptr, const float16_t *dy_ptr, float16_t *output_ptr, int output_batch, + PoolingParameter *pooling_param, const PoolingComputeParam *pooling_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + + for (int ib = 0; ib < output_batch; ib++) { + float16_t *out = &output_ptr[(ib * in_h * in_w * channel)]; + const float16_t *inPtr = &input_ptr[(ib * in_h * in_w * channel)]; + const float16_t *dyPtr = &dy_ptr[(ib * output_h * output_w * channel)]; + for (int yh = 0; yh < output_h; yh++) { + int over_h = pad_h - yh * stride_h; + int kh_s = MSMAX(0, over_h); + int kh_e = MSMIN(win_h, in_h + over_h); + for (int yw = 0; yw < output_w; yw++) { + int over_w = pad_w - yw * stride_w; + int kw_s = MSMAX(0, over_w); + int kw_e = MSMIN(win_w, in_w + over_w); + int ic = 0; + for (; ic < (channel & ~3); ic += C4NUM) { + int idx = (yw + yh * output_w) * channel + ic; +#ifdef ENABLE_NEON + uint32x4_t max_idx = vdupq_n_u32(0); + float16x4_t max_val = vdup_n_f16(-FLT16_MAX); + float16x4_t delta = vld1_f16(dyPtr + idx); +#else + float16_t delta[C4NUM] = {dyPtr[idx], dyPtr[idx + C1NUM], dyPtr[idx + C2NUM], dyPtr[idx + C3NUM]}; + float16_t max_val[C4NUM] = {-FLT16_MAX, -FLT16_MAX, -FLT16_MAX, -FLT16_MAX}; + uint max_idx[C4NUM] = {0}; +#endif + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + int val_idx = (xw + in_w * xh) * channel + ic; +#ifdef ENABLE_NEON + uint32x4_t index = {val_idx, val_idx + 1, val_idx + 2, val_idx + 3}; + float16x4_t in = vld1_f16(inPtr + val_idx); + max_idx = (uint32x4_t)MaxIndex(in, &max_val, index, max_idx); +#else + float16_t val[C4NUM] = {inPtr[val_idx], inPtr[val_idx + C1NUM], inPtr[val_idx + C2NUM], + inPtr[val_idx + C3NUM]}; + for (int i = 0; i < C4NUM; i++) { + if (val[i] > max_val[i]) { + max_val[i] = val[i]; + max_idx[i] = val_idx + i; + } + } +#endif + } + } + for (int i = 0; i < C4NUM; i++) { + out[((int *)&max_idx)[i]] += ((float16_t *)&delta)[i]; + } + } + for (; ic < channel; ic++) { + float16_t max_val = -FLT16_MAX; + int max_idx = 0; + int idx = (yw + yh * output_w) * channel + ic; + float16_t delta = dyPtr[idx]; + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_e; kw < kw_s; kw++) { + int xw = yw * stride_w + kw - pad_w; + int val_idx = (xw + in_w * xh) * channel + ic; + float16_t val = inPtr[val_idx]; + if (val > max_val) { + max_val = val; + max_idx = val_idx; + } + } + } + out[max_idx] += delta; + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pooling_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pooling_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..5c2bdee201e2f1ee6774eb1a1817be1136ab8955 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pooling_grad.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_POOLING_GRAD_H_ +#define NNACL_FP16_GRAD_POOLING_GRAD_H_ + +#include "nnacl_c/fp16/pooling_fp16.h" +#include "nnacl_c/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +void AvgPoolingFp16Grad(const float16_t *input_ptr, float16_t *output_ptr, int count, PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args); +void MaxPoolingFp16Grad(const float16_t *input_ptr, const float16_t *dy_ptr, float16_t *output_ptr, int output_batch, + PoolingParameter *pooling_param, const PoolingComputeParam *pooling_args); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_POOLING_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/resize_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/resize_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..6806293a241392adb0f38aaf158acdddd9acbe21 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/resize_grad.c @@ -0,0 +1,146 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16_grad/resize_grad.h" +#include +#include "nnacl_c/infer/common_infer.h" + +int ResizeNearestNeighborFp16Grad(float16_t *in_addr, float16_t *out_addr, int batch_size, int channel, int format, + ResizeFp16GradParameter *param) { + bool align_corners = param->align_corners_; + size_t in_hw_size = param->in_width_ * param->in_height_; + size_t out_hw_size = param->out_width_ * param->out_height_; + + if (format == Format_NHWC) { + NNACL_CHECK_ZERO_RETURN_ERR(param->in_width_); + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t i = 0; i < in_hw_size; ++i) { + size_t in_y = i / param->in_width_; + size_t in_x = i % param->in_width_; + for (int32_t c = 0; c < channel; ++c) { + size_t out_y = MSMIN( + (align_corners) ? (size_t)roundf(in_y * param->height_scale_) : (size_t)floorf(in_y * param->height_scale_), + param->out_height_ - 1); + size_t out_x = MSMIN( + (align_corners) ? (size_t)roundf(in_x * param->width_scale_) : (size_t)floorf(in_x * param->width_scale_), + param->out_width_ - 1); + size_t out_offset = out_y * (param->out_width_ * channel) + (out_x * channel) + c; + size_t in_offset = in_y * (param->in_width_ * channel) + (in_x * channel) + c; + out_addr[out_offset] += in_addr[in_offset]; + } + } + out_addr += out_hw_size * channel; + in_addr += in_hw_size * channel; + } + } else if (format == Format_NCHW) { + for (int32_t b = 0; b < batch_size; ++b) { + for (int32_t c = 0; c < channel; ++c) { + for (size_t h = 0; h < param->in_height_; ++h) { + size_t out_y = + MSMIN((align_corners) ? (size_t)roundf(h * param->height_scale_) : (size_t)floorf(h * param->height_scale_), + param->out_height_ - 1); + for (size_t w = 0; w < param->in_width_; ++w) { + size_t out_x = + MSMIN((align_corners) ? (size_t)roundf(w * param->width_scale_) : (size_t)floorf(w * param->width_scale_), + param->out_width_ - 1); + out_addr[out_y * param->out_width_ + out_x] += in_addr[h * param->in_width_ + w]; + } + } + out_addr += out_hw_size; + in_addr += in_hw_size; + } + } + } + return NNACL_OK; +} + +int ResizeBiLinearFp16Grad(float16_t *in_addr, float16_t *out_addr, int batch_size, int channel, int format, + ResizeFp16GradParameter *param) { + size_t in_hw_size = param->in_width_ * param->in_height_; + size_t out_hw_size = param->out_width_ * param->out_height_; + + if (format == Format_NHWC) { + NNACL_CHECK_ZERO_RETURN_ERR(param->in_width_); + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t i = 0; i < in_hw_size; ++i) { + size_t h = i / param->in_width_; + size_t w = i % param->in_width_; + for (int32_t c = 0; c < channel; ++c) { + float16_t in_y = (float16_t)h * param->height_scale_; + size_t top_y_index = MSMAX((size_t)(floorf(in_y)), (size_t)(0)); + size_t bottom_y_index = MSMIN((size_t)(ceilf(in_y)), param->out_height_ - 1); + float16_t y_lerp = in_y - floorf(in_y); + const float16_t inverse_y_lerp = 1.0 - y_lerp; + + float16_t in_x = (float16_t)w * param->width_scale_; + size_t left_x_index = MSMAX((size_t)(floorf(in_x)), (size_t)(0)); + size_t right_x_index = MSMIN((size_t)(ceilf(in_x)), param->out_width_ - 1); + float16_t x_lerp = in_x - floorf(in_x); + const float16_t inverse_x_lerp = 1.0 - x_lerp; + + size_t in_offset = h * (param->in_width_ * channel) + (w * channel) + c; + size_t out_offset_top_y_left_x = top_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; + size_t out_offset_top_y_right_x = top_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; + size_t out_offset_bottom_y_left_x = + bottom_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; + size_t out_offset_bottom_y_right_x = + bottom_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; + + out_addr[out_offset_top_y_left_x] += in_addr[in_offset] * (float16_t)(inverse_y_lerp * inverse_x_lerp); + out_addr[out_offset_top_y_right_x] += in_addr[in_offset] * (float16_t)(inverse_y_lerp * x_lerp); + out_addr[out_offset_bottom_y_left_x] += in_addr[in_offset] * (float16_t)(y_lerp * inverse_x_lerp); + out_addr[out_offset_bottom_y_right_x] += in_addr[in_offset] * (float16_t)(y_lerp * x_lerp); + } + } + out_addr += out_hw_size * channel; + in_addr += in_hw_size * channel; + } + } else if (format == Format_NCHW) { + size_t in_height = param->in_height_; + size_t in_width = param->in_width_; + size_t out_height = param->out_height_; + size_t out_width = param->out_width_; + + for (size_t b = 0; b < batch_size; ++b) { + for (size_t c = 0; c < channel; ++c) { + for (size_t h = 0; h < in_height; ++h) { + const float16_t in_y = (float16_t)(h)*param->height_scale_; + const size_t top_y_index = MSMAX((size_t)floorf(in_y), 0); + const size_t bottom_y_index = MSMIN((size_t)ceilf(in_y), out_height - 1); + const float16_t y_lerp = in_y - floorf(in_y); + const float16_t inverse_y_lerp = 1.0 - y_lerp; + for (size_t w = 0; w < in_width; ++w) { + const float16_t in_x = (float16_t)(w)*param->width_scale_; + const size_t left_x_index = MSMAX((size_t)floorf(in_x), 0); + const size_t right_x_index = MSMIN((size_t)ceilf(in_x), out_width - 1); + const float16_t x_lerp = in_x - floorf(in_x); + const float16_t inverse_x_lerp = 1.0 - x_lerp; + out_addr[top_y_index * out_width + left_x_index] += + in_addr[h * in_width + w] * (float16_t)(inverse_y_lerp * inverse_x_lerp); + out_addr[top_y_index * out_width + right_x_index] += + in_addr[h * in_width + w] * (float16_t)(inverse_y_lerp * x_lerp); + out_addr[bottom_y_index * out_width + left_x_index] += + in_addr[h * in_width + w] * (float16_t)(y_lerp * inverse_x_lerp); + out_addr[bottom_y_index * out_width + right_x_index] += + in_addr[h * in_width + w] * (float16_t)(y_lerp * x_lerp); + } + } + out_addr += out_hw_size; + in_addr += in_hw_size; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/resize_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/resize_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..e25fd0b4406124e37dc84249367bbd6f5afdef6c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/resize_grad.h @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_RESIZE_GRAD_H_ +#define NNACL_FP16_GRAD_RESIZE_GRAD_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct ResizeFp16GradParameter { + OpParameter op_parameter_; + bool align_corners_; + int method; + size_t in_height_; + size_t in_width_; + size_t out_height_; + size_t out_width_; + float16_t height_scale_; + float16_t width_scale_; +} ResizeFp16GradParameter; + +int ResizeNearestNeighborFp16Grad(float16_t *in_addr, float16_t *out_addr, int batch_size, int channel, int format, + ResizeFp16GradParameter *param); +int ResizeBiLinearFp16Grad(float16_t *in_addr, float16_t *out_addr, int batch_size, int channel, int format, + ResizeFp16GradParameter *param); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_GRAD_RESIZE_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/strided_slice_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/strided_slice_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..775824681b6043c8219396bd59b5c7b670ba51e0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/strided_slice_grad.c @@ -0,0 +1,67 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16_grad/strided_slice_grad.h" +#include "nnacl_c/errorcode.h" + +static size_t CalcIndex(const int *shape, size_t size, int i, size_t pos) { + size_t res = 1; + for (size_t j = 0; j < size; j++) { + res *= shape[(i + 1) + j]; + } + NNACL_CHECK_ZERO_RETURN_ERR(res); + NNACL_CHECK_ZERO_RETURN_ERR(shape[i]); + return (pos / res % shape[i]); +} + +int DoStridedSliceFp16Grad(const float16_t *inputs, float16_t *output, const int *dx_shape, + StridedSliceParameter *param) { + if (inputs == NULL || output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->num_axes_ > DIMENSION_7D) { + return NNACL_PARAM_INVALID; + } + + size_t size = 1; + int *s = param->strides_; + int *b = param->begins_; + for (int i = 0; i < DIMENSION_7D; i++) { + size *= param->in_shape_[i]; + } + + for (size_t pos = 0; pos < size; pos++) { + size_t i = CalcIndex(param->in_shape_, C6NUM, C0NUM, pos); + size_t j = CalcIndex(param->in_shape_, C5NUM, C1NUM, pos); + size_t k = CalcIndex(param->in_shape_, C4NUM, C2NUM, pos); + size_t l = CalcIndex(param->in_shape_, C3NUM, C3NUM, pos); + size_t m = CalcIndex(param->in_shape_, C2NUM, C4NUM, pos); + size_t n = CalcIndex(param->in_shape_, C1NUM, C5NUM, pos); + size_t o = CalcIndex(param->in_shape_, C0NUM, C6NUM, pos); + + size_t input_idx = + (i * s[C0NUM] + b[C0NUM]) * dx_shape[C1NUM] * dx_shape[C2NUM] * dx_shape[C3NUM] * dx_shape[C4NUM] * + dx_shape[C5NUM] * dx_shape[C6NUM] + + (j * s[C1NUM] + b[C1NUM]) * dx_shape[C2NUM] * dx_shape[C3NUM] * dx_shape[C4NUM] * dx_shape[C5NUM] * + dx_shape[C6NUM] + + (k * s[C2NUM] + b[C2NUM]) * dx_shape[C3NUM] * dx_shape[C4NUM] * dx_shape[C5NUM] * dx_shape[C6NUM] + + (l * s[C3NUM] + b[C3NUM]) * dx_shape[C4NUM] * dx_shape[C5NUM] * dx_shape[C6NUM] + + (m * s[C4NUM] + b[C4NUM]) * dx_shape[C5NUM] * dx_shape[C6NUM] + (n * s[C5NUM] + b[C5NUM]) * dx_shape[C6NUM] + + (o * s[C6NUM] + b[C6NUM]); + output[input_idx] = inputs[pos]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/strided_slice_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/strided_slice_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..6a79e8e9277885b1e0de816be1e166a74db8fbd0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/strided_slice_grad.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRAD_STRIDED_SLICE_GRAD_H_ +#define NNACL_FP16_GRAD_STRIDED_SLICE_GRAD_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/strided_slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoStridedSliceFp16Grad(const float16_t *inputs, float16_t *output, const int *dx_shape, + StridedSliceParameter *param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_STRIDED_SLICE_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/unsorted_segment_sum.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/unsorted_segment_sum.c new file mode 100644 index 0000000000000000000000000000000000000000..8f794d42e3600871c9dd8b623f2608f35616badd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/unsorted_segment_sum.c @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16_grad/unsorted_segment_sum.h" +#include "nnacl_c/errorcode.h" + +int UnsortedSegmentSumFp16(const float16_t *input, int unit_num, int input_dim1, const int *indices, float16_t *output, + int output_dim0, int output_dim1) { + NNACL_CHECK_ZERO_RETURN_ERR(input_dim1); + for (int i = 0; i < unit_num; ++i) { + int j = i / input_dim1; + int k = i % input_dim1; + + int index = indices[j]; + if (index < 0 || index >= output_dim0) { + continue; + } + int output_index = index * output_dim1 + k; + output[output_index] += input[i]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/unsorted_segment_sum.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/unsorted_segment_sum.h new file mode 100644 index 0000000000000000000000000000000000000000..85a54ab2f6be99157b361ea7a02ab551a920cc51 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/unsorted_segment_sum.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_UNSORTED_SEGMENT_SUM_H_ +#define NNACL_FP16_GRAD_UNSORTED_SEGMENT_SUM_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int UnsortedSegmentSumFp16(const float16_t *input, int unit_num, int input_dim1, const int *indices, float16_t *output, + int output_dim0, int output_dim1); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_GRAD_UNSORTED_SEGMENT_SUM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..f284b75aeaf925835c2f5fa8e8c0f5cd519c23bc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32.c @@ -0,0 +1,292 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/activation_fp32_simd.h" + +int Fp32Relu(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Fp32Relu, i, src, length, dst); + + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : 0; + } + return NNACL_OK; +} + +int Int32Relu(const int32_t *src, int length, int32_t *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Int32Relu, i, src, length, dst); + + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : 0; + } + return NNACL_OK; +} + +int Fp32Relu6(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Fp32Relu6, i, src, length, dst); + + for (; i < length; ++i) { + if (src[i] < 0) { + dst[i] = 0; + } else { + dst[i] = src[i] > 6.0f ? 6.0f : src[i]; // relu 6.0 + } + } + return NNACL_OK; +} + +int Fp32Clip(const float *src, int length, float *dst, float min, float max) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Fp32Clip, i, src, length, dst, min, max); + + for (; i < length; ++i) { + if (src[i] < min) { + dst[i] = min; + } else { + dst[i] = src[i] > max ? max : src[i]; + } + } + return NNACL_OK; +} + +int Int32Clip(const int32_t *src, int length, int32_t *dst, int min, int max) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Int32Clip, i, src, length, dst, min, max); + + for (; i < length; ++i) { + if (src[i] < min) { + dst[i] = min; + } else { + dst[i] = src[i] > max ? max : src[i]; + } + } + return NNACL_OK; +} + +int LRelu(const float *src, int length, float *dst, float alpha) { + int i = 0; + + SIMD_RUN_NO_SCALAR(LRelu, i, src, length, dst, alpha); + + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : (src[i] * alpha); + } + return NNACL_OK; +} + +int Sigmoid(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Sigmoid, i, src, length, dst); + + for (; i < length; ++i) { + simd_exp32(-src[i], dst + i); + dst[i] = 1.0f / (1.0f + dst[i]); + } + return NNACL_OK; +} + +float TanhOpt(float src) { + if (src > 5.0) { // src > 5.0, tanh(src) = 1.0f + return 1.0f; + } else if (src < -5.0) { // src < -5.0, tanh(src) = -1.0f + return -1.0f; + } else { + float square = src * src; + float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * src; + float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; + return a / b; + } +} + +int Tanh(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Tanh, i, src, length, dst); + + for (; i < length; ++i) { + dst[i] = TanhOpt(src[i]); + } + return NNACL_OK; +} + +int Swish(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Swish, i, src, length, dst); + + for (; i < length; ++i) { + simd_exp32(-src[i], dst + i); + dst[i] = src[i] / (1.0f + dst[i]); + } + return NNACL_OK; +} + +int HSwish(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(HSwish, i, src, length, dst); + + for (; i < length; ++i) { + float in = src[i]; + float relu6 = MSMIN(MSMAX(in + C3NUM, 0), C6NUM); + dst[i] = in * relu6 / C6NUM; + } + return NNACL_OK; +} + +int HSigmoid(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(HSigmoid, i, src, length, dst); + + for (; i < length; ++i) { + float relu6 = MSMIN(MSMAX(src[i] + C3NUM, 0), C6NUM); + dst[i] = relu6 / C6NUM; + } + return NNACL_OK; +} + +int HardTanh(const float *src, int length, float *dst, float min_val, float max_val) { + if (max_val <= min_val) { + return NNACL_ERR; + } + int i = 0; + if (min_val == FLT_MIN) { + SIMD_RUN_NO_SCALAR(HardTanhNoLimitMin, i, src, length, dst, min_val, max_val); + + for (; i < length; ++i) { + dst[i] = src[i] > max_val ? max_val : src[i]; + } + } else if (max_val == FLT_MAX) { + SIMD_RUN_NO_SCALAR(HardTanhNoLimitMax, i, src, length, dst, min_val, max_val); + + for (; i < length; ++i) { + dst[i] = src[i] < min_val ? min_val : src[i]; + } + } else { + SIMD_RUN_NO_SCALAR(HardTanhLimitMinMax, i, src, length, dst, min_val, max_val); + + for (; i < length; ++i) { + dst[i] = src[i] < min_val ? min_val : (src[i] > max_val ? max_val : src[i]); + } + } + return NNACL_OK; +} + +int Gelu(const float *src, int length, float *dst, bool approximate) { + if (src == NULL || dst == NULL) { + return NNACL_ERR; + } + int i = 0; + if (approximate) { + SIMD_RUN_NO_SCALAR(GeluTanhApproximate, i, src, length, dst); + + // dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3))) + for (; i < length; i++) { + dst[i] = 0.5 * src[i] * (1.0 + TanhOpt((0.79788456080287f + 0.035677408136f * src[i] * src[i]) * src[i])); + } + } else { + SIMD_RUN_NO_SCALAR(GeluErfAPPROXIMATE, i, src, length, dst); + + for (; i < length; i++) { + dst[i] = + 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f)); // dst = 0.5 * x * (1.0 + x / 1.4142135623730951f)) + } + } + return NNACL_OK; +} + +int Softplus(const float *src, int length, float *dst) { + float log_max = 88.0; + int i = 0; + + SIMD_RUN_NO_SCALAR(Softplus, i, src, length, dst); + + for (; i < length; ++i) { + simd_exp32(src[i], dst + i); + if (src[i] > log_max) { + dst[i] = src[i]; + } else { + dst[i] = log1p(dst[i]); + } + } + return NNACL_OK; +} + +int Elu(const float *src, int length, float *dst, float alpha) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Elu, i, src, length, dst, alpha); + + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : (expm1(src[i]) * alpha); + } + return NNACL_OK; +} + +void Celu(const float *src, int length, float *dst, float alpha) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Celu, i, src, length, dst, alpha); + + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : (expm1(src[i] / alpha) * alpha); + } + return; +} + +int HardShrink(const float *src, int length, float *dst, float lambd) { + int i = 0; + const float neg_lambd = -1 * lambd; + SIMD_RUN_NO_SCALAR(HardShrink, i, src, length, dst, lambd); + + for (; i < length; ++i) { + dst[i] = src[i] >= neg_lambd && src[i] <= lambd ? 0 : src[i]; + } + return NNACL_OK; +} + +int SoftShrink(const float *src, int length, float *dst, float lambd) { + int i = 0; + const float neg_lambd = -1 * lambd; + SIMD_RUN_NO_SCALAR(SoftShrink, i, src, length, dst, lambd); + + for (; i < length; ++i) { + dst[i] = (src[i] > lambd) ? (src[i] - lambd) : ((src[i] < neg_lambd) ? (src[i] + lambd) : (0)); + } + return NNACL_OK; +} + +int SoftsignFp32Opt(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(SoftsignFp32Opt, i, src, length, dst); + for (; i < length; ++i) { + dst[i] = src[i] / (1.0 + fabsf(src[i])); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..28a3da19e31b20f45f243557c9d97486d7dd40c1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ACTIVATION_H_ +#define MINDSPORE_NNACL_FP32_ACTIVATION_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/activation_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int Fp32Relu(const float *src, int length, float *dst); +int Int32Relu(const int32_t *src, int length, int32_t *dst); +int Fp32Relu6(const float *src, int length, float *dst); +int Fp32Clip(const float *src, int length, float *dst, float min, float max); +int Int32Clip(const int32_t *src, int length, int32_t *dst, int min, int max); +int LRelu(const float *src, int length, float *dst, float alpha); +int Sigmoid(const float *src, int length, float *dst); +int Tanh(const float *src, int length, float *dst); +int HSigmoid(const float *src, int length, float *dst); +int Swish(const float *src, int length, float *dst); +int HSwish(const float *src, int length, float *dst); +int HardTanh(const float *src, int length, float *dst, float min_val, float max_val); +int Gelu(const float *src, int length, float *dst, bool approximate); +int Softplus(const float *src, int length, float *dst); +int Elu(const float *src, int length, float *dst, float alpha); +void Celu(const float *src, int length, float *dst, float alpha); +float TanhOpt(float src); +int HardShrink(const float *src, int length, float *dst, float lambd); +int SoftShrink(const float *src, int length, float *dst, float lambd); +int SoftsignFp32Opt(const float *src, int length, float *dst); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_ACTIVATION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..8e9c48bc18a9e28610bc81d1c746b1abe93fc400 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32_simd.h.in @@ -0,0 +1,289 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int Fp32Relu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + SIMD_F32 zero = SIMD_SET0_F32; + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_MAX_F32(SIMD_LD_F32(src + index), zero)); + } + return index; +} + +static inline int Int32Relu@SIMD_INSTRUCTION@(int index, const int32_t *src, int length, int32_t *dst) { + SIMD_EPI32 zero = SIMD_MOV_EPI32(0.0f); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(dst + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(src + index), zero)); + } + return index; +} + +static inline int Fp32Relu6@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + SIMD_F32 zero = SIMD_SET0_F32; + SIMD_F32 six = SIMD_MOV_F32(6.0f); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_CLAMP_F32(SIMD_LD_F32(src + index), zero, six)); + } + return index; +} + +static inline int Fp32Clip@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float min, float max) { + SIMD_F32 min_val = SIMD_MOV_F32(min); + SIMD_F32 max_val = SIMD_MOV_F32(max); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_CLAMP_F32(SIMD_LD_F32(src + index), min_val, max_val)); + } + return index; +} + +static inline int Int32Clip@SIMD_INSTRUCTION@(int index, const int32_t *src, int length, int32_t *dst, int min, int max) { + SIMD_EPI32 min_val = SIMD_MOV_EPI32(min); + SIMD_EPI32 max_val = SIMD_MOV_EPI32(max); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(dst + index, SIMD_CLAMP_EPI32(SIMD_LD_EPI32(src + index), min_val, max_val)); + } + return index; +} + +static inline int LRelu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float alpha) { + SIMD_F32 alpha_data = SIMD_MOV_F32(alpha); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_tmp = SIMD_LD_F32(src + index); + SIMD_MASK mask = SIMD_CMPGT_F32(SIMD_SET0_F32, src_tmp); + SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src_tmp, SIMD_MUL_F32(src_tmp, alpha_data), mask)); + } + return index; +} + +static inline int Sigmoid@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EXP_ST_F32(SIMD_SUB_F32(SIMD_SET0_F32, (SIMD_LD_F32(src + index))), dst + index); + SIMD_ST_F32(dst + index, + SIMD_DIV_F32(SIMD_MOV_F32(1.0f), SIMD_ADD_F32(SIMD_MOV_F32(1.0f), SIMD_LD_F32(dst + index)))); + } + return index; +} + +static inline int Softplus@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + SIMD_F32 log_max = SIMD_MOV_F32(88.0); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_tmp = SIMD_LD_F32(src + index); + SIMD_F32 dst_tmp = SIMD_EXP_F32(src_tmp); + dst_tmp = SIMD_LOG_F32(SIMD_ADD_F32(SIMD_MOV_F32(1.0f), dst_tmp)); + SIMD_ST_F32(dst + index, SIMD_BLEND_F32(dst_tmp, src_tmp, SIMD_CMPGT_F32(src_tmp, log_max))); + } + return index; +} + +static inline int Tanh@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_LD_F32(src + index); + SIMD_ST_F32(dst + index, SIMD_TANH_F32(input)); + } + return index; +} + +static inline int Swish@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_value = SIMD_LD_F32(src + index); + SIMD_EXP_ST_F32(SIMD_SUB_F32(SIMD_SET0_F32, src_value), dst + index); + SIMD_ST_F32(dst + index, + SIMD_DIV_F32(src_value, SIMD_ADD_F32(SIMD_MOV_F32(1.0f), SIMD_LD_F32(dst + index)))); + } + return index; +} + +static inline int HSwish@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_value = SIMD_LD_F32(src + index); + SIMD_F32 relu6 = SIMD_CLAMP_N_F32(SIMD_ADD_N_F32(src_value, 3), 0, 6); + SIMD_ST_F32(dst + index, SIMD_DIV_N_F32(SIMD_MUL_F32(src_value, relu6), 6)); + } + return index; +} + +static inline int HSigmoid@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_value = SIMD_LD_F32(src + index); + SIMD_F32 relu6 = SIMD_CLAMP_N_F32(SIMD_ADD_N_F32(src_value, 3), 0, 6); + SIMD_ST_F32(dst + index, SIMD_DIV_N_F32(relu6, 6)); + } + return index; +} + +static inline int HardTanhNoLimitMin@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float min_val, + float max_val) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_MIN_N_F32(SIMD_LD_F32(src + index), max_val)); + } + return index; +} + +static inline int HardTanhNoLimitMax@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float min_val, + float max_val) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_MAX_N_F32(SIMD_LD_F32(src + index), min_val)); + } + return index; +} + +static inline int HardTanhLimitMinMax@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float min_val, + float max_val) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_CLAMP_N_F32(SIMD_LD_F32(src + index), min_val, max_val)); + } + return index; +} + +static inline int GeluTanhApproximate@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in = SIMD_LD_F32(src + index); + SIMD_F32 tmp1 = SIMD_FMADD_F32(SIMD_MUL_N_F32(in, 0.035677408136f), in, SIMD_MOV_F32(0.79788456080287f)); + SIMD_F32 tmp2 = SIMD_MUL_F32(tmp1, in); + SIMD_ST_F32(dst + index, SIMD_MUL_F32(SIMD_MUL_N_F32(in, 0.5f), SIMD_ADD_N_F32(SIMD_TANH_F32(tmp2), 1.0f))); + } + return index; +} + +static inline int Gelu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + SIMD_F32 para1 = SIMD_MOV_F32(1.4142135623730951f); + SIMD_F32 para2 = SIMD_MOV_F32(1.0f); + SIMD_F32 para3 = SIMD_MOV_F32(0.5f); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in = SIMD_LD_F32(src + index); + SIMD_F32 res = SIMD_MUL_F32(SIMD_MUL_F32(para3, in), SIMD_ADD_F32(para2, SIMD_ERF_F32(SIMD_DIV_F32(in, para1)))); + SIMD_ST_F32(dst + index, res); + } + return index; +} + +static inline SIMD_F32 SIMD_ERFCCHEB@SIMD_INSTRUCTION@(SIMD_F32 src) { + static const int ncof = 7; + const double cof[7] = {-1.3026537197817094, 6.4196979235649026e-1, 1.9476473204185836e-2, -9.561514786808631e-3, + -9.46595344482036e-4, 3.66839497852761e-4, 4.2523324806907e-5}; + SIMD_F32 dst; + SIMD_F32 d = SIMD_SET0_F32; + SIMD_F32 dd = SIMD_SET0_F32; + SIMD_F32 t = SIMD_DIV_F32(SIMD_MOV_F32(2.0f), SIMD_ADD_F32(src, SIMD_MOV_F32(2.0f))); + SIMD_F32 ty = SIMD_SUB_F32(SIMD_MUL_F32(SIMD_MOV_F32(4.0f), t), SIMD_MOV_F32(2.0f)); + + for (int j = ncof - 1; j > 0; j--) { + SIMD_F32 tmp = d; + d = SIMD_SUB_F32(SIMD_FMADD_F32(ty, d, SIMD_MOV_F32(cof[j])), dd); + dd = tmp; + } + + dst = + SIMD_FMADD_F32(src, src, MS_FSMUL_F32(dd, SIMD_FMADD_F32(ty, d, SIMD_MOV_F32(cof[0])), SIMD_MOV_F32(0.5f))); + dst = SIMD_MUL_F32(t, SIMD_EXP_F32(SIMD_MUL_F32(SIMD_MOV_F32(-1.0f), dst))); + return dst; +} + +static inline SIMD_F32 SIMD_ERF_APPROXIMATE@SIMD_INSTRUCTION@(SIMD_F32 src) { + SIMD_F32 abs_src = SIMD_ABS_F32(src); + SIMD_F32 sign = SIMD_GETSIGN_F32(src); + SIMD_F32 dst = SIMD_ERFCCHEB@SIMD_INSTRUCTION@(abs_src); + return SIMD_MUL_F32(sign, SIMD_SUB_F32(SIMD_MOV_F32(1.0f), dst)); +} + +static inline int GeluErfAPPROXIMATE@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + SIMD_F32 para1 = SIMD_MOV_F32(1.4142135623730951f); + SIMD_F32 para2 = SIMD_MOV_F32(1.0f); + SIMD_F32 para3 = SIMD_MOV_F32(0.5f); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in = SIMD_LD_F32(src + index); + SIMD_F32 res = SIMD_MUL_F32(SIMD_MUL_F32(para3, in), SIMD_ADD_F32(para2, SIMD_ERF_APPROXIMATE@SIMD_INSTRUCTION@(SIMD_DIV_F32(in, para1)))); + SIMD_ST_F32(dst + index, res); + } + return index; +} + +static inline int Elu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float alpha) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_tmp = SIMD_LD_F32(src + index); + SIMD_F32 exp_tmp = SIMD_SUB_N_F32(SIMD_EXP_F32(src_tmp), 1.0f); + SIMD_MASK mask = SIMD_CMPLE_F32(src_tmp, SIMD_SET0_F32); + SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src_tmp, SIMD_MUL_N_F32(exp_tmp, alpha), mask)); + } + return index; +} + +static inline int Celu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float alpha) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_tmp = SIMD_LD_F32(src + index); + SIMD_F32 exp_tmp = SIMD_SUB_N_F32(SIMD_EXP_F32(SIMD_DIV_N_F32(src_tmp, alpha)), 1.0f); + SIMD_MASK mask = SIMD_CMPLE_F32(src_tmp, SIMD_SET0_F32); + SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src_tmp, SIMD_MUL_N_F32(exp_tmp, alpha), mask)); + } + return index; +} + +static inline int HardShrink@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float lambd) { + SIMD_F32 pos_lamdb_v = SIMD_MOV_F32(lambd); + SIMD_F32 neg_lamdb_v = SIMD_MOV_F32(-lambd); + + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_t = SIMD_LD_F32(src + index); + /* v0 = (in > lamdb) & in */ + SIMD_F32 value0 = SIMD_AND_MASK_F32(SIMD_CMPGT_F32(src_t, pos_lamdb_v), src_t); + /* v1 = (in < -lamdb) & in */ + SIMD_F32 value1 = SIMD_AND_MASK_F32(SIMD_CMPLT_F32(src_t, neg_lamdb_v), src_t); + /* out = (v0 | v1) */ + SIMD_ST_F32(dst + index, SIMD_OR_F32(value0, value1)); + } + return index; +} + +static inline int SoftShrink@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float lambd) { + SIMD_F32 pos_lamdb_v = SIMD_MOV_F32(lambd); + SIMD_F32 neg_lamdb_v = SIMD_MOV_F32(-lambd); + + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_t = SIMD_LD_F32(src + index); + /* v0 = (in > lamdb) & (in - lamdb) */ + SIMD_F32 value0 = SIMD_AND_MASK_F32(SIMD_CMPGT_F32(src_t, pos_lamdb_v), SIMD_SUB_F32(src_t, pos_lamdb_v)); + /* v1 = (in < -lamdb) & (in + lamdb) */ + SIMD_F32 value1 = SIMD_AND_MASK_F32(SIMD_CMPLT_F32(src_t, neg_lamdb_v), SIMD_ADD_F32(src_t, pos_lamdb_v)); + /* out = (v0 | v1) */ + SIMD_ST_F32(dst + index, SIMD_OR_F32(value0, value1)); + } + return index; +} + +static inline int SoftsignFp32Opt@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_tmp = SIMD_LD_F32(src + index); + SIMD_F32 divisor_tmp = SIMD_ADD_F32(SIMD_MOV_F32(1.0f), SIMD_ABS_F32(src_tmp)); + SIMD_ST_F32(dst + index, SIMD_DIV_F32(src_tmp, divisor_tmp)); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..499957c98ce2febfafedf9a94ad2ec63cf43e558 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32.c @@ -0,0 +1,239 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/fp32/adam_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#ifdef ENABLE_AVX512 +#include "nnacl_c/avx512/adam_fp32_avx512.h" +#endif + +int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient, + size_t start, size_t end, bool use_nesterov) { + size_t c1 = start; +#ifdef ENABLE_AVX + size_t c8 = ((end - start) / C8NUM) * C8NUM; + __m256 coeff1_r = _mm256_set1_ps(1 - beta1); + __m256 coeff2_r = _mm256_set1_ps(1 - beta2); + __m256 beta1_r = _mm256_set1_ps(beta1); + __m256 lr_r = _mm256_set1_ps(lr); + __m256 epsi_r = _mm256_set1_ps(epsilon); + + float *var_ptr = var + start; + float *m_ptr = m + start; + float *v_ptr = v + start; + const float *grad_ptr = gradient + start; + + __m256 avx_r0, avx_r1; + __m256 var_r, m_r, v_r, grad_r; + + for (; c1 < start + c8; c1 += C8NUM) { + grad_r = _mm256_loadu_ps(grad_ptr); + m_r = _mm256_loadu_ps(m_ptr); + avx_r0 = _mm256_sub_ps(grad_r, m_r); + avx_r1 = _mm256_mul_ps(avx_r0, coeff1_r); + m_r = _mm256_add_ps(m_r, avx_r1); + _mm256_storeu_ps(m_ptr, m_r); + + v_r = _mm256_loadu_ps(v_ptr); + avx_r0 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), v_r); + v_r = _mm256_add_ps(v_r, _mm256_mul_ps(avx_r0, coeff2_r)); + _mm256_storeu_ps(v_ptr, v_r); + + if (use_nesterov) { + avx_r0 = _mm256_add_ps(_mm256_mul_ps(m_r, beta1_r), _mm256_mul_ps(coeff1_r, grad_r)); + avx_r1 = _mm256_mul_ps(lr_r, avx_r0); + avx_r0 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r); + __m256 avx_r2 = _mm256_div_ps(avx_r1, avx_r0); + + var_r = _mm256_loadu_ps(var_ptr); + var_r = _mm256_sub_ps(var_r, avx_r2); + _mm256_storeu_ps(var_ptr, var_r); + } else { + avx_r0 = _mm256_mul_ps(lr_r, m_r); + avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r); + __m256 avx_r2 = _mm256_div_ps(avx_r0, avx_r1); + var_r = _mm256_loadu_ps(var_ptr); + var_r = _mm256_sub_ps(var_r, avx_r2); + _mm256_storeu_ps(var_ptr, var_r); + } + m_ptr += C8NUM; + v_ptr += C8NUM; + var_ptr += C8NUM; + grad_ptr += C8NUM; + } +#endif + + // remaining + for (; c1 < end; c1++) { + m[c1] += (gradient[c1] - m[c1]) * (1 - beta1); + v[c1] += (gradient[c1] * gradient[c1] - v[c1]) * (1 - beta2); + if (use_nesterov) { + var[c1] -= lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon); + } else { + var[c1] -= lr * m[c1] / (sqrt(v[c1]) + epsilon); + } + } + return NNACL_OK; +} + +int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon, + const float *gradient, size_t start, size_t end, bool use_nesterov) { + size_t c1 = start; +#ifdef ENABLE_AVX + size_t c8 = ((end - start) / C8NUM) * C8NUM; + __m256 coeff1_r = _mm256_set1_ps(1.0f - beta1); + __m256 coeff2_r = _mm256_set1_ps(1.0f - beta2); + __m256 beta1_r = _mm256_set1_ps(beta1); + __m256 beta2_r = _mm256_set1_ps(beta2); + __m256 lr_r = _mm256_set1_ps(-lr); + __m256 epsi_r = _mm256_set1_ps(epsilon); + + float *m_ptr = m + start; + float *v_ptr = v + start; + float *delta_ptr = delta + start; + const float *gradient_ptr = gradient + start; + + __m256 m_r, v_r, delta_r, grad_r; + __m256 avx_r0, avx_r1; + for (; c1 < start + c8; c1 += C8NUM) { + m_r = _mm256_loadu_ps(m_ptr); + avx_r0 = _mm256_mul_ps(m_r, beta1_r); + grad_r = _mm256_loadu_ps(gradient_ptr); + m_r = _mm256_add_ps(avx_r0, _mm256_mul_ps(coeff1_r, grad_r)); + _mm256_storeu_ps(m_ptr, m_r); + + v_r = _mm256_loadu_ps(v_ptr); + avx_r0 = _mm256_mul_ps(v_r, beta2_r); + avx_r1 = _mm256_mul_ps(_mm256_mul_ps(coeff2_r, grad_r), grad_r); + v_r = _mm256_add_ps(avx_r0, avx_r1); + _mm256_storeu_ps(v_ptr, v_r); + + if (use_nesterov) { + avx_r0 = _mm256_add_ps(_mm256_mul_ps(m_r, beta1_r), _mm256_mul_ps(coeff1_r, grad_r)); + avx_r0 = _mm256_mul_ps(lr_r, avx_r0); + avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r); + delta_r = _mm256_div_ps(avx_r0, avx_r1); + _mm256_storeu_ps(delta_ptr, delta_r); + } else { + avx_r0 = _mm256_mul_ps(lr_r, m_r); + avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r); + delta_r = _mm256_div_ps(avx_r0, avx_r1); + _mm256_storeu_ps(delta_ptr, delta_r); + } + m_ptr += C8NUM; + v_ptr += C8NUM; + delta_ptr += C8NUM; + gradient_ptr += C8NUM; + } +#endif + + // remaining + for (; c1 < end; ++c1) { + m[c1] *= beta1; + m[c1] += (1 - beta1) * gradient[c1]; + v[c1] *= beta2; + v[c1] += (1 - beta2) * gradient[c1] * gradient[c1]; + if (use_nesterov) { + delta[c1] = -lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon); + } else { + delta[c1] = -lr * m[c1] / (sqrt(v[c1]) + epsilon); + } + } + return NNACL_OK; +} + +int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + const float *gradient, size_t start, size_t end) { + size_t c1 = start; + SIMD_RUN_AVX512(AdamWeightDecayFp32, c1, var, m, v, lr, beta1, beta2, epsilon, decay, gradient, end); + + // remaining + const float beta1_minus = 1 - beta1; + const float beta2_minus = 1 - beta2; + for (; c1 < end; c1++) { + m[c1] += (gradient[c1] - m[c1]) * beta1_minus; + v[c1] += (gradient[c1] * gradient[c1] - v[c1]) * beta2_minus; + var[c1] -= lr * (m[c1] / (sqrt(v[c1]) + epsilon) + decay * var[c1]); + } + return NNACL_OK; +} + +size_t FusedCastAdamFp32Fp16(float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end) { + size_t c1 = start; + + SIMD_RUN_AVX512(FusedCastAdamFp32Fp16, c1, var, gradient16, m, v, lr, beta1, beta2, epsilon, decay, + global_norm_reciprocal, end); + return c1; +} + +size_t FusedCastAdamFp32Fp32(float *var, const float *gradient32, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end) { + size_t c1 = start; + + SIMD_RUN_AVX512(FusedCastAdamFp32Fp32, c1, var, gradient32, m, v, lr, beta1, beta2, epsilon, decay, + global_norm_reciprocal, end); + return c1; +} + +size_t FusedCastAdamFp16Fp16(int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end) { + size_t c1 = start; + SIMD_RUN_AVX512(FusedCastAdamFp16Fp16, c1, var16, gradient16, m, v, lr, beta1, beta2, epsilon, decay, + global_norm_reciprocal, end); + return c1; +} + +size_t FusedCastAdamFp16Fp32(int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end) { + size_t c1 = start; + SIMD_RUN_AVX512(FusedCastAdamFp16Fp32, c1, var16, gradient32, m, v, lr, beta1, beta2, epsilon, decay, + global_norm_reciprocal, end); + return c1; +} + +int DoAdam(float *m, float *v, const float *gradient, float *weight, float beta1, float beta2, float *beta1_power, + float *beta2_power, float eps, float learning_rate, bool nesterov, int start, int end) { + if ((1.f - beta1_power[0]) <= 0.0f) { + return NNACL_PARAM_INVALID; + } + if ((1.f - beta2_power[0]) < 0.0f) { + return NNACL_ERRCODE_SQRT_NEGATIVE; + } + + float update_lr = learning_rate * sqrtf(1.f - beta2_power[0]) / (1.f - beta1_power[0]); + const float one_minus_beta1 = 1.f - beta1; + const float one_minus_beta2 = 1.f - beta2; + if (nesterov) { // Nadam + for (int i = start; i < end; ++i) { + m[i] += (gradient[i] - m[i]) * one_minus_beta1; + v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2; + weight[i] -= update_lr * (m[i] * beta1 + one_minus_beta1 * gradient[i]) / (sqrtf(v[i]) + eps); + } + } else { + for (int i = start; i < end; ++i) { + m[i] += (gradient[i] - m[i]) * one_minus_beta1; + v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2; + weight[i] -= update_lr * m[i] / (sqrtf(v[i]) + eps); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..4f2b0e987027c71ef6aa3beba89a903dcdb1c58f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32.h @@ -0,0 +1,49 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ADAM_FP32_H +#define MINDSPORE_NNACL_ADAM_FP32_H +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient, + size_t start, size_t end, bool use_nesterov); +int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon, + const float *gradient, size_t start, size_t end, bool use_nesterov); +int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + const float *gradient, size_t start, size_t end); +size_t FusedCastAdamFp32Fp16(float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end); +size_t FusedCastAdamFp32Fp32(float *var, const float *gradient32, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end); +size_t FusedCastAdamFp16Fp16(int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end); +size_t FusedCastAdamFp16Fp32(int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end); + +int DoAdam(float *m, float *v, const float *gradient, float *weight, float beta1, float beta2, float *beta1_power, + float *beta2_power, float eps, float learning_rate, bool nesterov, int start, int end); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ADAM_FP32_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..806f71bd77d41a9bb7d646518a604ccffcc994e0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32_simd.h.in @@ -0,0 +1,203 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ADAM_FP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_ADAM_FP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ +#ifdef MS_SIMD_AVX512 + static inline size_t AdamWeightDecayFp32@SIMD_INSTRUCTION@(size_t index, float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + const float *gradient, size_t end) { + SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); + SIMD_F32 beta2_r = SIMD_MOV_F32(beta2); + SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1); + SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2); + SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr); + SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon); + SIMD_F32 decay_r = SIMD_MOV_F32(decay); + + for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 var_r = SIMD_LD_F32(var + index); + SIMD_F32 m_r = SIMD_LD_F32(m + index); + SIMD_F32 v_r = SIMD_LD_F32(v + index); + SIMD_F32 g_r = SIMD_LD_F32(gradient + index); + + m_r = SIMD_MUL_F32(m_r, beta1_r); + v_r = SIMD_MUL_F32(v_r, beta2_r); + SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r); + m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r); + v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r); + avx_r0 = SIMD_SQRT_F32(v_r); + avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r)); + avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0); + var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r); + SIMD_ST_F32(m + index, m_r); + SIMD_ST_F32(v + index, v_r); + SIMD_ST_F32(var + index, var_r); + } + + return index; +} + +static inline size_t FusedCastAdamFp32Fp16@SIMD_INSTRUCTION@(size_t index, float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + float global_norm_reciprocal, size_t end) { + SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); + SIMD_F32 beta2_r = SIMD_MOV_F32(beta2); + SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1); + SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2); + SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr); + SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon); + SIMD_F32 decay_r = SIMD_MOV_F32(decay); + SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal); + + for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 var_r = SIMD_LD_F32(var + index); + SIMD_F32 m_r = SIMD_LD_F32(m + index); + SIMD_F32 v_r = SIMD_LD_F32(v + index); + SIMD_F32 g_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(gradient16 + index)); + + g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r); + m_r = SIMD_MUL_F32(m_r, beta1_r); + v_r = SIMD_MUL_F32(v_r, beta2_r); + SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r); + m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r); + v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r); + avx_r0 = SIMD_SQRT_F32(v_r); + avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r)); + avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0); + var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r); + SIMD_ST_F32(var + index, var_r); + SIMD_ST_F32(m + index, m_r); + SIMD_ST_F32(v + index, v_r); + } + + return index; +} + +static inline size_t FusedCastAdamFp32Fp32@SIMD_INSTRUCTION@(size_t index, float *var, const float *gradient32, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + float global_norm_reciprocal, size_t end) { + SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); + SIMD_F32 beta2_r = SIMD_MOV_F32(beta2); + SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1); + SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2); + SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr); + SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon); + SIMD_F32 decay_r = SIMD_MOV_F32(decay); + SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal); + + for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 var_r = SIMD_LD_F32(var + index); + SIMD_F32 m_r = SIMD_LD_F32(m + index); + SIMD_F32 v_r = SIMD_LD_F32(v + index); + SIMD_F32 g_r = SIMD_LD_F32(gradient32 + index); + + g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r); + m_r = SIMD_MUL_F32(m_r, beta1_r); + v_r = SIMD_MUL_F32(v_r, beta2_r); + SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r); + m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r); + v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r); + avx_r0 = SIMD_SQRT_F32(v_r); + avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r)); + avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0); + var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r); + SIMD_ST_F32(var + index, var_r); + SIMD_ST_F32(m + index, m_r); + SIMD_ST_F32(v + index, v_r); + } + + return index; +} + +static inline size_t FusedCastAdamFp16Fp16@SIMD_INSTRUCTION@(size_t index, int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + float global_norm_reciprocal, size_t end) { + SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); + SIMD_F32 beta2_r = SIMD_MOV_F32(beta2); + SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1); + SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2); + SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr); + SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon); + SIMD_F32 decay_r = SIMD_MOV_F32(decay); + SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal); + + for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 var_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(var16)); + SIMD_F32 m_r = SIMD_LD_F32(m + index); + SIMD_F32 v_r = SIMD_LD_F32(v + index); + SIMD_F32 g_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(gradient16 + index)); + g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r); + m_r = SIMD_MUL_F32(m_r, beta1_r); + v_r = SIMD_MUL_F32(v_r, beta2_r); + SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r); + m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r); + v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r); + avx_r0 = SIMD_SQRT_F32(v_r); + avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r)); + avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0); + var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r); + SIMD_ST_F32(m + index, m_r); + SIMD_ST_F32(v + index, v_r); + SIMD_ST_HALF_EPI32(var16 + index, SIMD_F32_TO_F16(var_r, 0)); + } + + return index; +} + +static inline size_t FusedCastAdamFp16Fp32@SIMD_INSTRUCTION@(size_t index, int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + float global_norm_reciprocal, size_t end) { + SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); + SIMD_F32 beta2_r = SIMD_MOV_F32(beta2); + SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1); + SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2); + SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr); + SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon); + SIMD_F32 decay_r = SIMD_MOV_F32(decay); + SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal); + + for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 var_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(var16)); + SIMD_F32 m_r = SIMD_LD_F32(m + index); + SIMD_F32 v_r = SIMD_LD_F32(v + index); + SIMD_F32 g_r = SIMD_LD_F32(gradient32 + index); + g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r); + m_r = SIMD_MUL_F32(m_r, beta1_r); + v_r = SIMD_MUL_F32(v_r, beta2_r); + SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r); + m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r); + v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r); + avx_r0 = SIMD_SQRT_F32(v_r); + avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r)); + avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0); + var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r); + SIMD_ST_F32(m + index, m_r); + SIMD_ST_F32(v + index, v_r); + SIMD_ST_HALF_EPI32(var16 + index, SIMD_F32_TO_F16(var_r, 0)); + } + + return index; +} +#endif + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..f3e16f5c417903ed4688bc215cfd9ab3867ba783 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32.c @@ -0,0 +1,156 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/add_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/add_fp32_simd.h" + +int ElementOptAdd(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptAdd, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[0] + in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptAdd, index, in1, in0, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptAddExt(const float *in0, const float *in1, const float alpha, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptAddExtNum0, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[0] + in1[index] * alpha; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptAddExtNum1, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[0] * alpha; + } + } + return NNACL_OK; +} + +int ElementOptAddInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptAddInt, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[0] + in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptAddInt, index, in1, in0, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptAddRelu, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[0] + in1[index], 0); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptAddRelu, index, in1, in0, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[index] + in1[0], 0); + } + } + return NNACL_OK; +} + +int ElementOptAddRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptAddRelu6, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[0] + in1[index], 0), 6); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptAddRelu6, index, in1, in0, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] + in1[0], 0), 6); + } + } + return NNACL_OK; +} + +int BroadcastAdd(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param) { + TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param); + return ElementAdd(tile_in0, tile_in1, out, size); +} + +int ElementAdd(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementAdd, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[index]; + } + return NNACL_OK; +} + +int ElementAddExt(const float *in0, const float *in1, const float alpha, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementAddExt, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[index] * alpha; + } + return NNACL_OK; +} + +int ElementAddRelu(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementAddRelu, index, in0, in1, out, size); + for (; index < size; index++) { + float res = in0[index] + in1[index]; + out[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementAddRelu6(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementAddRelu6, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] + in1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementAddInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementAddInt, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[index]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..783bb8a86b8f7c7309a87e99c2a58a84f8c427fd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32.h @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ADD_H_ +#define MINDSPORE_NNACL_FP32_ADD_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/arithmetic_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ElementAdd(const float *in0, const float *in1, float *out, int size); +int ElementAddExt(const float *in0, const float *in1, const float alpha, float *out, int size); +int ElementOptAddExt(const float *in0, const float *in1, const float alpha, float *out, int size, bool first_scalar); +int ElementAddRelu(const float *in0, const float *in1, float *out, int size); +int ElementAddRelu6(const float *in0, const float *in1, float *out, int size); +int ElementAddInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptAdd(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptAddRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptAddInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int BroadcastAdd(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ADD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..ea8f846fc3cfb8b714f2d110d6de1e8e13aa1269 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32_simd.h.in @@ -0,0 +1,153 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ADD_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_ADD_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int ElementOptAdd@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin0_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_ADD_F32(vin0_, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptAddExtNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, int size) { + SIMD_F32 vin0 = SIMD_MOV_F32(in0[0]); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1, valpha); + SIMD_F32 vout = SIMD_ADD_F32(vin0, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptAddExtNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, int size) { + SIMD_F32 vin1 = SIMD_MOV_F32(in1[0]); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1, valpha); + SIMD_F32 vout = SIMD_ADD_F32(vin0, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptAddInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, + int size) { + SIMD_EPI32 vin0_ = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_ADD_EPI32(vin0_, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptAddRelu@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_ADD_F32(vin0_, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptAddRelu6@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_ADD_F32(vin0_, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementAdd@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_ADD_F32(vin0, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementAddExt@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1, valpha); + SIMD_F32 vout = SIMD_ADD_F32(vin0, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementAddRelu@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_ADD_F32(vin0, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementAddRelu6@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_ADD_F32(vin0, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementAddInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_ADD_EPI32(vin0, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adder_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adder_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..b5ef4123879efe05918dadc459e81ca3a91aac21 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adder_fp32.c @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/adder_fp32.h" +#include +#include +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" + +void Adder12x4(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, + int col, int stride) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r12div = r / 12, r12mod = r % 12; + int c4div = c / 4, c4mod = c % 4; + size_t ci = r * stride + c; + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * 12 + d * 12 + r12mod; + size_t bi = c4div * deep * 4 + d * 4 + c4mod; + value += fabsf(a[ai] - b[bi]); + } + value = -value; + if (bias != NULL) value += bias[c]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type != ActType_No) value = MSMAX(0.0f, value); + dst[ci] = value; + } + } +} + +void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, + size_t stride) { +#ifdef ENABLE_ARM64 + AdderFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride); +#else + Adder12x4(a, b, c, bias, act_type, deep, row, col, stride); +#endif +} + +void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param) { + int out_channel = conv_param->output_channel_; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + int output_count = conv_param->output_h_ * conv_param->output_w_; + if (conv_param->thread_num_ == 0) { + return; + } +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) + const int cal_num = C4NUM; +#else + const int cal_num = C12NUM; +#endif + int output_tile_count = UP_DIV(output_count, cal_num); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * out_channel * output_count; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { + int start_index = thread_id * cal_num; + int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num; + float *gemm_input = packed_input + task_id * deep * cal_num; + float *col_major_gemm_input = col_major_input + task_id * deep * cal_num; + size_t packed_input_size = deep * cal_num * sizeof(float); + memset(gemm_input, 0, packed_input_size); + memset(col_major_gemm_input, 0, packed_input_size); + Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); + + int out_offset = thread_id * cal_num * out_channel + out_batch_offset; + float *gemm_output = output_data + out_offset; +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) + RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep); +#else + RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep); +#endif + AdderOpt(col_major_gemm_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_num, + out_channel, out_channel); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adder_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adder_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..ee59ec669af7c319afd6c3c5020d0a1debeb6fff --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adder_fp32.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ADDER_H_ +#define MINDSPORE_NNACL_FP32_ADDER_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef ENABLE_ARM64 +void AdderFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, size_t stride); +#endif + +void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, + size_t stride); + +void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ADDER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arg_min_max_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arg_min_max_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..cb8f49d793f3220f953b869d2f38f23862c7953e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arg_min_max_fp32.c @@ -0,0 +1,298 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/arg_min_max_fp32.h" +#include + +#define ARG_MIN_MAX_FUNC(data_type) \ + int ArgCompareDesc32##data_type(const void *a, const void *b) { \ + DATA_TYPE b_value = ((ArgElement *)b)->data_.UNION_DATA; \ + DATA_TYPE a_value = ((ArgElement *)a)->data_.UNION_DATA; \ + if (b_value > a_value) { \ + return 1; \ + } \ + if (b_value < a_value) { \ + return -1; \ + } \ + return 0; \ + } \ + int ArgCompareAsc32##data_type(const void *a, const void *b) { \ + DATA_TYPE a_value = ((ArgElement *)a)->data_.UNION_DATA; \ + DATA_TYPE b_value = ((ArgElement *)b)->data_.UNION_DATA; \ + if (b_value > a_value) { \ + return -1; \ + } \ + if (b_value < a_value) { \ + return 1; \ + } \ + return 0; \ + } \ + \ + void ArgMaxTopK1##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const ArgMinMaxComputeParam *param, int pre_axis_count, int axis_count, \ + int after_axis_count) { \ + bool out_value = param->out_value_; \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + for (int i = 0; i < pre_axis_count; ++i) { \ + int output_offset = i * after_axis_count; \ + int input_offset = output_offset * axis_count; \ + for (int j = 0; j < after_axis_count; ++j) { \ + DATA_TYPE value = MIN_VALUE; \ + int index = 0; \ + for (int k = 0; k < axis_count; ++k) { \ + DATA_TYPE value_tmp = input[input_offset + k * after_axis_count + j]; \ + if (value_tmp > value) { \ + value = value_tmp; \ + index = k; \ + } \ + } \ + if (out_value) { \ + outputfp32[output_offset + j] = value; \ + } else { \ + outputint[output_offset + j] = index; \ + } \ + if (output_value != NULL) { \ + output_value[output_offset + j] = value; \ + } \ + } \ + } \ + } \ + \ + void ArgMinTopK1##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const ArgMinMaxComputeParam *param, int pre_axis_count, int axis_count, \ + int after_axis_count) { \ + bool out_value = param->out_value_; \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + for (int i = 0; i < pre_axis_count; ++i) { \ + int output_offset = i * after_axis_count; \ + int input_offset = output_offset * axis_count; \ + for (int j = 0; j < after_axis_count; ++j) { \ + DATA_TYPE value = MAX_VALUE; \ + int index = 0; \ + for (int k = 0; k < axis_count; ++k) { \ + DATA_TYPE value_tmp = input[input_offset + k * after_axis_count + j]; \ + if (value_tmp < value) { \ + value = value_tmp; \ + index = k; \ + } \ + } \ + if (out_value) { \ + outputfp32[output_offset + j] = value; \ + } else { \ + outputint[output_offset + j] = index; \ + } \ + if (output_value != NULL) { \ + output_value[output_offset + j] = value; \ + } \ + } \ + } \ + } \ + \ + void ArgMinMaxDim0##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const int32_t *in_shape, const ArgMinMaxComputeParam *param, \ + COMPARE_FUNCTION compare_func) { \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + for (int32_t i = 0; i < param->in_strides_[0]; ++i) { \ + for (int j = 0; j < in_shape[0]; ++j) { \ + int offset = param->in_strides_[0] * j + i; \ + param->arg_elements_[j].index_ = (uint32_t)j; \ + param->arg_elements_[j].data_.UNION_DATA = input[offset]; \ + } \ + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), *compare_func); \ + for (int j = 0; j < param->topk_; ++j) { \ + int out_offset = j * param->out_strides_[0] + i; \ + if (param->out_value_) { \ + outputfp32[out_offset] = param->arg_elements_[j].data_.UNION_DATA; \ + } else { \ + outputint[out_offset] = param->arg_elements_[j].index_; \ + } \ + if (output_value != NULL) { \ + output_value[out_offset] = param->arg_elements_[j].data_.UNION_DATA; \ + } \ + } \ + } \ + return; \ + } \ + \ + void ArgMinMaxDim1##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const int32_t *in_shape, const ArgMinMaxComputeParam *param, \ + COMPARE_FUNCTION compare_func) { \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + int in_shape1 = in_shape[1]; \ + for (int i = 0; i < in_shape[0]; ++i) { \ + int in_dim0_offset = i * param->in_strides_[0]; \ + int out_dim0_offset = i * param->out_strides_[0]; \ + for (int j = 0; j < param->in_strides_[1]; ++j) { \ + for (int k = 0; k < in_shape1; ++k) { \ + int offset = param->in_strides_[1] * k + in_dim0_offset + j; \ + param->arg_elements_[k].index_ = (uint32_t)k; \ + param->arg_elements_[k].data_.UNION_DATA = input[offset]; \ + } \ + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), *compare_func); \ + for (int k = 0; k < param->topk_; ++k) { \ + int out_offset = out_dim0_offset + j + k * param->out_strides_[1]; \ + if (param->out_value_) { \ + outputfp32[out_offset] = param->arg_elements_[k].data_.UNION_DATA; \ + } else { \ + outputint[out_offset] = param->arg_elements_[k].index_; \ + } \ + if (output_value != NULL) { \ + output_value[out_offset] = param->arg_elements_[k].data_.UNION_DATA; \ + } \ + } \ + } \ + } \ + return; \ + } \ + \ + void ArgMinMaxDim2##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const int32_t *in_shape, const ArgMinMaxComputeParam *param, \ + COMPARE_FUNCTION compare_func) { \ + int in_shape1 = in_shape[1]; \ + int in_shape2 = in_shape[2]; \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + for (int i = 0; i < in_shape[0]; ++i) { \ + int in_dim0_offset = i * param->in_strides_[0]; \ + int out_dim0_offset = i * param->out_strides_[0]; \ + for (int j = 0; j < in_shape1; ++j) { \ + int in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; \ + int out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; \ + for (int k = 0; k < param->in_strides_[2]; ++k) { \ + for (int l = 0; l < in_shape2; ++l) { \ + int offset = param->in_strides_[2] * l + k + in_dim1_offset; \ + param->arg_elements_[l].index_ = (uint32_t)l; \ + param->arg_elements_[l].data_.UNION_DATA = input[offset]; \ + } \ + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), *compare_func); \ + for (int l = 0; l < param->topk_; ++l) { \ + int out_offset = out_dim1_offset + k + l * param->out_strides_[2]; \ + if (param->out_value_) { \ + outputfp32[out_offset] = param->arg_elements_[l].data_.UNION_DATA; \ + } else { \ + outputint[out_offset] = param->arg_elements_[l].index_; \ + } \ + if (output_value != NULL) { \ + output_value[out_offset] = param->arg_elements_[l].data_.UNION_DATA; \ + } \ + } \ + } \ + } \ + } \ + } \ + \ + void ArgMinMaxDim3##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const int32_t *in_shape, const ArgMinMaxComputeParam *param, \ + COMPARE_FUNCTION compare_func) { \ + int in_shape1 = in_shape[1]; \ + int in_shape2 = in_shape[2]; \ + int in_shape3 = in_shape[3]; \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + for (int i = 0; i < in_shape[0]; ++i) { \ + int in_dim0_offset = i * param->in_strides_[0]; \ + int out_dim0_offset = i * param->out_strides_[0]; \ + for (int j = 0; j < in_shape1; ++j) { \ + int in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; \ + int out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; \ + for (int k = 0; k < in_shape2; ++k) { \ + int in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset; \ + int out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset; \ + for (int l = 0; l < in_shape3; ++l) { \ + int offset = l + in_dim2_offset; \ + param->arg_elements_[l].index_ = (uint32_t)l; \ + param->arg_elements_[l].data_.UNION_DATA = input[offset]; \ + } \ + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), *compare_func); \ + for (int l = 0; l < param->topk_; ++l) { \ + int out_offset = out_dim2_offset + l; \ + if (param->out_value_) { \ + outputfp32[out_offset] = param->arg_elements_[l].data_.UNION_DATA; \ + } else { \ + outputint[out_offset] = (int)(param->arg_elements_[l].index_); \ + } \ + if (output_value != NULL) { \ + output_value[out_offset] = param->arg_elements_[l].data_.UNION_DATA; \ + } \ + } \ + } \ + } \ + } \ + } \ + \ + void ArgMinMax##data_type##32(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const int32_t *in_shape, const ArgMinMaxComputeParam *param) { \ + if (param->topk_ == 1) { \ + int pre_axis_count = 1; \ + int axis_count = 1; \ + int after_axis_count = 1; \ + ComputeAxisDims(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count); \ + \ + if (param->get_max_) { \ + ArgMaxTopK1##data_type(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count); \ + } else { \ + ArgMinTopK1##data_type(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count); \ + } \ + return; \ + } \ + \ + COMPARE_FUNCTION compare_function = NULL; \ + if (param->get_max_) { \ + compare_function = ArgCompareDesc32##data_type; \ + } else { \ + compare_function = ArgCompareAsc32##data_type; \ + } \ + \ + switch (param->axis_) { \ + case 0: \ + ArgMinMaxDim0##data_type(input, output, output_value, in_shape, param, compare_function); \ + break; \ + case 1: \ + ArgMinMaxDim1##data_type(input, output, output_value, in_shape, param, compare_function); \ + break; \ + case 2: \ + ArgMinMaxDim2##data_type(input, output, output_value, in_shape, param, compare_function); \ + break; \ + case 3: \ + ArgMinMaxDim3##data_type(input, output, output_value, in_shape, param, compare_function); \ + break; \ + } \ + return; \ + } + +#define DATA_TYPE float +#define MIN_VALUE -FLT_MAX +#define MAX_VALUE FLT_MAX +#define UNION_DATA f_data_ +ARG_MIN_MAX_FUNC(Fp) +#undef DATA_TYPE +#undef MIN_VALUE +#undef MAX_VALUE +#undef UNION_DATA + +#define DATA_TYPE int32_t +#define MIN_VALUE INT32_MIN +#define MAX_VALUE INT32_MAX +#define UNION_DATA i_data_ +ARG_MIN_MAX_FUNC(Int) +#undef DATA_TYPE +#undef MIN_VALUE +#undef MAX_VALUE +#undef UNION_DATA diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arg_min_max_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arg_min_max_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..8895eeab9c85a92306727b64e423063a18aa913c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arg_min_max_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FP32_ARG_MIN_MAX_FP32_H_ +#define FP32_ARG_MIN_MAX_FP32_H_ + +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/arg_min_max_parameter.h" +#include "nnacl_c/kernel/arg_min_max.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ArgMinMaxFp32(const float *input, void *output, float *output_value, const int32_t *in_shape, + const ArgMinMaxComputeParam *param); +void ArgMinMaxInt32(const int32_t *input, void *output, int32_t *output_value, const int32_t *in_shape, + const ArgMinMaxComputeParam *param); +#ifdef __cplusplus +} +#endif + +#endif // FP32_ARG_MIN_MAX_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_compare_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_compare_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..70d0a55a9b72bc4f8334370b23e9482b3cc4f731 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_compare_fp32.c @@ -0,0 +1,198 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/arithmetic_compare_fp32.h" + +inline bool EqualFp32(float x, float y); +inline bool EqualBool(bool x, bool y); +inline bool NotEqualFp32(float x, float y); +inline bool LessFp32(float x, float y); +inline bool LessEqualFp32(float x, float y); +inline bool GreaterFp32(float x, float y); +inline bool GreaterEqualFp32(float x, float y); + +inline bool EqualInt32(int x, int y); +inline bool NotEqualInt32(int x, int y); +inline bool NotEqualInt64(int64_t x, int64_t y); +inline bool LessInt32(int x, int y); +inline bool LessEqualInt32(int x, int y); +inline bool GreaterInt32(int x, int y); +inline bool GreaterEqualInt32(int x, int y); + +bool EqualFp32(float x, float y) { return x == y; } +bool EqualBool(bool x, bool y) { return x == y; } +bool NotEqualFp32(float x, float y) { return x != y; } +bool LessFp32(float x, float y) { return x < y; } +bool LessEqualFp32(float x, float y) { return x <= y; } +bool GreaterFp32(float x, float y) { return x > y; } +bool GreaterEqualFp32(float x, float y) { return x >= y; } + +bool EqualInt32(int x, int y) { return x == y; } +bool NotEqualInt32(int x, int y) { return x != y; } +bool NotEqualInt64(int64_t x, int64_t y) { return x != y; } +bool LessInt32(int x, int y) { return x < y; } +bool LessEqualInt32(int x, int y) { return x <= y; } +bool GreaterInt32(int x, int y) { return x > y; } +bool GreaterEqualInt32(int x, int y) { return x >= y; } + +#define ELEMENT_COMPARE(input0, input1, output, element_size, compare_func) \ + do { \ + for (int i = 0; i < element_size; i++) { \ + output[i] = compare_func(input0[i], input1[i]); \ + } \ + return NNACL_OK; \ + } while (0) + +#define ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, compare_func) \ + do { \ + int i = 0; \ + if (first_scalar) { \ + for (; i < element_size; i++) { \ + output[i] = compare_func(input0[0], input1[i]); \ + } \ + } else { \ + for (; i < element_size; i++) { \ + output[i] = compare_func(input0[i], input1[0]); \ + } \ + } \ + return NNACL_OK; \ + } while (0) + +// equal: +int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, EqualFp32); +} + +int ElementEqualBool(const bool *input0, const bool *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, EqualBool); +} + +int ElementOptEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, EqualFp32); +} + +int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, EqualInt32); +} + +int ElementOptEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, EqualInt32); +} + +// not equal +int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, NotEqualFp32); +} + +int ElementOptNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, NotEqualFp32); +} + +int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, NotEqualInt32); +} + +int ElementOptNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, NotEqualInt32); +} + +int ElementNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, NotEqualInt64); +} + +int ElementOptNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, NotEqualInt64); +} + +// less +int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, LessFp32); +} + +int ElementOptLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size, bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessFp32); +} + +int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, LessInt32); +} + +int ElementOptLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessInt32); +} + +// less equal +int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, LessEqualFp32); +} + +int ElementOptLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessEqualFp32); +} + +int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, LessEqualInt32); +} + +int ElementOptLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessEqualInt32); +} + +// greater +int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, GreaterFp32); +} + +int ElementOptGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterFp32); +} + +int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, GreaterInt32); +} + +int ElementOptGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterInt32); +} + +// greater equal +int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, GreaterEqualFp32); +} + +int ElementOptGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterEqualFp32); +} + +int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, GreaterEqualInt32); +} + +int ElementOptGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterEqualInt32); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_compare_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_compare_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..5ffc7a48e9125958fe685757b6e61dbe5f9827c9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_compare_fp32.h @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_ARITHMETIC_COMPARE_H_ +#define MINDSPORE_NNACL_ARITHMETIC_COMPARE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/arithmetic_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementEqualBool(const bool *input0, const bool *input1, uint8_t *output, int element_size); +int ElementOptEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, bool first_scalar); +int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementOptNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size); +int ElementOptNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementOptLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size, bool first_scalar); +int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementOptLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementOptGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementOptGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_ARITHMETIC_COMPARE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7cba904f9e485cb774ab07597bb64d06b10a14c3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32.c @@ -0,0 +1,482 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include +#include "nnacl_c/arithmetic_fp32_simd.h" + +#define ACCURACY_DATA 0.00000001 + +int ElementFloorMod(const float *in0, const float *in1, float *out, int size) { + int i = 0; + + SIMD_RUN_X86_NO_SCALAR(ElementFloorMod, i, in0, in1, out, size); // neon no floor instruction + + for (; i < size; i++) { + out[i] = in0[i] - floorf(in0[i] / in1[i]) * in1[i]; + } + return NNACL_OK; +} + +int ElementOptFloorMod(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int i = 0; + + if (first_scalar) { + SIMD_RUN_X86_NO_SCALAR(ElementOptFloorModNum0, i, in0, in1, out, size); // neon no floor instruction + for (; i < size; i++) { + out[i] = in0[0] - floorf(in0[0] / in1[i]) * in1[i]; + } + } else { + SIMD_RUN_X86_NO_SCALAR(ElementOptFloorModNum1, i, in0, in1, out, size); // neon no floor instruction + for (; i < size; i++) { + out[i] = in0[i] - floorf(in0[i] / in1[0]) * in1[0]; + } + } + + return NNACL_OK; +} + +int ElementFloorModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int i = 0; i < size; i++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); + int remainder = in0[i] - (in0[i] / in1[i]) * in1[i]; + out[i] = (remainder != 0) && ((in0[i] > 0) != (in1[i] > 0)) ? remainder + in1[i] : remainder; + } + return NNACL_OK; +} + +int ElementOptFloorModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int i = 0; + if (first_scalar) { + for (; i < size; i++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); + int remainder = in0[0] - (in0[0] / in1[i]) * in1[i]; + out[i] = (remainder != 0) && ((in0[0] > 0) != (in1[i] > 0)) ? remainder + in1[i] : remainder; + } + } else { + NNACL_CHECK_ZERO_RETURN_ERR(in1[0]); + for (; i < size; i++) { + int remainder = in0[i] - (in0[i] / in1[0]) * in1[0]; + out[i] = (remainder != 0) && ((in0[i] > 0) != (in1[0] > 0)) ? remainder + in1[0] : remainder; + } + } + + return NNACL_OK; +} + +int ElementMod(const float *in0, const float *in1, float *out, int size) { + for (int i = 0; i < size; i++) { + out[i] = fmodf(in0[i], in1[i]); + } + return NNACL_OK; +} + +int ElementOptMod(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + for (; index < size; index++) { + out[index] = fmodf(in0[0], in1[index]); + } + } else { + for (; index < size; index++) { + out[index] = fmodf(in0[index], in1[0]); + } + } + return NNACL_OK; +} + +int ElementModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int i = 0; i < size; i++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); + out[i] = in0[i] % in1[i]; + } + return NNACL_OK; +} + +int ElementOptModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + if (first_scalar) { + for (int index = 0; index < size; index++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[index]); + out[index] = in0[0] % in1[index]; + } + } else { + NNACL_CHECK_ZERO_RETURN_ERR(in1[0]); + for (int index = 0; index < size; index++) { + out[index] = in0[index] % in1[0]; + } + } + return NNACL_OK; +} + +int ElementFloorDiv(const float *in0, const float *in1, float *out, int size) { + int i = 0; + + SIMD_RUN_X86_NO_SCALAR(ElementFloorDiv, i, in0, in1, out, size); // neon no floor instruction + + for (; i < size; i++) { + out[i] = floorf(in0[i] / in1[i]); + } + return NNACL_OK; +} + +int ElementOptFloorDiv(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int i = 0; + + if (first_scalar) { + SIMD_RUN_X86_NO_SCALAR(ElementOptFloorDivNum0, i, in0, in1, out, size); // neon no floor instruction + + for (; i < size; i++) { + out[i] = floorf(in0[0] / in1[i]); + } + } else { + SIMD_RUN_X86_NO_SCALAR(ElementOptFloorDivNum1, i, in0, in1, out, size); // neon no floor instruction + + for (; i < size; i++) { + out[i] = floorf(in0[i] / in1[0]); + } + } + + return NNACL_OK; +} + +int ElementFloorDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementFloorDivInt, i, in0, in1, out, size); + + for (; i < size; i++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); + out[i] = in0[i] / in1[i]; + } + return NNACL_OK; +} + +int ElementOptFloorDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int i = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptFloorDivIntNum0, i, in0, in1, out, size); + + for (; i < size; i++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); + out[i] = in0[0] / in1[i]; + } + } else { + NNACL_CHECK_ZERO_RETURN_ERR(in1[0]); + + SIMD_RUN_NO_SCALAR(ElementOptFloorDivIntNum1, i, in0, in1, out, size); + + for (; i < size; i++) { + out[i] = in0[i] / in1[0]; + } + } + + return NNACL_OK; +} + +int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementLogicalAnd, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = (float)((bool)(in0[index]) & (bool)(in1[index])); + } + return NNACL_OK; +} + +int ElementOptLogicalAnd(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + SIMD_RUN_NO_SCALAR(ElementOptLogicalAnd, index, in0, in1, out, size, first_scalar); + if (first_scalar) { + for (; index < size; index++) { + out[index] = (float)((bool)(in0[0]) & (bool)(in1[index])); + } + } else { + for (; index < size; index++) { + out[index] = (float)((bool)(in0[index]) & (bool)(in1[0])); + } + } + + return NNACL_OK; +} + +int ElementLogicalAndInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + for (; index < size; index++) { + out[index] = (int)((unsigned int)(in0[index]) & (unsigned int)(in1[index])); + } + return NNACL_OK; +} + +int ElementOptLogicalAndInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + for (; index < size; index++) { + out[index] = (int)((unsigned int)(in0[0]) & (unsigned int)(in1[index])); + } + } else { + for (; index < size; index++) { + out[index] = (int)((unsigned int)(in0[index]) & (unsigned int)(in1[0])); + } + } + + return NNACL_OK; +} + +int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size) { + int index = 0; + for (; index < size; index++) { + out[index] = (bool)((unsigned int)(in0[index]) & (unsigned int)(in1[index])); + } + + return NNACL_OK; +} + +int ElementOptLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + for (; index < size; index++) { + out[index] = (bool)((unsigned int)(in0[0]) & (unsigned int)(in1[index])); + } + } else { + for (; index < size; index++) { + out[index] = (bool)((unsigned int)(in0[index]) & (unsigned int)(in1[0])); + } + } + + return NNACL_OK; +} + +int ElementLogicalOr(const float *in0, const float *in1, float *out, int size) { + int index = 0; +#ifdef ENABLE_NEON + float32x4_t vtrue = vdupq_n_f32(1); + float32x4_t vfalse = vdupq_n_f32(0); + uint32x4_t mask = vmovq_n_u32(((uint32_t)(1u << 31) - 1)); + uint32x4_t zeros = vdupq_n_u32(0); + for (; index <= size - 4; index += C4NUM) { + uint32x4_t vin0 = vandq_u32(vreinterpretq_u32_f32(vld1q_f32(in0 + index)), mask); + uint32x4_t vin1 = vandq_u32(vreinterpretq_u32_f32(vld1q_f32(in1 + index)), mask); + float32x4_t vout = vbslq_f32(vceqq_u32(vorrq_u32(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f32(out + index, vout); + } +#endif + for (; index < size; index++) { + out[index] = (float)((bool)(in0[index]) | (bool)(in1[index])); + } + return NNACL_OK; +} + +int ElementOptLogicalOr(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + for (; index < size; index++) { + out[index] = (float)((bool)(in0[0]) | (bool)(in1[index])); + } + } else { + for (; index < size; index++) { + out[index] = (float)((bool)(in0[index]) | (bool)(in1[0])); + } + } + + return NNACL_OK; +} + +int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size) { + int index = 0; + for (; index < size; index++) { + out[index] = (bool)(in0[index] | in1[index]); + } + return NNACL_OK; +} + +int ElementOptLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + for (; index < size; index++) { + out[index] = (bool)(in0[0] | in1[index]); + } + } else { + for (; index < size; index++) { + out[index] = (bool)(in0[index] | in1[0]); + } + } + + return NNACL_OK; +} + +int ElementMaximum(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMaximum, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[index] ? in0[index] : in1[index]; + } + return NNACL_OK; +} + +int ElementOptMaximum(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMaximumNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] > in1[index] ? in0[0] : in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMaximumNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[0] ? in0[index] : in1[0]; + } + } + + return NNACL_OK; +} + +int ElementMaximumInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMaximumInt, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[index] ? in0[index] : in1[index]; + } + return NNACL_OK; +} + +int ElementOptMaximumInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMaximumIntNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] > in1[index] ? in0[0] : in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMaximumIntNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[0] ? in0[index] : in1[0]; + } + } + + return NNACL_OK; +} + +int ElementMinimumInt(const int32_t *input0, const int32_t *input1, int32_t *output, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMinimumInt, index, input0, input1, output, size); + + for (; index < size; index++) { + output[index] = input0[index] > input1[index] ? input1[index] : input0[index]; + } + return NNACL_OK; +} + +int ElementOptMinimumInt(const int32_t *input0, const int32_t *input1, int32_t *output, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMinimumIntNum0, index, input0, input1, output, size); + + for (; index < size; index++) { + output[index] = input0[0] > input1[index] ? input1[index] : input0[0]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMinimumIntNum1, index, input0, input1, output, size); + + for (; index < size; index++) { + output[index] = input0[index] > input1[0] ? input1[0] : input0[index]; + } + } + + return NNACL_OK; +} + +int ElementMinimum(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMinimum, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[index] ? in1[index] : in0[index]; + } + return NNACL_OK; +} + +int ElementOptMinimum(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMinimumNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] > in1[index] ? in1[index] : in0[0]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMinimumNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[0] ? in1[0] : in0[index]; + } + } + + return NNACL_OK; +} + +#undef ACCURACY_DATA + +void TileOneDimensionFp32(const void *inPtr, void *outPtr, int dim, size_t ndim, const int32_t *inShape, + const int32_t *inStrides, const int32_t *outStrides, const int32_t *multiple) { + const float *inData = (const float *)inPtr; + float *outData = (float *)outPtr; + + int srcDimSize = inShape[dim]; + if (dim == ndim - 1) { + for (int i = 0; i < multiple[dim]; i++) { + memcpy(outData, inData, srcDimSize * sizeof(float)); + outData += srcDimSize; + } + return; + } + for (size_t i = 0; i < srcDimSize; i++) { + for (size_t j = 0; j < multiple[dim]; j++) { + TileOneDimensionFp32(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim, + inShape, inStrides, outStrides, multiple); + } + } +} + +void TileDimensionsFp32(const float *data0, const float *data1, float *tile_data0, float *tile_data1, + ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimensionFp32(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, + param->multiples0_); + TileOneDimensionFp32(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_, + param->multiples1_); +} + +void AssignSubOpt(float *in0, const float *in1, size_t size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(AssignSubOpt, index, in0, in1, size); + + for (; index < size; index++) { + in0[index] = in0[index] - in1[index]; + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..005c90a7ef359e1edfae4d177460ff788fbe1c0b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ARITHMETIC_H_ +#define MINDSPORE_NNACL_ARITHMETIC_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/arithmetic_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp32/add_fp32.h" +#include "nnacl_c/fp32/mul_fp32.h" +#include "nnacl_c/fp32/div_fp32.h" +#include "nnacl_c/fp32/sub_fp32.h" +#include "nnacl_c/fp32/squared_difference.h" + +#ifdef __cplusplus +extern "C" { +#endif +void TileOneDimensionFp32(const void *inData, void *outData, int dim, size_t ndim, const int32_t *inShape, + const int32_t *inStrides, const int32_t *outStrides, const int32_t *multiple); +void TileDimensionsFp32(const float *data0, const float *data1, float *tile_data0, float *tile_data1, + ArithmeticParameter *param); +/* logical and */ +int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size); +int ElementOptLogicalAnd(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementLogicalAndInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptLogicalAndInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size); +int ElementOptLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size, bool first_scalar); + +/* logical or */ +int ElementLogicalOr(const float *in0, const float *in1, float *out, int size); +int ElementOptLogicalOr(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size); +int ElementOptLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size, bool first_scalar); + +/* max min */ +int ElementMaximum(const float *in0, const float *in1, float *out, int size); +int ElementOptMaximum(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementMinimum(const float *in0, const float *in1, float *out, int size); +int ElementOptMinimum(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementMaximumInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptMaximumInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int ElementMinimumInt(const int32_t *input0, const int32_t *input1, int32_t *output, int size); +int ElementOptMinimumInt(const int32_t *input0, const int32_t *input1, int32_t *output, int size, bool first_scalar); + +/* floor div */ +int ElementFloorDiv(const float *in0, const float *in1, float *out, int size); +int ElementOptFloorDiv(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementFloorDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptFloorDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); + +/* floor mod */ +int ElementFloorMod(const float *in0, const float *in1, float *out, int size); +int ElementOptFloorMod(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementFloorModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptFloorModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); + +/* mod */ +int ElementMod(const float *in0, const float *in1, float *out, int size); +int ElementOptMod(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); + +void AssignSubOpt(float *in0, const float *in1, size_t size); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_ARITHMETIC_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..24688952a181f2a07cf4ad1dcb1b2905d8be0ae7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32_simd.h.in @@ -0,0 +1,287 @@ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_ARITHMETIC_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_ARITHMETIC_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +#ifndef MS_SIMD_NEON +static inline int ElementFloorMod@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 floor_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_F32 out_tmp = SIMD_SUB_F32(in0_tmp, SIMD_MUL_F32(floor_tmp, in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptFloorModNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in0_tmp = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 floor_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_F32 out_tmp = SIMD_SUB_F32(in0_tmp, SIMD_MUL_F32(floor_tmp, in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptFloorModNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in1_tmp = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 floor_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_F32 out_tmp = SIMD_SUB_F32(in0_tmp, SIMD_MUL_F32(floor_tmp, in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementFloorDiv@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 floor_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_ST_F32(out + index, floor_tmp); + } + return index; +} + +static inline int ElementOptFloorDivNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in0_tmp = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptFloorDivNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in1_tmp = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 out_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} +#endif + +static inline int ElementFloorDivInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_DIV_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptFloorDivIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in0_tmp = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_DIV_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptFloorDivIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in1_tmp = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 out_tmp = SIMD_DIV_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementMaximum@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_MAX_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMaximumNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in0_tmp = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_MAX_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMaximumNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in1_tmp = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 out_tmp = SIMD_MAX_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementMaximumInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_MAX_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMaximumIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in0_tmp = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_MAX_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMaximumIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in1_tmp = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 out_tmp = SIMD_MAX_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementMinimumInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_MIN_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMinimumIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in0_tmp = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_MIN_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMinimumIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in1_tmp = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 out_tmp = SIMD_MIN_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementMinimum@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_MIN_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMinimumNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in0_tmp = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_MIN_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMinimumNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in1_tmp = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 out_tmp = SIMD_MIN_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline size_t AssignSubOpt@SIMD_INSTRUCTION@(int index, float *in0, const float *in1, size_t size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_SUB_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(in0 + index, out_tmp); + } + return index; +} + +int ElementLogicalAnd@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_AND_F32(SIMD_GETSIGN_F32(in0_tmp), SIMD_GETSIGN_F32(in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +int ElementOptLogicalAnd@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size, bool first_scalar) { + if (first_scalar) { + SIMD_F32 in0_tmp = SIMD_MOV_F32(*in0); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_AND_F32(SIMD_GETSIGN_F32(in0_tmp), SIMD_GETSIGN_F32(in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + } else { + SIMD_F32 in1_tmp = SIMD_MOV_F32(*in1); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 out_tmp = SIMD_AND_F32(SIMD_GETSIGN_F32(in0_tmp), SIMD_GETSIGN_F32(in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + } + + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a85950c59b57c75daeea96eb78f832f522323c71 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32.c @@ -0,0 +1,230 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "nnacl_c/fp32/arithmetic_self_fp32.h" +#include "nnacl_c/arithmetic_self_fp32_simd.h" + +int ElementAbs(const float *input, float *output, const int element_size) { + int i = 0; + + // only avx512 support abs fp32 instruction + SIMD_RUN_AVX512(ElementAbs, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = fabsf(input[i]); + } + return NNACL_OK; +} + +int ElementAbsInt(const int32_t *input, int32_t *output, const int element_size) { + int i = 0; + + // only avx512 support abs fp32 instruction + SIMD_RUN_AVX512(ElementAbsInt, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = abs(input[i]); + } + return NNACL_OK; +} + +// cos +int ElementCos(const float *input, float *output, const int element_size) { + int i = 0; + SIMD_RUN_X86_NO_SCALAR(ElementCos, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = cosf(input[i]); + } + return NNACL_OK; +} + +// log: +int ElementLog(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_X86_NO_SCALAR(ElementLog, i, input, output, element_size); + for (; i < element_size; i++) { + if (MS_UNLIKELY(input[i] < 0)) { + return NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO; + } + output[i] = logf(input[i]); + } + return NNACL_OK; +} + +// log1p: +int ElementLog1p(const float *input, float *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + if (MS_UNLIKELY(input[i] < -1.0f)) { + return NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO; + } + output[i] = log1p(input[i]); + } + return NNACL_OK; +} + +int ElementSquare(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementSquare, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = input[i] * input[i]; + } + return NNACL_OK; +} + +int ElementSqrt(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementSqrt, i, input, output, element_size); + for (; i < element_size; i++) { + if (MS_UNLIKELY(input[i] < 0)) { + return NNACL_ERRCODE_SQRT_NEGATIVE; + } + output[i] = sqrtf(input[i]); + } + return NNACL_OK; +} + +int ElementRsqrt(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementRsqrt, i, input, output, element_size); + for (; i < element_size; i++) { + if (MS_UNLIKELY(input[i] < 0)) { + return NNACL_ERRCODE_RSQRT_NEGATIVE; + } + output[i] = 1.f / sqrtf(input[i]); + } + return NNACL_OK; +} + +int ElementSin(const float *input, float *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = sinf(input[i]); + } + return NNACL_OK; +} + +// logical_not: +int ElementLogicalNot(const float *input, float *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = (float)(!((bool)(input[i]))); + } + return NNACL_OK; +} + +// logical_not: +int ElementLogicalNotBool(const bool *input, bool *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = !input[i]; + } + return NNACL_OK; +} + +int ElementRound(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_AVX(ElementRound, i, input, output, element_size); + SIMD_RUN_SSE(ElementRound, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = roundf(input[i]); + } + return NNACL_OK; +} + +int ElementFloor(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_X86_NO_SCALAR(ElementFloor, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = floorf(input[i]); + } + return NNACL_OK; +} + +int ElementCeil(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_X86_NO_SCALAR(ElementCeil, i, input, output, element_size); + for (; i < element_size; ++i) { + output[i] = ceilf(input[i]); + } + return NNACL_OK; +} + +int ElementNegative(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementNegative, i, input, output, element_size); + for (; i < element_size; ++i) { + output[i] = -input[i]; + } + return NNACL_OK; +} + +int ElementNegativeInt(const int32_t *input, int32_t *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementNegativeInt, i, input, output, element_size); + for (; i < element_size; ++i) { + output[i] = -input[i]; + } + return NNACL_OK; +} + +int ElementReciprocal(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementReciprocal, i, input, output, element_size); + for (; i < element_size; ++i) { + if (input[i] == 0.0f) { + return NNACL_ERR; + } + output[i] = 1.f / input[i]; + } + return NNACL_OK; +} + +// Erf +int ElementErf(const float *input, float *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = erff(input[i]); + } + return NNACL_OK; +} + +int ElementIsFinite(const float *input, bool *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = true; + if (isnan(input[i]) || isinf(input[i])) { + output[i] = false; + } + } + return NNACL_OK; +} + +int ElementMish(const float *input, float *output, const int element_size) { + int i = 0; + SIMD_RUN_NO_SCALAR(ElementMish, i, input, output, element_size); + + for (; i < element_size; ++i) { + simd_exp32(input[i], output + i); + float exp_pow = (output[i] + 1) * (output[i] + 1); + output[i] = input[i] * (exp_pow - 1) / (exp_pow + 1); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..5c4213c56a8706dcb393bdb531ca518fb528c49b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_ARITHMETIC_SELF_H_ +#define MINDSPORE_NNACL_ARITHMETIC_SELF_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ElementAbs(const float *input, float *output, const int element_size); +int ElementAbsInt(const int32_t *input, int32_t *output, const int element_size); + +int ElementCos(const float *input, float *output, const int element_size); + +int ElementLog(const float *input, float *output, const int element_size); + +int ElementLog1p(const float *input, float *output, const int element_size); + +int ElementSquare(const float *input, float *output, const int element_size); + +int ElementSqrt(const float *input, float *output, const int element_size); + +int ElementRsqrt(const float *input, float *output, const int element_size); + +int ElementSin(const float *input, float *output, const int element_size); + +int ElementLogicalNot(const float *input, float *output, const int element_size); + +int ElementLogicalNotBool(const bool *input, bool *output, const int element_size); + +int ElementRound(const float *input, float *output, const int element_size); + +int ElementFloor(const float *input, float *output, const int element_size); + +int ElementCeil(const float *input, float *output, const int number); + +int ElementNegative(const float *input, float *output, const int element_size); +int ElementNegativeInt(const int32_t *input, int32_t *output, const int element_size); + +int ElementReciprocal(const float *input, float *output, const int element_size); + +int ElementErf(const float *input, float *output, const int element_size); + +int ElementIsFinite(const float *input, bool *output, const int element_size); + +int ElementMish(const float *input, float *output, const int element_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_ARITHMETIC_SELF_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..ec29d20f332ade7e9119778151839d1190129065 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32_simd.h.in @@ -0,0 +1,152 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_ARITHMETIC_SELF_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_ARITHMETIC_SELF_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +#if defined(MS_SIMD_AVX512) +// only avx512 support abs fp32 instruction +static inline int ElementAbs@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_ABS_F32(SIMD_LD_F32(input + index))); + } + return index; +} + +static inline int ElementAbsInt@SIMD_INSTRUCTION@(int index, const int32_t *input, int32_t *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(output + index, SIMD_ABS_EPI32(SIMD_LD_EPI32(input + index))); + } + return index; +} +#endif + +#if !defined(MS_SIMD_NEON) +// not support neon + static inline int ElementCos@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin = SIMD_LD_F32(input + index); + SIMD_ST_F32(output + index, SIMD_COS_F32(vin)); + } + return index; + } + + static inline int ElementLog@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin = SIMD_LD_F32(input + index); + SIMD_ST_F32(output + index, SIMD_LOG_F32(vin)); + } + return index; + } +#endif + +static inline int ElementSquare@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin = SIMD_LD_F32(input + index); + SIMD_ST_F32(output + index, SIMD_MUL_F32(vin, vin)); + } + return index; +} + +static inline int ElementSqrt@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_SQRT_F32(SIMD_LD_F32(input + index))); + } + return index; +} + +static inline int ElementRsqrt@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_RSQRT_F32(SIMD_LD_F32(input + index))); + } + return index; +} + +static inline int ElementMish@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + SIMD_F32 one = SIMD_MOV_F32(1.0f); + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 exp_add_one = SIMD_ADD_F32(SIMD_EXP_F32(SIMD_LD_F32(input + index)), one); + SIMD_F32 exp_pow = SIMD_MUL_F32(exp_add_one, exp_add_one); + SIMD_ST_F32(output + index, SIMD_MUL_F32(SIMD_LD_F32(input + index), + SIMD_DIV_F32(SIMD_SUB_F32(exp_pow, one), SIMD_ADD_F32(exp_pow, one)))); + } + return index; +} + +#if defined(MS_SIMD_AVX) || defined(MS_SIMD_SSE) +// avx512 dont support round fp32 instruction +static inline int ElementRound@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_ROUND_F32(SIMD_LD_F32(input + index))); + } + return index; +} +#endif + +#ifndef MS_SIMD_NEON +// neon dont support floor fp32 instruction +static inline int ElementFloor@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_FLOOR_F32(SIMD_LD_F32(input + index))); + } + return index; +} +#endif + +#ifndef MS_SIMD_NEON +static inline int ElementCeil@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_CEIL_F32(SIMD_LD_F32(input + index))); + } + return index; +} +#endif + +static inline int ElementNegative@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_MUL_N_F32(SIMD_LD_F32(input + index), -1.0f)); + } + return index; +} + +static inline int ElementNegativeInt@SIMD_INSTRUCTION@(int index, const int32_t *input, int32_t *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(output + index, SIMD_MUL_N_EPI32(SIMD_LD_EPI32(input + index), -1)); + } + return index; +} + +static inline int ElementReciprocal@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + SIMD_F32 num1 = SIMD_MOV_F32(1.0f); + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_DIV_F32(num1, SIMD_LD_F32(input + index))); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/attention_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/attention_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..e0bf8f2f81cbaa152257bfac8aa0efaaf8db7f74 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/attention_fp32.c @@ -0,0 +1,581 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/attention_fp32.h" +#include +#include +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/add_fp32.h" +#include "nnacl_c/fp32/transpose_fp32.h" +#include "nnacl_c/transpose_parameter.h" +#include "nnacl_c/fp32/softmax_fp32.h" +#include "nnacl_c/errorcode.h" + +int InitMatrix(Matrix *matrix, int batch, int row, int col, bool is_trans) { + if (matrix == NULL) { + return NNACL_NULL_PTR; + } + matrix->batch_ = batch; + matrix->row_ = row; + matrix->col_ = col; + matrix->is_transpose_ = is_trans; + matrix->data_ = NULL; + matrix->packed_data_ = NULL; + return NNACL_OK; +} + +size_t LeftMatrixPackElementSize(Matrix *matrix, int row_tile) { + if (matrix == NULL) { + return 0; + } + int real_row = matrix->is_transpose_ ? matrix->col_ : matrix->row_; + int deep = matrix->is_transpose_ ? matrix->row_ : matrix->col_; + bool vec_matmul = real_row == 1; + int row_align = vec_matmul ? 1 : UP_ROUND(real_row, row_tile); + int dst_area = row_align * deep; + matrix->packed_row_ = row_align; + matrix->packed_col_ = deep; + return matrix->batch_ * dst_area; +} + +size_t RightMatrixPackElementSize(Matrix *matrix, int col_tile) { + if (matrix == NULL) { + return 0; + } + int deep = matrix->is_transpose_ ? matrix->col_ : matrix->row_; + int real_col = matrix->is_transpose_ ? matrix->row_ : matrix->col_; + bool vec_matmul = deep == 1; + int col_align = vec_matmul ? real_col : UP_ROUND(real_col, col_tile); + int dst_area = deep * col_align; + matrix->packed_row_ = deep; + matrix->packed_col_ = col_align; + return matrix->batch_ * dst_area; +} + +int PackLeftMatrix(Matrix *matrix, int row_tile) { + if (matrix == NULL || matrix->data_ == NULL || row_tile == 0) { + return NNACL_NULL_PTR; + } + int real_row = matrix->is_transpose_ ? matrix->col_ : matrix->row_; + int deep = matrix->is_transpose_ ? matrix->row_ : matrix->col_; + bool vec_matmul = real_row == 1; + int row_align = vec_matmul ? 1 : UP_ROUND(real_row, row_tile); + int src_area = matrix->row_ * matrix->col_; + int dst_area = row_align * deep; + bool malloced = false; + if (matrix->packed_data_ == NULL) { + matrix->packed_data_ = (float *)malloc(dst_area * matrix->batch_ * sizeof(float)); + if (matrix->packed_data_ == NULL) { + return NNACL_NULL_PTR; + } + malloced = true; + } + + if (vec_matmul) { + memcpy(matrix->packed_data_, matrix->data_, matrix->batch_ * dst_area * sizeof(float)); + } else { + for (int i = 0; i < matrix->batch_; i++) { + const float *cur_src = matrix->data_ + i * src_area; + float *cur_dst = matrix->packed_data_ + i * dst_area; + switch (row_tile) { + case C6NUM: + if (matrix->is_transpose_) { + RowMajor2Row6Major(cur_src, cur_dst, real_row, deep); + } else { + RowMajor2Col6Major(cur_src, cur_dst, real_row, deep); + } + break; + case C4NUM: + if (matrix->is_transpose_) { + RowMajor2Row4Major(cur_src, cur_dst, real_row, deep); + } else { + RowMajor2Col4Major(cur_src, cur_dst, real_row, deep); + } + break; + case C12NUM: + if (matrix->is_transpose_) { + RowMajor2Row12Major(cur_src, cur_dst, real_row, deep); + } else { + RowMajor2Col12Major(cur_src, cur_dst, real_row, deep); + } + break; + default: + if (malloced) { + free(matrix->packed_data_); + matrix->packed_data_ = NULL; + return NNACL_ERR; + } + break; + } + } + } + matrix->packed_row_ = row_align; + matrix->packed_col_ = deep; + return NNACL_OK; +} + +int PackRightMatrix(Matrix *matrix, int col_tile) { + if (matrix == NULL || matrix->data_ == NULL || col_tile == 0) { + return NNACL_NULL_PTR; + } + int deep = matrix->is_transpose_ ? matrix->col_ : matrix->row_; + int real_col = matrix->is_transpose_ ? matrix->row_ : matrix->col_; + bool vec_matmul = deep == 1; + int col_align = vec_matmul ? real_col : UP_ROUND(real_col, col_tile); + int src_area = matrix->row_ * matrix->col_; + int dst_area = deep * col_align; + bool malloced = false; + if (matrix->packed_data_ == NULL) { + matrix->packed_data_ = (float *)malloc(dst_area * matrix->batch_ * sizeof(float)); + if (matrix->packed_data_ == NULL) { + return NNACL_NULL_PTR; + } + malloced = true; + } + if (vec_matmul) { + memcpy(matrix->packed_data_, matrix->data_, matrix->batch_ * dst_area * sizeof(float)); + } else { + for (int i = 0; i < matrix->batch_; i++) { + const float *cur_src = matrix->data_ + i * src_area; + float *cur_dst = matrix->packed_data_ + i * dst_area; + switch (col_tile) { + case C16NUM: + if (matrix->is_transpose_) { + RowMajor2Col16Major(cur_src, cur_dst, deep, real_col); + } else { + RowMajor2Row16Major(cur_src, cur_dst, deep, real_col); + } + break; + case C4NUM: + if (matrix->is_transpose_) { + RowMajor2Col4Major(cur_src, cur_dst, deep, real_col); + } else { + RowMajor2Row4Major(cur_src, cur_dst, deep, real_col); + } + break; + case C8NUM: + if (matrix->is_transpose_) { + RowMajor2Col8Major(cur_src, cur_dst, deep, real_col); + } else { + RowMajor2Row8Major(cur_src, cur_dst, deep, real_col); + } + break; + default: + if (malloced) { + free(matrix->packed_data_); + matrix->packed_data_ = NULL; + return NNACL_ERR; + } + break; + } + } + } + matrix->packed_row_ = deep; + matrix->packed_col_ = col_align; + return NNACL_OK; +} + +int PackAttentionBias(Matrix *matrix, int tile) { + if (matrix == NULL || matrix->batch_ != 1 || matrix->row_ != 1 || matrix->data_ == NULL) { + return NNACL_PARAM_INVALID; + } + if (tile == 0) { + return NNACL_OK; + } + int size = matrix->col_; + float *src = matrix->data_; + int size_align = UP_ROUND(size, tile); + if (size_align <= 0) { + return NNACL_ERR; + } + matrix->packed_data_ = (float *)malloc(size_align * sizeof(float)); + if (matrix->packed_data_ == NULL) { + return NNACL_NULL_PTR; + } + matrix->packed_row_ = matrix->row_; + matrix->packed_col_ = size_align; + memset(matrix->packed_data_, 0, size_align * sizeof(float)); + memcpy(matrix->packed_data_, src, size * sizeof(float)); + return NNACL_OK; +} + +static void RelativeShiftPad(const float *input_data, float *output_data, const int32_t *input_shape, int tid, + int thread_num) { + int row = input_shape[0]; + int col = input_shape[1]; + int out_area = row * (col + 1); + memset(output_data, 0, out_area * sizeof(float)); + for (int r = tid; r < row; r += thread_num) { + float *dst = output_data + r * (col + 1); + const float *src = input_data + r * col; + memcpy(dst, src, col * sizeof(float)); + } + int tile = row % thread_num; + for (int r = row - tile; r < row; r++) { + float *dst = output_data + r * (col + 1); + const float *src = input_data + r * col; + memcpy(dst, src, col * sizeof(float)); + } +} + +static void RelativeShiftSlice(const float *input_data, float *output_data, const int32_t *input_shape, int tid, + int thread_num) { + int row = input_shape[0]; + int col = input_shape[1]; + int begin = row; + memset(output_data, 0, row * row * sizeof(float)); + for (int r = tid; r < row; r += thread_num) { + float *dst = output_data + r * row; + const float *src = input_data + r * col + begin; + memcpy(dst, src, (col / 2) * sizeof(float)); + } + int tile = row % thread_num; + for (int r = row - tile; r < row; r++) { + float *dst = output_data + r * row; + const float *src = input_data + r * col + begin; + memcpy(dst, src, (col / 2) * sizeof(float)); + } +} + +static void RelativeShift(const Matrix *x, float *pad_buf, float *slice_buf) { + int x_area = x->row_ * x->col_; + int pad_area = x->row_ * (x->col_ + 1); + int slice_area = x->row_ * (x->col_ / 2); + int input_shape[] = {x->row_, x->col_}; + memset(slice_buf, 0, x->batch_ * x->row_ * (x->col_ / 2) * sizeof(float)); + for (int i = 0; i < x->batch_; i++) { + float *cur_x_data = x->data_ + i * x_area; + memset(pad_buf, 0, pad_area * sizeof(float)); + // pad: [row, col + 1] + RelativeShiftPad(cur_x_data, pad_buf, input_shape, 0, 1); + // reshape: [col + 1, row] + // slice last row: [col, row] + // reshape: [row, col] + // slice col -> [row, row + col / 2]: [row, col / 2] + float *cur_slice_data = slice_buf + i * slice_area; + RelativeShiftSlice(pad_buf, cur_slice_data, input_shape, 0, 1); + } +} + +static void ElementOptAddDiv(const float *input0, const float *input1, const float input2, float *output, + const int batch, const int area) { + int index = 0; + const float mul = 1 / input2; + for (int b = 0; b < batch; b++) { + const float *cur_input0 = input0 + b * area; + const float *cur_input1 = input1 + b * area; + float *cur_output = output + b * area; +#ifdef ENABLE_NEON + for (; index <= area - 4; index += C4NUM) { + float32x4_t vin0 = vld1q_f32(cur_input0 + index); + float32x4_t vin1 = vld1q_f32(cur_input1 + index); + float32x4_t vout = vaddq_f32(vin0, vin1); + vout = vmulq_n_f32(vout, mul); + vst1q_f32(cur_output + index, vout); + } +#endif + for (; index < area; index++) { + cur_output[index] += (cur_input0[index] + cur_input1[index]) * mul; + } + } +} + +static bool GetTransposeParameter(TransposeParameter *param, const int in_shape[], int in_shape_len, + const int out_shape[], int out_shape_len, const int perm[], int perm_len) { + param->num_axes_ = perm_len; + size_t shape_size = 1; + for (int i = 0; i < perm_len; i++) { + param->perm_[i] = perm[i]; + shape_size *= perm[i]; // check overflow + } + param->data_num_ = (int)shape_size; // check overflow + param->strides_[param->num_axes_ - 1] = 1; + param->out_strides_[param->num_axes_ - 1] = 1; + if (param->num_axes_ - 1 >= in_shape_len) { + return false; + } + if (param->num_axes_ - 1 >= out_shape_len) { + return false; + } + for (int i = param->num_axes_ - 2; i >= 0; i--) { + param->strides_[i] = in_shape[i + 1] * param->strides_[i + 1]; + param->out_strides_[i] = out_shape[i + 1] * param->out_strides_[i + 1]; + } + return true; +} + +void QWithPosition(RelativePositionAttentionParameter *param, Matrix *q_mat, const Matrix *wq_mat, Matrix *bq_mat, + Matrix *q2wq_mat, Matrix *pu_mat, Matrix *pv_mat, Matrix *q2wq_with_pos_mat, + Matrix *q2wq_with_pu_trans_mat, Matrix *q2wq_with_pv_trans_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int batch = param->batch_; + int depth = d_model / num_heads; + // Q * WQ + int q_area = q_mat->packed_row_ * q_mat->packed_col_; + int wq_area = wq_mat->packed_row_ * wq_mat->packed_col_; + int q2wq_area = q2wq_mat->row_ * q2wq_mat->col_ * q2wq_mat->batch_ / param->batch_; + float *q2wq_data = q2wq_mat->data_; + memset(q2wq_data, 0, param->batch_ * q2wq_area * sizeof(float)); + for (int i = 0; i < param->batch_; i++) { + float *cur_q = q_mat->packed_data_ + i * q_area; + float *cur_wq = wq_mat->packed_data_ + i * wq_area; + float *cur_q2wq = q2wq_data + i * q2wq_area; + MatMulOpt(cur_q, cur_wq, cur_q2wq, bq_mat->packed_data_, ActType_No, q_mat->col_, q_mat->row_, wq_mat->col_, + wq_mat->col_, OutType_Nhwc); + } + // transpose param init + TransposeParameter q_with_pos_trans_param; + int q_with_pos_trans_in_shape[] = {batch, param->q_seq_, num_heads, depth}; + int q_with_pos_trans_out_shape[] = {batch, num_heads, param->q_seq_, depth}; + int q_with_pos_perm[] = {0, 2, 1, 3}; + (void)GetTransposeParameter(&q_with_pos_trans_param, q_with_pos_trans_in_shape, 4, q_with_pos_trans_out_shape, 4, + q_with_pos_perm, 4); + int q2wq_reshaped_area = q2wq_mat->row_ * q2wq_mat->col_; + // Q_WQ + POS_U + { + float *q_with_pu = q2wq_with_pos_mat->data_; + int q_with_pu_area = q2wq_with_pos_mat->row_ * q2wq_with_pos_mat->col_; + memset(q_with_pu, 0, q2wq_with_pos_mat->batch_ * q_with_pu_area * sizeof(float)); + for (int i = 0; i < q2wq_with_pos_mat->batch_; i++) { + float *cur_qw = q2wq_data + i * q2wq_reshaped_area; + float *cur_q_with_pu = q_with_pu + i * q_with_pu_area; + ElementAdd(cur_qw, pu_mat->packed_data_, cur_q_with_pu, q_with_pu_area); + } + // Q_WITH_U perm [0,2,1,3] + float *q_with_pu_trans = q2wq_with_pu_trans_mat->data_; + size_t q_with_pu_trans_data_size = (size_t)(q2wq_with_pu_trans_mat->batch_) * + (size_t)(q2wq_with_pu_trans_mat->row_) * (size_t)(q2wq_with_pu_trans_mat->col_) * + sizeof(float); + memset(q_with_pu_trans, 0, q_with_pu_trans_data_size); + TransposeDimsFp32(q_with_pu, q_with_pu_trans, q_with_pos_trans_out_shape, q_with_pos_trans_param.perm_, + q_with_pos_trans_param.strides_, q_with_pos_trans_param.out_strides_, + q_with_pos_trans_param.num_axes_, 0, 1); + } + + // Q_WQ + POS_V + { + float *q_with_pv = q2wq_with_pos_mat->data_; + int q_with_pv_area = q2wq_with_pos_mat->row_ * q2wq_with_pos_mat->col_; + memset(q_with_pv, 0, q2wq_with_pos_mat->batch_ * q_with_pv_area * sizeof(float)); + for (int i = 0; i < q2wq_with_pos_mat->batch_; i++) { + float *cur_qw = q2wq_data + i * q2wq_reshaped_area; + float *cur_q_with_pv = q_with_pv + i * q_with_pv_area; + ElementAdd(cur_qw, pv_mat->packed_data_, cur_q_with_pv, q_with_pv_area); + } + // Q_WITH_V perm [0,2,1,3] + float *q_with_pv_trans = q2wq_with_pv_trans_mat->data_; + size_t q_with_pv_trans_data_size = (size_t)(q2wq_with_pv_trans_mat->batch_) * + (size_t)(q2wq_with_pv_trans_mat->row_) * (size_t)(q2wq_with_pv_trans_mat->col_) * + sizeof(float); + memset(q_with_pv_trans, 0, q_with_pv_trans_data_size); + TransposeDimsFp32(q_with_pv, q_with_pv_trans, q_with_pos_trans_out_shape, q_with_pos_trans_param.perm_, + q_with_pos_trans_param.strides_, q_with_pos_trans_param.out_strides_, + q_with_pos_trans_param.num_axes_, 0, 1); + } +} + +void KMulWeightK(RelativePositionAttentionParameter *param, Matrix *k_mat, const Matrix *wk_mat, Matrix *bk_mat, + Matrix *k2wk_mat, Matrix *k2wk_trans_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int batch = param->batch_; + int depth = d_model / num_heads; + // K * WK + int k_area = k_mat->packed_row_ * k_mat->packed_col_; + int wk_area = wk_mat->packed_row_ * wk_mat->packed_col_; + int k2wk_area = k2wk_mat->row_ * k2wk_mat->col_ * k2wk_mat->batch_ / param->batch_; + float *k2wk = k2wk_mat->data_; + memset(k2wk, 0, param->batch_ * k2wk_area * sizeof(float)); + for (int i = 0; i < param->batch_; i++) { + float *cur_k = k_mat->packed_data_ + i * k_area; + float *cur_wk = wk_mat->packed_data_ + i * wk_area; + float *cur_k2wk = k2wk + i * k2wk_area; + MatMulOpt(cur_k, cur_wk, cur_k2wk, bk_mat->packed_data_, ActType_No, k_mat->col_, k_mat->row_, wk_mat->col_, + wk_mat->col_, OutType_Nhwc); + } + // K * WK perm [0,2,3,1] + float *k2wk_trans_data = k2wk_trans_mat->data_; + int k2wk_trans_area = k2wk_trans_mat->row_ * k2wk_trans_mat->col_; + memset(k2wk_trans_data, 0, k2wk_trans_mat->batch_ * k2wk_trans_area * sizeof(float)); + TransposeParameter k2wk_trans_param; + int k2wk_in_shape[] = {batch, param->k_seq_, num_heads, depth}; + int k2wk_out_shape[] = {batch, num_heads, depth, param->k_seq_}; + int k2wk_perm[] = {0, 2, 3, 1}; + (void)GetTransposeParameter(&k2wk_trans_param, k2wk_in_shape, 4, k2wk_out_shape, 4, k2wk_perm, 4); + TransposeDimsFp32(k2wk, k2wk_trans_data, k2wk_out_shape, k2wk_trans_param.perm_, k2wk_trans_param.strides_, + k2wk_trans_param.out_strides_, k2wk_trans_param.num_axes_, 0, 1); +} + +void VMulWeightV(RelativePositionAttentionParameter *param, Matrix *v_mat, const Matrix *wv_mat, Matrix *bv_mat, + Matrix *v2wv_mat, Matrix *v2wv_trans_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int batch = param->batch_; + int depth = d_model / num_heads; + // V * WV + int v_area = v_mat->packed_row_ * v_mat->packed_col_; + int wv_area = wv_mat->packed_row_ * wv_mat->packed_col_; + int v2wv_area = v2wv_mat->row_ * v2wv_mat->col_ * v2wv_mat->batch_ / param->batch_; + float *v2wv = v2wv_mat->data_; + memset(v2wv, 0, param->batch_ * v2wv_area * sizeof(float)); + for (int i = 0; i < param->batch_; i++) { + float *cur_v = v_mat->packed_data_ + i * v_area; + float *cur_wv = wv_mat->packed_data_ + i * wv_area; + float *cur_v2wv = v2wv + i * v2wv_area; + MatMulOpt(cur_v, cur_wv, cur_v2wv, bv_mat->packed_data_, ActType_No, v_mat->col_, v_mat->row_, wv_mat->col_, + wv_mat->col_, OutType_Nhwc); + } + // V * WV perm [0,2,1,3] + float *v2wv_trans_data = v2wv_trans_mat->data_; + int v2wv_trans_area = v2wv_trans_mat->row_ * v2wv_trans_mat->col_; + memset(v2wv_trans_data, 0, v2wv_trans_mat->batch_ * v2wv_trans_area * sizeof(float)); + TransposeParameter v2wv_trans_param; + int v2wv_in_shape[] = {batch, param->v_seq_, num_heads, depth}; + int v2wv_out_shape[] = {batch, num_heads, param->v_seq_, depth}; + int v2wv_perm[] = {0, 2, 1, 3}; + (void)GetTransposeParameter(&v2wv_trans_param, v2wv_in_shape, 4, v2wv_out_shape, 4, v2wv_perm, 4); + TransposeDimsFp32(v2wv, v2wv_trans_data, v2wv_out_shape, v2wv_trans_param.perm_, v2wv_trans_param.strides_, + v2wv_trans_param.out_strides_, v2wv_trans_param.num_axes_, 0, 1); +} + +void PMulWeightP(RelativePositionAttentionParameter *param, Matrix *p_mat, const Matrix *wp_mat, Matrix *p2wp_mat, + Matrix *p2wp_trans_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int batch = param->batch_; + int depth = d_model / num_heads; + + // P * WP + int p_area = p_mat->packed_row_ * p_mat->packed_col_; + int wp_area = wp_mat->packed_row_ * wp_mat->packed_col_; + int p2wp_area = p2wp_mat->row_ * p2wp_mat->col_ * p2wp_mat->batch_ / param->batch_; + float *p2wp_data = p2wp_mat->data_; + memset(p2wp_data, 0, param->batch_ * p2wp_area * sizeof(float)); + for (int i = 0; i < param->batch_; i++) { + float *cur_p = p_mat->packed_data_ + i * p_area; + float *cur_wp = wp_mat->packed_data_ + i * wp_area; + float *cur_p2wp = p2wp_data + i * p2wp_area; + MatMulOpt(cur_p, cur_wp, cur_p2wp, NULL, ActType_No, p_mat->col_, p_mat->row_, wp_mat->col_, wp_mat->col_, + OutType_Nhwc); + } + // P * WP perm [0,2,3,1] + float *p2wp_trans_data = p2wp_trans_mat->data_; + int p2wp_trans_area = p2wp_trans_mat->row_ * p2wp_trans_mat->col_; + memset(p2wp_trans_data, 0, p2wp_trans_mat->batch_ * p2wp_trans_area * sizeof(float)); + TransposeParameter p2wp_trans_param; + int p2wp_in_shape[] = {batch, param->p_seq_, num_heads, depth}; + int p2wp_out_shape[] = {batch, num_heads, depth, param->p_seq_}; + int p2wp_perm[] = {0, 2, 3, 1}; + (void)GetTransposeParameter(&p2wp_trans_param, p2wp_in_shape, 4, p2wp_out_shape, 4, p2wp_perm, 4); + TransposeDimsFp32(p2wp_data, p2wp_trans_data, p2wp_out_shape, p2wp_trans_param.perm_, p2wp_trans_param.strides_, + p2wp_trans_param.out_strides_, p2wp_trans_param.num_axes_, 0, 1); +} + +void CalculateLogits(RelativePositionAttentionParameter *param, Matrix *q2wq_with_pu_trans_mat, + Matrix *q2wq_with_pv_trans_mat, Matrix *k2wk_trans_mat, Matrix *p2wp_trans_mat, + Matrix *logits_with_u_mat, Matrix *logits_with_v_mat, Matrix *logits_with_v_pad_mat, + Matrix *logits_with_v_shifted_mat, Matrix *logits_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int depth = d_model / num_heads; + + // pack Q_WITH_U as left_matrix + // since we malloc dst data, pack function can not be failed + (void)PackLeftMatrix(q2wq_with_pu_trans_mat, param->row_tile_); + // pack Q_WITH_V as left_matrix + (void)PackLeftMatrix(q2wq_with_pv_trans_mat, param->row_tile_); + // pack K * WK as right_matrix + (void)PackRightMatrix(k2wk_trans_mat, param->col_tile_); + // pack P * WP as right_matrix + (void)PackRightMatrix(p2wp_trans_mat, param->col_tile_); + + // q_with_pu * k = logits_with_u + MatMulOpt(q2wq_with_pu_trans_mat->packed_data_, k2wk_trans_mat->packed_data_, logits_with_u_mat->data_, NULL, + ActType_No, q2wq_with_pu_trans_mat->col_, logits_with_u_mat->row_, logits_with_u_mat->col_, + logits_with_u_mat->col_, OutType_Nhwc); + + // q_with_pv * p = logits_with_v + MatMulOpt(q2wq_with_pv_trans_mat->packed_data_, p2wp_trans_mat->packed_data_, logits_with_v_mat->data_, NULL, + ActType_No, q2wq_with_pv_trans_mat->col_, logits_with_v_mat->row_, logits_with_v_mat->col_, + logits_with_v_mat->col_, OutType_Nhwc); + // relative shift logits_with_v + float *pad_buf = logits_with_v_pad_mat->data_; + float *logits_with_v_shifted_data = logits_with_v_shifted_mat->data_; + RelativeShift(logits_with_v_mat, pad_buf, logits_with_v_shifted_data); + // logits = (logits_with_u + logits_with_v) / sqrt(depth) + float *logits_buffer = logits_mat->data_; + ElementOptAddDiv(logits_with_u_mat->data_, logits_with_v_shifted_data, 1 / sqrt(depth), logits_buffer, + logits_with_u_mat->batch_, logits_with_u_mat->row_ * logits_with_u_mat->col_); +} + +void RelPosAttention(RelativePositionAttentionParameter *param, Matrix *logits_mat, Matrix *softmax_mat, + Matrix *v2wv_trans_mat, Matrix *logits2v_mat, Matrix *logits2v_trans_mat, const Matrix *wo_mat, + Matrix *bo_mat, Matrix *output_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int batch = param->batch_; + int depth = d_model / num_heads; + float *logits_buffer = logits_mat->data_; + // softmax(logits) + SoftmaxLastAxis(logits_buffer, softmax_mat->data_, batch * num_heads * softmax_mat->row_, softmax_mat->col_); + + // logits * v + (void)PackLeftMatrix(softmax_mat, param->row_tile_); + (void)PackRightMatrix(v2wv_trans_mat, param->col_tile_); + int softmax_logits_area = softmax_mat->packed_row_ * softmax_mat->packed_col_; + int v2wv_area = v2wv_trans_mat->packed_row_ * v2wv_trans_mat->packed_col_; + int logits2v_area = logits2v_mat->row_ * logits2v_mat->col_; + float *logits2v_data = logits2v_mat->data_; + memset(logits2v_data, 0, logits2v_mat->batch_ * logits2v_area * sizeof(float)); + for (int i = 0; i < logits2v_mat->batch_; i++) { + float *cur_logits = softmax_mat->packed_data_ + i * softmax_logits_area; + float *cur_v2wv = v2wv_trans_mat->packed_data_ + i * v2wv_area; + float *cur_logits2v = logits2v_data + i * logits2v_area; + MatMulOpt(cur_logits, cur_v2wv, cur_logits2v, NULL, ActType_No, softmax_mat->col_, softmax_mat->row_, + v2wv_trans_mat->col_, v2wv_trans_mat->col_, OutType_Nhwc); + } + // multi_head output perm [0,2,1,3] + float *logits2v_trans_data = logits2v_trans_mat->data_; + int logits2v_trans_area = logits2v_trans_mat->row_ * logits2v_trans_mat->col_; + memset(logits2v_trans_data, 0, logits2v_trans_mat->batch_ * logits2v_trans_area * sizeof(float)); + TransposeParameter logits2v_trans_param; + int logits2v_trans_in_shape[] = {batch, num_heads, param->q_seq_, depth}; + int logits2v_trans_out_shape[] = {batch, param->q_seq_, num_heads, depth}; + int logits2v_trans_perm[] = {0, 2, 1, 3}; + (void)GetTransposeParameter(&logits2v_trans_param, logits2v_trans_in_shape, 4, logits2v_trans_out_shape, 4, + logits2v_trans_perm, 4); + TransposeDimsFp32(logits2v_data, logits2v_trans_data, logits2v_trans_out_shape, logits2v_trans_param.perm_, + logits2v_trans_param.strides_, logits2v_trans_param.out_strides_, logits2v_trans_param.num_axes_, 0, + 1); + // concat = reshape [batch, -1, d_model] + logits2v_trans_mat->batch_ = batch; + logits2v_trans_mat->row_ = param->q_seq_; + logits2v_trans_mat->col_ = param->d_model_; + // * o + (void)PackLeftMatrix(logits2v_trans_mat, param->row_tile_); + int concat_out_area = logits2v_trans_mat->packed_row_ * logits2v_trans_mat->packed_col_; + int wo_area = wo_mat->packed_row_ * wo_mat->packed_col_; + int output_area = output_mat->row_ * output_mat->col_; + for (int i = 0; i < output_mat->batch_; i++) { + float *cur_concat_out = logits2v_trans_mat->packed_data_ + i * concat_out_area; + float *cur_wo = wo_mat->packed_data_ + i * wo_area; + float *cur_output = output_mat->data_ + i * output_area; + MatMulOpt(cur_concat_out, cur_wo, cur_output, bo_mat->packed_data_, ActType_No, logits2v_trans_mat->col_, + logits2v_trans_mat->row_, wo_mat->col_, wo_mat->col_, OutType_Nhwc); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/attention_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/attention_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..68f8b05fbaaa966c88c0665e2969ba40eea09f3d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/attention_fp32.h @@ -0,0 +1,72 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ATTENTION_FP32_H_ +#define MINDSPORE_NNACL_FP32_ATTENTION_FP32_H_ + +#include "nnacl_c/attention_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct Matrix { + float *data_; + int row_; + int col_; + float *packed_data_; + int packed_row_; + int packed_col_; + int batch_; + bool is_transpose_; +} Matrix; + +int InitMatrix(Matrix *matrix, int batch, int row, int col, bool is_trans); + +size_t LeftMatrixPackElementSize(Matrix *matrix, int row_tile); + +size_t RightMatrixPackElementSize(Matrix *matrix, int col_tile); + +int PackLeftMatrix(Matrix *matrix, int row_tile); + +int PackRightMatrix(Matrix *matrix, int col_tile); + +int PackAttentionBias(Matrix *matrix, int tile); + +void QWithPosition(RelativePositionAttentionParameter *param, Matrix *q_mat, const Matrix *wq_mat, Matrix *bq_mat, + Matrix *q2wq_mat, Matrix *pu_mat, Matrix *pv_mat, Matrix *q2wq_with_pos_mat, + Matrix *q2wq_with_pu_trans_mat, Matrix *q2wq_with_pv_trans_mat); + +void KMulWeightK(RelativePositionAttentionParameter *param, Matrix *k_mat, const Matrix *wk_mat, Matrix *bk_mat, + Matrix *k2wk_mat, Matrix *k2wk_trans_mat); + +void VMulWeightV(RelativePositionAttentionParameter *param, Matrix *v_mat, const Matrix *wv_mat, Matrix *bv_mat, + Matrix *v2wv_mat, Matrix *v2wv_trans_mat); + +void PMulWeightP(RelativePositionAttentionParameter *param, Matrix *p_mat, const Matrix *wp_mat, Matrix *p2wp_mat, + Matrix *p2wp_trans_mat); + +void CalculateLogits(RelativePositionAttentionParameter *param, Matrix *q2wq_with_pu_trans_mat, + Matrix *q2wq_with_pv_trans_mat, Matrix *k2wk_trans_mat, Matrix *p2wp_trans_mat, + Matrix *logits_with_u_mat, Matrix *logits_with_v_mat, Matrix *logits_with_v_pad_mat, + Matrix *logits_with_v_shifted_mat, Matrix *logits_mat); + +void RelPosAttention(RelativePositionAttentionParameter *param, Matrix *logits_mat, Matrix *softmax_mat, + Matrix *v2wv_trans_mat, Matrix *logits2v_mat, Matrix *logits2v_trans_mat, const Matrix *wo_mat, + Matrix *bo_mat, Matrix *output_mat); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_ATTENTION_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d6a0d53f8e2f20e688b16b66defb0753fd0af9cd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32.c @@ -0,0 +1,129 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/batchnorm_fp32.h" +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/batchnorm_fp32_simd.h" +#include "nnacl_c/kernel/fused_batch_norm.h" +#include "nnacl_c/tensor_c_utils.h" + +int FusedBatchNormEval(KernelBase *self) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + + if (fused_batch_norm->trained_) { + TensorC *scale_tensor = fused_batch_norm->bn_.base_.in_[SECOND_INPUT]; + TensorC *offset_tensor = fused_batch_norm->bn_.base_.in_[THIRD_INPUT]; + TensorC *mean_tensor = fused_batch_norm->bn_.base_.in_[FOURTH_INPUT]; + TensorC *var_tensor = fused_batch_norm->bn_.base_.in_[FIFTH_INPUT]; + (void)memcpy(fused_batch_norm->scale_, scale_tensor->data_, NNACLGetSize(scale_tensor)); + (void)memcpy(fused_batch_norm->offset_, offset_tensor->data_, NNACLGetSize(offset_tensor)); + (void)memcpy(fused_batch_norm->bn_.mean_, mean_tensor->data_, NNACLGetSize(mean_tensor)); + (void)memcpy(fused_batch_norm->bn_.variance_, var_tensor->data_, NNACLGetSize(var_tensor)); + } + return NNACL_OK; +} + +void BatchNormSetupVirtualBatch(KernelBase *self, int virtual_batch_multiplier, int momentum) { + BatchNormStruct *bn = (BatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_VOID(bn); + if (virtual_batch_multiplier > 0) { + float new_momentum = (momentum < 0.0f) ? (bn->momentum_ / virtual_batch_multiplier) : momentum; + bn->momentum_ = new_momentum; + } + return; +} + +void BatchNormFp32(const float *input, const float *mean, const float *variance, const BatchNormStruct *param, + int task_id, int thread_num, float *output) { + int units_per_thread = UP_DIV(param->unit_, thread_num); + int completed_units = task_id * units_per_thread; + int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units); + int channel = param->channel_; + int cur_offset = completed_units * channel; + float epsilon = param->epsilon_; + + for (int i = 0; i < cur_unit; i++) { + const float *unit_input = input + cur_offset; + float *unit_output = output + cur_offset; + int c = 0; + + SIMD_RUN_NO_SCALAR(BatchNormFp32, c, unit_input, mean, variance, channel, epsilon, unit_output); + + for (; c < channel; c++) { + float variance_sqrt = sqrtf(variance[c] + epsilon); + unit_output[c] = (unit_input[c] - mean[c]) / variance_sqrt; + } + cur_offset += channel; + } +} + +void FusedBatchNormFp32(const float *input, const float *scale, const float *offset, const float *mean, + const float *variance, const BatchNormStruct *param, int task_id, int thread_num, + float *output) { + int units_per_thread = UP_DIV(param->unit_, thread_num); + int completed_units = task_id * units_per_thread; + int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units); + int channel = param->channel_; + float epsilon = param->epsilon_; + int cur_offset = completed_units * channel; + + for (int i = 0; i < cur_unit; i++) { + const float *unit_input = input + cur_offset; + float *unit_output = output + cur_offset; + int c = 0; + + SIMD_RUN_NO_SCALAR(FusedBatchNormFp32, c, unit_input, scale, offset, mean, variance, channel, epsilon, unit_output); + + for (; c < channel; c++) { + float variance_sqrt = sqrtf(variance[c] + epsilon); + float norm_val = (unit_input[c] - mean[c]) / variance_sqrt; + unit_output[c] = norm_val * scale[c] + offset[c]; + } + cur_offset += channel; + } +} + +void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, const BatchNormStruct *param, + float *save_mean, float *save_var, bool isBatchNorm2d) { + const float N = (float)param->unit_; + const float VN = N; + const float VNUB = (isBatchNorm2d == false) ? N : ((N > 1.0f) ? (N - 1.0f) : 1.0f); + const float momentum = (1.0f - param->momentum_); + + for (int i = 0; i < param->unit_; i++) { + for (int c = 0; c < param->channel_; c++) { + int idx = i * param->channel_ + c; + run_mean[c] += input[idx]; + } + } + for (int c = 0; c < param->channel_; c++) { + run_mean[c] /= N; + } + for (int i = 0; i < param->unit_; i++) { + for (int c = 0; c < param->channel_; c++) { + int idx = i * param->channel_ + c; + run_var[c] += (input[idx] - run_mean[c]) * (input[idx] - run_mean[c]); + } + } + for (int c = 0; c < param->channel_; c++) { + float unbiased_var = (run_var[c] / VNUB); + run_var[c] = (run_var[c] / VN); + save_mean[c] = momentum * save_mean[c] + (1.0f - momentum) * run_mean[c]; + save_var[c] = momentum * save_var[c] + (1.0f - momentum) * unbiased_var; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..f09f445343f035de3500b52da1562ac9701b5770 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_BATCHNORM_FP32_H_ +#define NNACL_FP32_BATCHNORM_FP32_H_ + +#include "nnacl_c/kernel/batch_norm.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void BatchNormSetupVirtualBatch(KernelBase *self, int virtual_batch_multiplier, int momentum); +void BatchNormFp32(const float *input, const float *mean, const float *variance, const BatchNormStruct *param, + int task_id, int thread_num, float *output); + +int FusedBatchNormEval(KernelBase *self); +void FusedBatchNormFp32(const float *input, const float *scale, const float *offset, const float *mean, + const float *variance, const BatchNormStruct *param, int task_id, int thread_num, + float *output); +void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, const BatchNormStruct *param, + float *save_mean, float *save_var, bool isBatchNorm2d); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_BATCHNORM_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..fbf5f039ca76a532a05c8d2474b45d4d7a9de601 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32_simd.h.in @@ -0,0 +1,60 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int BatchNormFp32@SIMD_INSTRUCTION@(int index, const float *input, const float *mean, + const float *variance, int channel, float epsilon, float *output) { + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input_data = SIMD_LD_F32(input + index); + SIMD_F32 mean_ = SIMD_LD_F32(mean + index); + SIMD_F32 variance_ = SIMD_LD_F32(variance + index); + SIMD_F32 variance_sqrt = SIMD_SQRT_F32(SIMD_ADD_F32(variance_, SIMD_MOV_F32(epsilon))); + SIMD_F32 output_data = SIMD_DIV_F32(SIMD_SUB_F32(input_data, mean_), variance_sqrt); + SIMD_ST_F32(output + index, output_data); + } + return index; +} + +static inline int FusedBatchNormFp32@SIMD_INSTRUCTION@(int index, const float *input, const float *scale, + const float *offset, const float *mean, const float *variance, int channel, float epsilon, float *output) { + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input_data = SIMD_LD_F32(input + index); + SIMD_F32 scale_ = SIMD_LD_F32(scale + index); + SIMD_F32 offset_ = SIMD_LD_F32(offset + index); + SIMD_F32 mean_ = SIMD_LD_F32(mean + index); + SIMD_F32 variance_ = SIMD_LD_F32(variance + index); + SIMD_F32 variance_sqrt = SIMD_SQRT_F32(SIMD_ADD_F32(variance_, SIMD_MOV_F32(epsilon))); + SIMD_F32 norm_val = SIMD_DIV_F32(SIMD_SUB_F32(input_data, mean_), variance_sqrt); + SIMD_F32 output_data = SIMD_ADD_F32(SIMD_MUL_F32(norm_val, scale_), offset_); + SIMD_ST_F32(output + index, output_data); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_logits_loss_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_logits_loss_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..206d833a06e39da7a00a7e44a203850b9aeb9e1e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_logits_loss_fp32.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ACTIVATION_H_ +#define MINDSPORE_NNACL_FP32_ACTIVATION_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void BCEWithLogitLoss(const float *logits, const float *label, const float *weight, const float *pos_weight, int length, + bool reduction, float *output, float *reduction_sum); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_ACTIVATION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_logits_loss_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_logits_loss_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..12e3fcdf33d695271879612f2ae5ec51b86f35ba --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_logits_loss_fp32_simd.h.in @@ -0,0 +1,62 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_BCE_WITH_LOGITS_LOSS_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_BCE_WITH_LOGITS_LOSS_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int BCEWithLogitLoss@SIMD_INSTRUCTION@(int index, const float *logits, const float *label, + const float *weight, const float *pos_weight, int length, bool reduction, float *output, + float *reduction_sum) { + SIMD_F32 zero = SIMD_SET0_F32; + SIMD_F32 ones = SIMD_MOV_F32(1.0f); + SIMD_F32 middle_output = SIMD_SET0_F32; + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 logits_tmp = SIMD_LD_F32(logits + index); + SIMD_F32 label_tmp = SIMD_LD_F32(label + index); + SIMD_F32 weight_tmp = SIMD_LD_F32(weight + index); + SIMD_F32 pos_weight_tmp = SIMD_LD_F32(pos_weight + index); + SIMD_F32 neg_logits_tmp = SIMD_SUB_F32(zero, logits_tmp); + SIMD_F32 max_value = neg_logits_tmp; + max_value = SIMD_MAX_F32(max_value, zero); + SIMD_F32 neg_max_value = SIMD_SUB_F32(zero, max_value); + SIMD_F32 log_weight = SIMD_ADD_F32(SIMD_MUL_F32(SIMD_SUB_F32(pos_weight_tmp, ones), label_tmp), ones); + SIMD_F32 log_exp_value = + SIMD_LOG_F32(SIMD_ADD_F32(SIMD_HEXP_F32(neg_max_value), SIMD_HEXP_F32(SIMD_SUB_F32(neg_logits_tmp, max_value)))); + SIMD_F32 loss = SIMD_ADD_F32(SIMD_MUL_F32(SIMD_SUB_F32(ones, label_tmp), logits_tmp), + SIMD_MUL_F32(log_weight, SIMD_ADD_F32(log_exp_value, max_value))); + if (reduction) { + middle_output = SIMD_FMADD_F32(loss, weight_tmp, middle_output); + } else { + SIMD_ST_F32(output + index, SIMD_MUL_F32(loss, weight_tmp)); + } + } + if (reduction) { + *reduction_sum += SIMD_GET_SUM_F32(middle_output); + } + return index; +} +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_loigts_loss_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_loigts_loss_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..5209ca413003de78cc85c1e89ec0ff4d05a158ef --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_loigts_loss_fp32.c @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/bce_with_logits_loss_fp32.h" +#include "nnacl_c/bce_with_logits_loss_fp32_simd.h" + +void BCEWithLogitLoss(const float *logits, const float *label, const float *weight, const float *pos_weight, int length, + bool reduction, float *output, float *reduction_sum) { + int i = 0; + float simd_reduction_output = 0.0f; + SIMD_RUN_NO_SCALAR(BCEWithLogitLoss, i, logits, label, weight, pos_weight, length, reduction, output, + &simd_reduction_output); + for (; i < length; ++i) { + float logits_value = logits[i]; + float label_value = label[i]; + float weight_value = weight[i]; + float post_weight_value = pos_weight[i]; + float max_value = -logits_value; + max_value = max_value > 0.f ? max_value : 0.f; + float log_weight = (post_weight_value - 1.0f) * label_value + 1.0f; + float log_exp_value = logf(expf(-max_value) + expf(-logits_value - max_value)); + float loss = (1.0f - label_value) * logits_value + log_weight * (log_exp_value + max_value); + if (reduction) { + simd_reduction_output += loss * weight_value; + } else { + output[i] = loss * weight_value; + } + } + if (reduction) { + *reduction_sum = simd_reduction_output; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add.c new file mode 100644 index 0000000000000000000000000000000000000000..076672a35eccfccc62d45652b4bf633814096f3d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add.c @@ -0,0 +1,123 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/bias_add.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/bias_add_simd.h" + +void BiasAddByInnerCore(const float *input, const float *bias, float *output, int64_t num) { + int64_t index = 0; + + SIMD_RUN_NO_SCALAR(BiasAddByInnerCore, index, input, bias, output, num); + + for (; index < num; ++index) { + output[index] = input[index] + bias[index]; + } +} + +void BiasAddByBatchCore(const float *input, const float *bias, float *output, int64_t num) { + float *output1 = output; + float *output2 = output + num; + float *output3 = output + num * 2; + float *output4 = output + num * 3; + int64_t index = 0; + + SIMD_RUN_NO_SCALAR(BiasAddByBatchCore, index, input, bias, output1, output2, output3, output4, num); + + const float *input_data1 = input; + const float *input_data2 = input + num; + const float *input_data3 = input + num * 2; + const float *input_data4 = input + num * 3; + for (; index < num; ++index) { + output1[index] = input_data1[index] + bias[index]; + output2[index] = input_data2[index] + bias[index]; + output3[index] = input_data3[index] + bias[index]; + output4[index] = input_data4[index] + bias[index]; + } +} + +void DoBiasAddByBatch(const float *input, const float *bias, float *output, int64_t start_inner, int64_t start_outer, + int64_t end_inner, int64_t end_outer, int64_t inner_num) { + const float *cur_bias = bias + start_inner; + if (start_outer == end_outer) { + BiasAddByInnerCore(input, cur_bias, output, end_inner - start_inner); + return; + } + if (start_inner != 0) { + BiasAddByInnerCore(input, cur_bias, output, inner_num - start_inner); + start_outer += 1; + input += inner_num - start_inner; + cur_bias = bias; + output += inner_num - start_inner; + } + int64_t step = C4NUM * inner_num; + for (; start_outer <= end_outer - C4NUM; start_outer += C4NUM) { + BiasAddByBatchCore(input, cur_bias, output, inner_num); + input += step; + output += step; + } + for (; start_outer < end_outer; ++start_outer) { + BiasAddByInnerCore(input, cur_bias, output, inner_num); + input += inner_num; + output += inner_num; + } + BiasAddByInnerCore(input, cur_bias, output, end_inner); +} + +void DoBiasAddByInner(const float *input, const float *bias, float *output, int64_t start_inner, int64_t start_outer, + int64_t end_inner, int64_t end_outer, int64_t inner_num) { + const float *cur_bias = bias + start_inner; + if (start_outer == end_outer) { + BiasAddByInnerCore(input, cur_bias, output, end_inner - start_inner); + return; + } else { + BiasAddByInnerCore(input, cur_bias, output, inner_num - start_inner); + start_outer += 1; + input += inner_num - start_inner; + cur_bias = bias; + output += inner_num - start_inner; + } + if (start_outer == end_outer) { + BiasAddByInnerCore(input, cur_bias, output, end_inner); + return; + } else { + for (; start_outer < end_outer; ++start_outer) { + BiasAddByInnerCore(input, cur_bias, output, inner_num); + input += inner_num; + output += inner_num; + } + } + BiasAddByInnerCore(input, bias, output, end_inner); +} + +void BiasAddOpt(const float *input, const float *bias, float *output, int64_t start, int64_t end, int64_t inner_num, + bool batch_priority) { + if (inner_num == 0) { + return; + } + int64_t start_outer = start / inner_num; + int64_t start_inner = start % inner_num; + int64_t end_outer = end / inner_num; + int64_t end_inner = end % inner_num; + const float *cur_input = input + start; + float *cur_output = output + start; + + if (batch_priority) { + DoBiasAddByBatch(cur_input, bias, cur_output, start_inner, start_outer, end_inner, end_outer, inner_num); + } else { + DoBiasAddByInner(cur_input, bias, cur_output, start_inner, start_outer, end_inner, end_outer, inner_num); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add.h new file mode 100644 index 0000000000000000000000000000000000000000..210b176858dabd3292fbbd64102c851d054e5f69 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add.h @@ -0,0 +1,34 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_BIAS_ADD_H_ +#define MINDSPORE_NNACL_FP32_BIAS_ADD_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +void BiasAddOpt(const float *input, const float *bias, float *output, int64_t start, int64_t end, int64_t inner_num, + bool batch_priority); + +#ifdef __cplusplus +}; +#endif + +#endif // MINDSPORE_NNACL_FP32_BIAS_ADD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..baa787c4029e1a4e504bc41db3a05cf82d3344d7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add_simd.h.in @@ -0,0 +1,57 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_BIAS_ADD_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_BIAS_ADD_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int BiasAddByInnerCore@SIMD_INSTRUCTION@(int index, const float *input, const float *bias, float *output, + int64_t num) { + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(input + index); + SIMD_F32 vin1 = SIMD_LD_F32(bias + index); + SIMD_F32 vout = SIMD_ADD_F32(vin0, vin1); + SIMD_ST_F32(output + index, vout); + } + return index; +} + +static inline int BiasAddByBatchCore@SIMD_INSTRUCTION@(int index, const float *input, const float *bias, float *output1, + float *output2, float *output3, float *output4, int64_t num) { + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_LDX4_F32(input_data, input + index, num); + SIMD_F32 bias_data = SIMD_LD_F32(bias + index); + SIMD_ST_F32(output1 + index, SIMD_ADD_F32(input_data1, bias_data)); + SIMD_ST_F32(output2 + index, SIMD_ADD_F32(input_data2, bias_data)); + SIMD_ST_F32(output3 + index, SIMD_ADD_F32(input_data3, bias_data)); + SIMD_ST_F32(output4 + index, SIMD_ADD_F32(input_data4, bias_data)); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +}; +#endif + +#endif // MINDSPORE_NNACL_FP32_BIAS_ADD_SIMD_H_ \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..070f439ae31e53c63beaef0a6d3275d437b3d262 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32.c @@ -0,0 +1,77 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/cdist_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/cdist_fp32_simd.h" + +void CdistTwoNormalOpt(const float *a, const float *b, float *dst, int64_t m, float p) { + float result = 0; + int64_t i = 0; + + SIMD_RUN_NO_SCALAR(CdistTwoNormalOpt, i, a, b, &result, m); + + for (; i < m; i++) { + float x = fabsf(a[i] - b[i]); + result += x * x; + } + result = sqrtf(result); + *dst = result; + + return; +} + +void CdistPNormalOpt(const float *a, const float *b, float *dst, int64_t m, float p) { + float result = 0; + int64_t i = 0; + + SIMD_RUN_NO_SCALAR(CdistPNormalOpt, i, a, b, &result, m, p); + + for (; i < m; i++) { + float x = fabsf(a[i] - b[i]); + result += powf(x, p); + } + result = powf(result, 1.0 / p); + *dst = result; + + return; +} + +void CdistZeroNormalOpt(const float *a, const float *b, float *c, int64_t m, float p) { + float result = 0; + for (int64_t i = 0; i < m; i++) { + float x = fabsf(a[i] - b[i]); + result += MSMIN(ceilf(x), 1.0f); + } + *c = result; +} + +void CdistOneNormalOpt(const float *a, const float *b, float *c, int64_t m, float p) { + float result = 0; + for (int64_t i = 0; i < m; i++) { + float x = fabsf(a[i] - b[i]); + result += x; + } + *c = result; +} + +void CdistInfNormalOpt(const float *a, const float *b, float *c, int64_t m, float p) { + float result = 0; + for (int64_t i = 0; i < m; i++) { + float x = fabsf(a[i] - b[i]); + result = MSMAX(result, x); + } + *c = result; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..e7f408ac910af16aacc1b9d49ae777200b911460 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CDIST_H_ +#define MINDSPORE_NNACL_FP32_CDIST_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CdistTwoNormalOpt(const float *a, const float *b, float *dst, int64_t m, float p); +void CdistPNormalOpt(const float *a, const float *b, float *dst, int64_t m, float p); + +void CdistZeroNormalOpt(const float *a, const float *b, float *c, int64_t m, float p); +void CdistOneNormalOpt(const float *a, const float *b, float *c, int64_t m, float p); +void CdistInfNormalOpt(const float *a, const float *b, float *c, int64_t m, float p); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CDIST_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..7e88ea1d0af55159a71c96e00b61f8d5f1f6dc11 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32_simd.h.in @@ -0,0 +1,63 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CDIST_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_CDIST_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t CdistTwoNormalOpt@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, + float *out, int64_t size) { + SIMD_F32 result_vec = SIMD_MOV_F32(0.0f); + for (int64_t block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 a_vec = SIMD_LD_F32(a + index); + SIMD_F32 b_vec = SIMD_LD_F32(b + index); + SIMD_F32 tmp_vec = SIMD_SUB_F32(a_vec, b_vec); + tmp_vec = SIMD_ABS_F32(tmp_vec); + result_vec = SIMD_FMADD_F32(tmp_vec, tmp_vec, result_vec); + } + *out += SIMD_GET_SUM_F32(result_vec); + + return index; +} + +static inline int64_t CdistPNormalOpt@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, + float *out, int64_t size, float p) { + SIMD_F32 result_vec = SIMD_MOV_F32(0.0f); + SIMD_F32 p_vec = SIMD_MOV_F32(p); + for (int64_t block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 a_vec = SIMD_LD_F32(a + index); + SIMD_F32 b_vec = SIMD_LD_F32(b + index); + SIMD_F32 tmp_vec = SIMD_SUB_F32(a_vec, b_vec); + tmp_vec = SIMD_ABS_F32(tmp_vec); + tmp_vec = SIMD_POW_F32(tmp_vec, p_vec); + result_vec = SIMD_ADD_F32(tmp_vec, result_vec); + } + *out += SIMD_GET_SUM_F32(result_vec); + + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/common_func_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/common_func_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..ae9375ae3bf1ffee0e76d15c9c813b068d97ca0f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/common_func_fp32.c @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/common_func_fp32.h" + +void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t plane_stride, size_t oc_stride, ActType relu_type, int size) { + if (size == 0) { + return; + } + for (size_t oc = 0; oc < output_channel; oc++) { + int oc_div = oc / size; + int oc_mod = oc % size; + for (int hw = 0; hw < (int)plane_size; hw++) { + int src_index = oc_div * size * plane_stride + hw * size + oc_mod; + int dst_index = hw * oc_stride + oc; + float value = src_ptr_[src_index]; + if (bias_ptr != NULL) { + value = value + bias_ptr[oc]; + } + value = (relu_type == ActType_Relu || relu_type == ActType_Relu6) ? (MSMAX(0.f, value)) : (value); + value = (relu_type == ActType_Relu6) ? (MSMIN(6.f, value)) : (value); + out_ptr[dst_index] = value; + } + } +} + +void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t stride, size_t relu_type) { +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) + PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, relu_type, C8NUM); +#else + size_t oc8mod = output_channel % C8NUM; + size_t oc8div = output_channel - oc8mod; + size_t stride_size = stride * sizeof(float); + PostFuncBiasReluC8(out_ptr, c8_out_ptr, bias_ptr, oc8div, oc8mod, plane_size, stride_size, relu_type); +#endif +} + +void WinogradPostConvFuncFp32CX(const float *cx_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t plane_stride, size_t relu_type) { +#ifdef ENABLE_AVX + size_t oc8mod = output_channel % C8NUM; + size_t oc8div = output_channel - oc8mod; + size_t stride_size = (plane_stride - plane_size) * C8NUM * sizeof(float); + WinogradPostFuncBiasReluC8(out_ptr, cx_out_ptr, bias_ptr, oc8div, oc8mod, plane_size, stride_size, relu_type); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + size_t oc4mod = output_channel % C4NUM; + size_t oc4div = output_channel - oc4mod; + size_t stride_size = (plane_stride - plane_size) * C4NUM * sizeof(float); + WinogradPostFuncBiasReluC4(out_ptr, cx_out_ptr, bias_ptr, oc4div, oc4mod, plane_size, stride_size, relu_type); +#else + PostConvFuncComm(cx_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_stride, output_channel, relu_type, + C4NUM); +#endif +} + +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) +void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + const int unitStep = 4 * length; + for (int y = 0; y < h; ++y) { + float *dstY = M + y * w * unitStep; + for (int x = 0; x < w; ++x) { + float *dstX = dstY + x * unitStep; + const float *srcX = S + x * unitStep; + memset(dstX, 0, unitStep * sizeof(float)); + for (int i = 0; i < k; ++i) { + float b = B[i * h + y]; + const float *srcY = srcX + i * w * unitStep; + if (0.0f == b) { + continue; + } + for (int j = 0; j < unitStep; ++j) { + dstX[j] += srcY[j] * b; + } + } + } + } +} + +// M = S * B , M = h * w * l, S = h * k * l, B = k * w +void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + const int unitStep = 4 * length; + for (int y = 0; y < h; ++y) { + float *dstY = M + y * w * unitStep; + const float *srcY = S + y * k * unitStep; + + for (int x = 0; x < w; ++x) { + float *dstX = dstY + x * unitStep; + memset(dstX, 0, unitStep * sizeof(float)); + for (int i = 0; i < k; ++i) { + const float *srcX = srcY + i * unitStep; + float b = B[i * h + x]; + if (0.0f == b) { + continue; + } + for (int j = 0; j < unitStep; ++j) { + dstX[j] += srcX[j] * b; + } + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/common_func_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/common_func_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..400a5f554dc7bf0820e036d720b36b45d6c29c04 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/common_func_fp32.h @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_COMMON_FUNC_H_ +#define MINDSPORE_NNACL_FP32_COMMON_FUNC_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/conv_parameter.h" + +typedef struct ConvDwFp32BorderParam { + float *dst; + const float *src; + const float *weight; + const float *bias; + size_t height; + size_t width; + size_t in_kh_step; + size_t in_kw_step; + size_t kernel_w; + size_t relu; + size_t relu6; +} ConvDwFp32BorderParam; + +#ifdef __cplusplus +extern "C" { +#endif + +void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t stride, size_t relu_type); + +void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t stride, size_t relu_type); + +void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length); + +void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length); + +void WinogradPostConvFuncFp32CX(const float *cx_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t plane_stride, size_t relu_type); + +void WinogradPostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, + size_t plane_size, size_t plane_stride, size_t relu_type); + +void WinogradPostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t plane_stride, size_t relu_type); + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +#ifdef ENABLE_AVX +void ConvDwFp32Border(ConvDwFp32BorderParam *param); +#else +void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6); +#endif +void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h, + size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, + size_t in_kh_step, size_t in_kw_step); +void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step); +#endif + +#ifdef ENABLE_ARM64 +void DeconvDwFp32Border(float *dst, const float *src, const float *weight, size_t height, size_t width, + size_t in_kh_step, size_t in_kw_step, size_t kernel_w); + +void ConvSwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t ic4, + size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, + size_t relu6); + +void ConvDw3x3Stride1(float *output, const float *buffer, const float *weight, const float *bias, int col_size, + int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6); + +void ConvDw3x3Stride2(float *output, const float *buffer, const float *weight, const float *bias, int col_size, + int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6); + +void ConvDw3x3Corner(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, + int in_kw_step, int channel, size_t relu, size_t relu6); + +void ConvDw3x3Vertical(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, + int in_kw_step, int channel, size_t relu, size_t relu6); + +void ConvDw3x3Horizontal(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, + int in_kw_step, int channel, size_t relu, size_t relu6); +#endif + +#ifdef __cplusplus +} +#endif +#endif /* MINDSPORE_NNACL_FP32_COMMON_FUNC_H_ */ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/constant_of_shape_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/constant_of_shape_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..a58557dfa06f6faf392cb04bf5a12625b731c104 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/constant_of_shape_fp32.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CONSTANT_OF_SHAPE_FP32_H_ +#define MINDSPORE_NNACL_FP32_CONSTANT_OF_SHAPE_FP32_H_ +#include +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/constant_of_shape_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +inline int ConstantOfShapeInt32(int32_t *output, int start, int end, int32_t value) { + for (int i = start; i < end; i++) { + output[i] = value; + } + return NNACL_OK; +} + +inline int ConstantOfShapeFp32(float *output, int start, int end, float value) { + for (int i = start; i < end; i++) { + output[i] = value; + } + return NNACL_OK; +} + +inline int ConstantOfShapeBool(bool *output, int start, int end, bool value) { + for (int i = start; i < end; i++) { + output[i] = value; + } + return NNACL_OK; +} + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONSTANT_OF_SHAPE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..0c1cdd2c080302d2fb6c4853fe50f749aa8e42a2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.c @@ -0,0 +1,1608 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/conv_1x1_avx_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_avx_instructions.h" + +void Conv1x1SW3x32AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + asm volatile( + "movq %8, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%7), %%ymm0\n" + "vmovups 0x20(%7), %%ymm1\n" + "vmovups 0x40(%7), %%ymm2\n" + "vmovups 0x60(%7), %%ymm3\n" + "vmovups (%7, %6, 1), %%ymm4\n" + "vmovups 0x20(%7, %6, 1), %%ymm5\n" + "vmovups 0x40(%7, %6, 1), %%ymm6\n" + "vmovups 0x60(%7, %6, 1), %%ymm7\n" + "vmovups (%7, %6, 2), %%ymm8\n" + "vmovups 0x20(%7, %6, 2), %%ymm9\n" + "vmovups 0x40(%7, %6, 2), %%ymm10\n" + "vmovups 0x60(%7, %6, 2), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups 0x40(%2), %%ymm6\n" + "vmovups 0x60(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups 0x40(%2), %%ymm10\n" + "vmovups 0x60(%2), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "2:\n" // LoopIC + "vbroadcastss (%0), %%ymm13\n" + "vbroadcastss (%0, %4), %%ymm14\n" + "vbroadcastss (%0, %4, 2), %%ymm15\n" + "vmovups (%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 0x20(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 0x40(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 0x60(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 4(%0), %%ymm13\n" + "vbroadcastss 4(%0, %4), %%ymm14\n" + "vbroadcastss 4(%0, %4, 2), %%ymm15\n" + "vmovups 128(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 160(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 192(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 224(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 8(%0), %%ymm13\n" + "vbroadcastss 8(%0, %4), %%ymm14\n" + "vbroadcastss 8(%0, %4, 2), %%ymm15\n" + "vmovups 256(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 288(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 320(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 352(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 12(%0), %%ymm13\n" + "vbroadcastss 12(%0, %4), %%ymm14\n" + "vbroadcastss 12(%0, %4, 2), %%ymm15\n" + "vmovups 384(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 416(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 448(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 480(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 16(%0), %%ymm13\n" + "vbroadcastss 16(%0, %4), %%ymm14\n" + "vbroadcastss 16(%0, %4, 2), %%ymm15\n" + "vmovups 512(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 544(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 576(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 608(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 20(%0), %%ymm13\n" + "vbroadcastss 20(%0, %4), %%ymm14\n" + "vbroadcastss 20(%0, %4, 2), %%ymm15\n" + "vmovups 640(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 672(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 704(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 736(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 24(%0), %%ymm13\n" + "vbroadcastss 24(%0, %4), %%ymm14\n" + "vbroadcastss 24(%0, %4, 2), %%ymm15\n" + "vmovups 768(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 800(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 832(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 864(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 28(%0), %%ymm13\n" + "vbroadcastss 28(%0, %4), %%ymm14\n" + "vbroadcastss 28(%0, %4, 2), %%ymm15\n" + "vmovups 896(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 928(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 960(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 992(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "addq $1024, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %8, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %5, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "3:\n" + "vmovups %%ymm0, (%7)\n" // dst_0 + "vmovups %%ymm1, 0x20(%7)\n" + "vmovups %%ymm2, 0x40(%7)\n" + "vmovups %%ymm3, 0x60(%7)\n" + "vmovups %%ymm4, (%7, %6, 1)\n" + "vmovups %%ymm5, 0x20(%7, %6, 1)\n" + "vmovups %%ymm6, 0x40(%7, %6, 1)\n" + "vmovups %%ymm7, 0x60(%7, %6, 1)\n" + "vmovups %%ymm8, (%7, %6, 2)\n" + "vmovups %%ymm9, 0x20(%7, %6, 2)\n" + "vmovups %%ymm10, 0x40(%7, %6, 2)\n" + "vmovups %%ymm11, 0x60(%7, %6, 2)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), + "r"(dst_flag) // 8 + : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", + "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void Conv1x1SW1x32AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + asm volatile( + "movq %8, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%7), %%ymm0\n" + "vmovups 0x20(%7), %%ymm1\n" + "vmovups 0x40(%7), %%ymm2\n" + "vmovups 0x60(%7), %%ymm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + + "2:\n" // LoopIC + "vbroadcastss (%0), %%ymm13\n" + "vmovups (%1), %%ymm4\n" + "vmovups 0x20(%1), %%ymm5\n" + "vmovups 0x40(%1), %%ymm6\n" + "vmovups 0x60(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 4(%0), %%ymm13\n" + "vmovups 128(%1), %%ymm4\n" + "vmovups 160(%1), %%ymm5\n" + "vmovups 192(%1), %%ymm6\n" + "vmovups 224(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 8(%0), %%ymm13\n" + "vmovups 256(%1), %%ymm4\n" + "vmovups 288(%1), %%ymm5\n" + "vmovups 320(%1), %%ymm6\n" + "vmovups 352(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 12(%0), %%ymm13\n" + "vmovups 384(%1), %%ymm4\n" + "vmovups 416(%1), %%ymm5\n" + "vmovups 448(%1), %%ymm6\n" + "vmovups 480(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 16(%0), %%ymm13\n" + "vmovups 512(%1), %%ymm4\n" + "vmovups 544(%1), %%ymm5\n" + "vmovups 576(%1), %%ymm6\n" + "vmovups 608(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 20(%0), %%ymm13\n" + "vmovups 640(%1), %%ymm4\n" + "vmovups 672(%1), %%ymm5\n" + "vmovups 704(%1), %%ymm6\n" + "vmovups 736(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 24(%0), %%ymm13\n" + "vmovups 768(%1), %%ymm4\n" + "vmovups 800(%1), %%ymm5\n" + "vmovups 832(%1), %%ymm6\n" + "vmovups 864(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 28(%0), %%ymm13\n" + "vmovups 896(%1), %%ymm4\n" + "vmovups 928(%1), %%ymm5\n" + "vmovups 960(%1), %%ymm6\n" + "vmovups 992(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + "addq $1024, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %8, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %5, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + + "3:\n" + "vmovups %%ymm0, (%7)\n" // dst_0 + "vmovups %%ymm1, 0x20(%7)\n" + "vmovups %%ymm2, 0x40(%7)\n" + "vmovups %%ymm3, 0x60(%7)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), + "r"(dst_flag) // 8 + : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", "%ymm13", + "%ymm14"); +} + +void Conv1x1SW4x24AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_align / sizeof(float); + asm volatile( + "movq %10, %%rax\n" // dst_flag + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%8), %%ymm0\n" // dst_0 + "vmovups 0x20(%8), %%ymm1\n" + "vmovups 0x40(%8), %%ymm2\n" + "vmovups (%8, %7, 1), %%ymm3\n" + "vmovups 0x20(%8, %7, 1), %%ymm4\n" + "vmovups 0x40(%8, %7, 1), %%ymm5\n" + "vmovups (%8, %7, 2), %%ymm6\n" + "vmovups 0x20(%8, %7, 2), %%ymm7\n" + "vmovups 0x40(%8, %7, 2), %%ymm8\n" + "vmovups (%9), %%ymm9\n" + "vmovups 0x20(%9), %%ymm10\n" + "vmovups 0x40(%9), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups (%2), %%ymm3\n" + "vmovups 0x20(%2), %%ymm4\n" + "vmovups 0x40(%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups 0x40(%2), %%ymm8\n" + "vmovups (%2), %%ymm9\n" + "vmovups 0x20(%2), %%ymm10\n" + "vmovups 0x40(%2), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "2:\n" // LoopIC + "vmovups (%1), %%ymm13\n" + "vmovups 0x20(%1), %%ymm14\n" + "vmovups 0x40(%1), %%ymm15\n" + "vbroadcastss (%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss (%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss (%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss (%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 96(%1), %%ymm13\n" + "vmovups 128(%1), %%ymm14\n" + "vmovups 160(%1), %%ymm15\n" + "vbroadcastss 4(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 4(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 4(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 4(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 192(%1), %%ymm13\n" + "vmovups 224(%1), %%ymm14\n" + "vmovups 256(%1), %%ymm15\n" + "vbroadcastss 8(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 8(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 8(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 8(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 288(%1), %%ymm13\n" + "vmovups 320(%1), %%ymm14\n" + "vmovups 352(%1), %%ymm15\n" + "vbroadcastss 12(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 12(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 12(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 12(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 384(%1), %%ymm13\n" + "vmovups 416(%1), %%ymm14\n" + "vmovups 448(%1), %%ymm15\n" + "vbroadcastss 16(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 16(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 16(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 16(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 480(%1), %%ymm13\n" + "vmovups 512(%1), %%ymm14\n" + "vmovups 544(%1), %%ymm15\n" + "vbroadcastss 20(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 20(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 20(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 20(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 576(%1), %%ymm13\n" + "vmovups 608(%1), %%ymm14\n" + "vmovups 640(%1), %%ymm15\n" + "vbroadcastss 24(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 24(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 24(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 24(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 672(%1), %%ymm13\n" + "vmovups 704(%1), %%ymm14\n" + "vmovups 736(%1), %%ymm15\n" + "vbroadcastss 28(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 28(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 28(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 28(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "addq $768, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %10, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %6, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "3:\n" + "vmovups %%ymm0, (%8)\n" // dst_0 + "vmovups %%ymm1, 0x20(%8)\n" + "vmovups %%ymm2, 0x40(%8)\n" + "vmovups %%ymm3, (%8, %7, 1)\n" + "vmovups %%ymm4, 0x20(%8, %7, 1)\n" + "vmovups %%ymm5, 0x40(%8, %7, 1)\n" + "vmovups %%ymm6, (%8, %7, 2)\n" + "vmovups %%ymm7, 0x20(%8, %7, 2)\n" + "vmovups %%ymm8, 0x40(%8, %7, 2)\n" + "vmovups %%ymm9, (%9)\n" + "vmovups %%ymm10, 0x20(%9)\n" + "vmovups %%ymm11, 0x40(%9)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(src_3_step), "r"(act_flag), // 6 + "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_flag) // 10 + : "%rax", "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", + "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void Conv1x1SW1x24AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + asm volatile( + "movq %8, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%7), %%ymm0\n" + "vmovups 0x20(%7), %%ymm1\n" + "vmovups 0x40(%7), %%ymm2\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + + "2:\n" // LoopIC + "vbroadcastss (%0), %%ymm13\n" + "vmovups (%1), %%ymm4\n" + "vmovups 0x20(%1), %%ymm5\n" + "vmovups 0x40(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 4(%0), %%ymm13\n" + "vmovups 96(%1), %%ymm4\n" + "vmovups 128(%1), %%ymm5\n" + "vmovups 160(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 8(%0), %%ymm13\n" + "vmovups 192(%1), %%ymm4\n" + "vmovups 224(%1), %%ymm5\n" + "vmovups 256(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 12(%0), %%ymm13\n" + "vmovups 288(%1), %%ymm4\n" + "vmovups 320(%1), %%ymm5\n" + "vmovups 352(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 16(%0), %%ymm13\n" + "vmovups 384(%1), %%ymm4\n" + "vmovups 416(%1), %%ymm5\n" + "vmovups 448(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 20(%0), %%ymm13\n" + "vmovups 480(%1), %%ymm4\n" + "vmovups 512(%1), %%ymm5\n" + "vmovups 544(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 24(%0), %%ymm13\n" + "vmovups 576(%1), %%ymm4\n" + "vmovups 608(%1), %%ymm5\n" + "vmovups 640(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 28(%0), %%ymm13\n" + "vmovups 672(%1), %%ymm4\n" + "vmovups 704(%1), %%ymm5\n" + "vmovups 736(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "addq $768, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %8, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %5, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + + "3:\n" + "vmovups %%ymm0, (%7)\n" // dst_0 + "vmovups %%ymm1, 0x20(%7)\n" + "vmovups %%ymm2, 0x40(%7)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), + "r"(dst_flag) // 8 + : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm4", "%ymm5", "%ymm6", "%ymm12", "%ymm13", "%ymm14"); +} + +void Conv1x1SW6x16AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_align / sizeof(float); + asm volatile( + "movq %10, %%rax\n" // dst_flag + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%8), %%ymm0\n" // dst_0 + "vmovups 0x20(%8), %%ymm1\n" + "vmovups (%8, %7, 1), %%ymm2\n" + "vmovups 0x20(%8, %7, 1), %%ymm3\n" + "vmovups (%8, %7, 2), %%ymm4\n" + "vmovups 0x20(%8, %7, 2), %%ymm5\n" + "vmovups (%9), %%ymm6\n" + "vmovups 0x20(%9), %%ymm7\n" + "vmovups (%9, %7, 1), %%ymm8\n" + "vmovups 0x20(%9, %7, 1), %%ymm9\n" + "vmovups (%9, %7, 2), %%ymm10\n" + "vmovups 0x20(%9, %7, 2), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. + "vmovups (%2), %%ymm2\n" + "vmovups 0x20(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups (%2), %%ymm10\n" + "vmovups 0x20(%2), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "2:\n" // LoopIC + "movq %0, %%rax\n" + "addq %5, %%rax\n" + + "vmovups (%1), %%ymm12\n" + "vmovups 0x20(%1), %%ymm13\n" + "vbroadcastss (%0), %%ymm14\n" + "vbroadcastss (%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss (%0, %4, 2), %%ymm14\n" + "vbroadcastss (%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss (%%rax, %4), %%ymm14\n" + "vbroadcastss (%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 64(%1), %%ymm12\n" + "vmovups 96(%1), %%ymm13\n" + "vbroadcastss 4(%0), %%ymm14\n" + "vbroadcastss 4(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 4(%0, %4, 2), %%ymm14\n" + "vbroadcastss 4(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 4(%%rax, %4), %%ymm14\n" + "vbroadcastss 4(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 128(%1), %%ymm12\n" + "vmovups 160(%1), %%ymm13\n" + "vbroadcastss 8(%0), %%ymm14\n" + "vbroadcastss 8(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 8(%0, %4, 2), %%ymm14\n" + "vbroadcastss 8(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 8(%%rax, %4), %%ymm14\n" + "vbroadcastss 8(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 192(%1), %%ymm12\n" + "vmovups 224(%1), %%ymm13\n" + "vbroadcastss 12(%0), %%ymm14\n" + "vbroadcastss 12(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 12(%0, %4, 2), %%ymm14\n" + "vbroadcastss 12(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 12(%%rax, %4), %%ymm14\n" + "vbroadcastss 12(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 256(%1), %%ymm12\n" + "vmovups 288(%1), %%ymm13\n" + "vbroadcastss 16(%0), %%ymm14\n" + "vbroadcastss 16(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 16(%0, %4, 2), %%ymm14\n" + "vbroadcastss 16(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 16(%%rax, %4), %%ymm14\n" + "vbroadcastss 16(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 320(%1), %%ymm12\n" + "vmovups 352(%1), %%ymm13\n" + "vbroadcastss 20(%0), %%ymm14\n" + "vbroadcastss 20(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 20(%0, %4, 2), %%ymm14\n" + "vbroadcastss 20(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 20(%%rax, %4), %%ymm14\n" + "vbroadcastss 20(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 384(%1), %%ymm12\n" + "vmovups 416(%1), %%ymm13\n" + "vbroadcastss 24(%0), %%ymm14\n" + "vbroadcastss 24(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 24(%0, %4, 2), %%ymm14\n" + "vbroadcastss 24(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 24(%%rax, %4), %%ymm14\n" + "vbroadcastss 24(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 448(%1), %%ymm12\n" + "vmovups 480(%1), %%ymm13\n" + "vbroadcastss 28(%0), %%ymm14\n" + "vbroadcastss 28(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 28(%0, %4, 2), %%ymm14\n" + "vbroadcastss 28(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 28(%%rax, %4), %%ymm14\n" + "vbroadcastss 28(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "addq $512, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %10, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %6, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "3:\n" + "vmovups %%ymm0, (%8)\n" // dst_0 + "vmovups %%ymm1, 0x20(%8)\n" + "vmovups %%ymm2, (%8, %7, 1)\n" + "vmovups %%ymm3, 0x20(%8, %7, 1)\n" + "vmovups %%ymm4, (%8, %7, 2)\n" + "vmovups %%ymm5, 0x20(%8, %7, 2)\n" + "vmovups %%ymm6, (%9)\n" // dst+3 + "vmovups %%ymm7, 0x20(%9)\n" + "vmovups %%ymm8, (%9, %7, 1)\n" + "vmovups %%ymm9, 0x20(%9, %7, 1)\n" + "vmovups %%ymm10, (%9, %7, 2)\n" + "vmovups %%ymm11, 0x20(%9, %7, 2)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(src_3_step), "r"(act_flag), // 6 + "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_flag) // 10 + : "%rax", "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", + "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void Conv1x1SW1x16AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + asm volatile( + "movq %8, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%7), %%ymm0\n" + "vmovups 0x20(%7), %%ymm1\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + + "2:\n" // LoopIC + "vbroadcastss (%0), %%ymm12\n" + "vmovups (%1), %%ymm13\n" + "vmovups 0x20(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 4(%0), %%ymm12\n" + "vmovups 64(%1), %%ymm13\n" + "vmovups 96(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 8(%0), %%ymm12\n" + "vmovups 128(%1), %%ymm13\n" + "vmovups 160(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 12(%0), %%ymm12\n" + "vmovups 192(%1), %%ymm13\n" + "vmovups 224(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 16(%0), %%ymm12\n" + "vmovups 256(%1), %%ymm13\n" + "vmovups 288(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 20(%0), %%ymm12\n" + "vmovups 320(%1), %%ymm13\n" + "vmovups 352(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 24(%0), %%ymm12\n" + "vmovups 384(%1), %%ymm13\n" + "vmovups 416(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 28(%0), %%ymm12\n" + "vmovups 448(%1), %%ymm13\n" + "vmovups 480(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "addq $512, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %8, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %5, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + + "3:\n" + "vmovups %%ymm0, (%7)\n" // dst_0 + "vmovups %%ymm1, 0x20(%7)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), + "r"(dst_flag) // 8 + : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm13", "%ymm14"); +} + +void Conv1x1SW12x8AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + ic_align <<= 3; + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_align / sizeof(float); + float *dst_5 = dst + 5 * oc_align / sizeof(float); + float *dst_9 = dst + 9 * oc_align / sizeof(float); + asm volatile( + "movq %12, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%8), %%ymm0\n" // dst_0 + "vmovups (%8, %7), %%ymm1\n" + "vmovups (%8, %7, 2), %%ymm2\n" + "vmovups (%9), %%ymm3\n" // dst_3 + "vmovups (%8, %7, 4), %%ymm4\n" + "vmovups (%10), %%ymm5\n" // dst_5 + "vmovups (%10, %7, 1), %%ymm6\n" + "vmovups (%10, %7, 2), %%ymm7\n" + "vmovups (%8, %7, 8), %%ymm8\n" + "vmovups (%11), %%ymm9\n" // dst_9 + "vmovups (%11, %7, 1), %%ymm10\n" + "vmovups (%11, %7, 2), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups (%2), %%ymm1\n" + "vmovups (%2), %%ymm2\n" + "vmovups (%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups (%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups (%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups (%2), %%ymm9\n" + "vmovups (%2), %%ymm10\n" + "vmovups (%2), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "2:\n" // LoopIC + "vmovups (%1), %%ymm12\n" + "movq %0, %%rax\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vbroadcastss (%%rax, %4), %%ymm14\n" + "vbroadcastss (%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "addq %5, %%rax\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vbroadcastss (%%rax, %4), %%ymm14\n" + "vbroadcastss (%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "addq %5, %%rax\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vbroadcastss (%%rax, %4), %%ymm14\n" + "vbroadcastss (%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "addq %5, %%rax\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vbroadcastss (%%rax, %4), %%ymm14\n" + "vbroadcastss (%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + "addq $32, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 2b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(src_3_step), "r"(act_flag), // 6 + "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "r"(dst_flag) // 12 + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x2, %%eax\n" + "je 0f\n" + "movq %0, %%rax\n" + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%3)\n" // dst_3 + "vmovups %%ymm4, (%2, %1, 4)\n" + "vmovups %%ymm5, (%4)\n" // dst_5 + "vmovups %%ymm6, (%4, %1, 1)\n" + "vmovups %%ymm7, (%4, %1, 2)\n" + "vmovups %%ymm8, (%2, %1, 8)\n" + "vmovups %%ymm9, (%5)\n" // dst_9 + "vmovups %%ymm10, (%5, %1, 1)\n" + "vmovups %%ymm11, (%5, %1, 2)\n" + : + : "r"(act_flag), "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "a"(dst_flag) // 6 + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void Conv1x1SW1x8AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + asm volatile( + "movq %8, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%7), %%ymm0\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + + "2:\n" // LoopIC + "vbroadcastss (%0), %%ymm12\n" + "vmovups (%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 4(%0), %%ymm12\n" + "vmovups 32(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 8(%0), %%ymm12\n" + "vmovups 64(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 12(%0), %%ymm12\n" + "vmovups 96(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 16(%0), %%ymm12\n" + "vmovups 128(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 20(%0), %%ymm12\n" + "vmovups 160(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 24(%0), %%ymm12\n" + "vmovups 192(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 28(%0), %%ymm12\n" + "vmovups 224(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "addq $256, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %8, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %5, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + + "3:\n" + "vmovups %%ymm0, (%7)\n" // dst_0 + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), + "r"(dst_flag) // 8 + : "%rax", "%ecx", "%ymm0", "%ymm12", "%ymm13"); +} + +// sliding window to compate 1x1 conv in x86 +void Conv1x1SWAVXFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data, + int task_id, ConvParameter *conv_param, SlidingWindowParam *sw_param) { + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int ohw = output_h * output_w; + int ohw_step = UP_DIV(ohw, conv_param->thread_num_); + int ohw_start = ohw_step * task_id; + int ohw_end = MSMIN(ohw_start + ohw_step, ohw); + if (ohw_start >= ohw_end) { + return; + } + int act_type = C0NUM; + int oc_tile_ = C8NUM; // oc in algin to C8NUM in x86_64_avx + if (conv_param->act_type_ == ActType_Relu6) { + act_type += C1NUM; + } + if (conv_param->act_type_ == ActType_Relu6 || conv_param->act_type_ == ActType_Relu) { + act_type += C2NUM; + } + int pad_d = conv_param->pad_d_; + int pad_l = conv_param->pad_l_; + int pad_r = conv_param->pad_r_; + int pad_u = conv_param->pad_u_; + int oc_align = sw_param->block_channel_; + int oc_align_float = oc_align * sizeof(float); + int ic_align = sw_param->ic_align_; + int in_sw_step = sw_param->in_sw_step_; + int in_sw_step_float = sw_param->in_sw_step_ * sizeof(float); + int kernel_step = sw_param->kernel_step_; + int oc_num = sw_param->c_block_; + int in_step = sw_param->in_step_; + int out_step = sw_param->out_step_; + const int ow_block_num[4] = {12, 6, 4, 3}; + const Conv1x1SWAVXKernel kernel[4][2] = {{Conv1x1SW1x8AVXKernel, Conv1x1SW12x8AVXKernel}, + {Conv1x1SW1x16AVXKernel, Conv1x1SW6x16AVXKernel}, + {Conv1x1SW1x24AVXKernel, Conv1x1SW4x24AVXKernel}, + {Conv1x1SW1x32AVXKernel, Conv1x1SW3x32AVXKernel}}; + for (int b = 0; b < conv_param->output_batch_; b++) { + int ic_block = 128; + int dst_flag = 0; + for (int ic = 0; ic < ic_align; ic += ic_block) { + if (ic_align - ic <= ic_block) { + ic_block = ic_align - ic; + dst_flag = C3NUM - (ic == 0); + } else { + dst_flag = 1 - (ic == 0); + } + if (pad_d == 0 && pad_l == 0 && pad_r == 0 && pad_u == 0) { + const float *bias = bias_data; + int oc_block = 0; + for (int oc = 0; oc < oc_num; oc += oc_block) { + oc_block = MSMIN(C4NUM, oc_num - oc); // 4 3 2 1 + const float *weight = packed_weight + oc * kernel_step + ic * C8NUM * oc_block; + if (bias != NULL) { + bias = bias_data + oc * oc_tile_; + } + const float *src_w = input_data + ic + ohw_start * in_sw_step; + float *dst_oc = output_data + oc * oc_tile_; + int hw_block = ow_block_num[oc_block - 1]; + for (int hw = ohw_start; hw < ohw_end; hw += hw_block) { + if (hw_block > ohw_end - hw) { // ow is not enough and process one ow + hw_block = 1; + } + float *dst_w = dst_oc + hw * oc_align; + kernel[oc_block - 1][hw_block / ow_block_num[oc_block - 1]](dst_w, src_w, weight, bias, act_type, hw_block, + oc_block, oc_align_float, ic_block >> C3NUM, + in_sw_step_float, dst_flag); + src_w += hw_block * in_sw_step; + } + } + } + } + input_data += in_step; + output_data += out_step; + } // batch loop +} + +#ifdef ENABLE_DEBUG +void Conv1x1SWOWxOCAVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + oc_align /= sizeof(float); + in_sw_step /= sizeof(float); + ic_align <<= C3NUM; + __m256 dst_data[12]; + const float *src_sw[12]; + __m256 weight_data[4]; + for (int i = 0; i < C4NUM; ++i) { + weight_data[i] = _mm256_set1_ps(0.0f); + } + for (int i = 0; i < ow_block; ++i) { + if (dst_flag & 0x01) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_loadu_ps(dst + i * oc_align + j * C8NUM); + } + } else { + if (bias != NULL) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_loadu_ps(bias + j * C8NUM); + } + } else { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_set1_ps(0.0f); + } + } + } + src_sw[i] = src + i * in_sw_step; + } + const float *weight_kernel = weight; + for (int ic = 0; ic < ic_align; ++ic) { + for (int j = 0; j < oc_block; ++j) { + weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM); + } + for (int i = 0; i < ow_block; ++i) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] += src_sw[i][ic] * weight_data[j]; + } + } + weight_kernel += C8NUM * oc_block; + } // ic loop + // add bias and relu + for (int i = 0; i < ow_block; ++i) { + for (int j = 0; j < oc_block; ++j) { + if (dst_flag & 0x02) { + if (0x1 & act_flag) { // relu6 + dst_data[i * oc_block + j] = _mm256_min_ps(dst_data[i * oc_block + j], _mm256_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f)); + } + } + _mm256_storeu_ps(dst + i * oc_align + j * C8NUM, dst_data[i * oc_block + j]); + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..6a948da3fd107842ce3158b411681b9ba9cff285 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CONV_1X1_AVX_FP32_H_ +#define MINDSPORE_NNACL_FP32_CONV_1X1_AVX_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*Conv1x1SWAVXKernel)(float *dst, const float *src, const float *weight, const float *bias, + size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, + size_t in_sw_step, size_t dst_flag); + +void Conv1x1SWAVXFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data, + int task_id, ConvParameter *conv_param, SlidingWindowParam *sw_param); + +#ifdef ENABLE_DEBUG +void Conv1x1SWOWxOCAVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag); +#endif +#ifdef __cplusplus +} +#endif // MINDSPORE_NNACL_FP32_CONV_1X1_AVX_FP32_H_ +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_x86_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_x86_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..51ce48be2ecd9b420e5dae74bc5f4361fecf797c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_x86_fp32.h @@ -0,0 +1,21 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CONV_1X1_X86_FP32_H_ +#define MINDSPORE_NNACL_FP32_CONV_1X1_X86_FP32_H_ + +#include "nnacl_c/fp32/conv_1x1_avx_fp32.h" + +#endif // MINDSPORE_NNACL_FP32_CONV_1X1_X86_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_common_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_common_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..47302410860a3eaa6e26547f5839c9f48a36e93b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_common_fp32.c @@ -0,0 +1,435 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/conv_common_fp32.h" +#include +#ifdef ENABLE_AVX +#ifdef _MSC_VER +#include +#else +#include +#endif +#endif +#include "nnacl_c/fp32/matmul_fp32.h" +void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, + int block_index) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int kernel_plane = kernel_h * kernel_w; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int out_w = conv_param->output_w_; + if (dilation_h == 0 || dilation_w == 0 || out_w == 0) { + return; + } + int in_channel = conv_param->input_channel_; + int in_w = conv_param->input_w_; + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_; + int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_; + if (conv_param->input_h_ - input_h < 0 || in_w - input_w < 0) { + continue; + } + int input_stride = (input_h * in_w + input_w) * in_channel; + int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); + int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h)); + int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); + int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); + if (kw_e <= kw_s) { + continue; + } + if (dilation_w == 1 && dilation_h == 1) { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * in_w * in_channel + input_stride; + int input_x_stride = input_y_stride + kw_s * in_channel; + int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, + (kw_e - kw_s) * in_channel * sizeof(float)); + } // kernel_h loop + } else { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; + for (int k = kw_s; k < kw_e; ++k) { + int input_x_stride = input_y_stride + k * dilation_w * in_channel; + int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float)); + } + } // kernel_h loop + } + } // tile num loop +} + +// fp32 conv common +void ConvFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param) { + if (conv_param->thread_num_ == 0) { + return; + } + Row2ColMajorFuncPtr Row2ColMajor = NULL; + int output_hw = conv_param->output_h_ * conv_param->output_w_; +#ifdef ENABLE_AVX + Row2ColMajor = RowMajor2Col6Major; + const int cal_num = C6NUM; +#elif defined(ENABLE_SSE) + Row2ColMajor = RowMajor2Col4Major; + const int cal_num = C4NUM; +#elif defined(ENABLE_ARM64) + MatmulFloatOptFuncPtr MatmulFloatOpt = NULL; + int cal_num = 0; + if (output_hw <= C4NUM) { + Row2ColMajor = RowMajor2Col4Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow4; + cal_num = C4NUM; + } else if (output_hw <= C8NUM) { + Row2ColMajor = RowMajor2Col8Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow8; + cal_num = C8NUM; + } else { + Row2ColMajor = RowMajor2Col12Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow12; + cal_num = C12NUM; + } +#elif defined(ENABLE_ARM32) + Row2ColMajor = RowMajor2Col12Major; + const int cal_num = C12NUM; +#else + Row2ColMajor = RowMajor2Col12Major; + const int cal_num = C12NUM; +#endif + + int block_per_thread = UP_DIV(UP_DIV(output_hw, cal_num), conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int start_hw = start_block * cal_num; + int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * cal_num); + if (start_hw >= end_hw) { + return; + } + int out_stride = conv_param->output_channel_ * cal_num; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + col_major_input += task_id * deep * cal_num; + size_t input_size = deep * cal_num * sizeof(float); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int out_channel = conv_param->output_channel_; + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * out_channel * output_hw + start_hw * out_channel; + for (int i = start_hw; i < end_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + Row2ColMajor(packed_input, col_major_input, cal_num, deep); + float *gemm_output = output_data + out_offset; +// x86 func param types are different +#if ENABLE_AVX + MatmulFloatAvxOpt(col_major_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep, + real_cal_row, out_channel, (size_t)out_channel, (size_t)OutType_Nhwc); +#elif ENABLE_SSE + MatmulFloatSse64Opt(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, + real_cal_row, out_channel, (size_t)out_channel, (int)OutType_Nhwc); +#elif ENABLE_ARM32 + MatmulFloatNeon32Opt12x4(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, + real_cal_row, out_channel, out_channel, OutType_Nhwc); +#elif ENABLE_ARM64 + MatmulFloatOpt(col_major_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_row, + out_channel, out_channel, OutType_Nhwc); +#else + MatMul12x8(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, real_cal_row, + out_channel, out_channel, OutType_Nhwc); +#endif + } + } +} + +// fp32 conv common +void ConvFp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *col_major_input, float *output_data, int task_id, + const ConvParameter *conv_param) { + if (conv_param->thread_num_ == 0) { + return; + } + int output_hw = conv_param->output_h_ * conv_param->output_w_; + Row2ColMajorFuncPtr Row2ColMajor = NULL; +#ifdef ENABLE_AVX + const int cal_num = C6NUM; + Row2ColMajor = RowMajor2Col6Major; +#elif defined(ENABLE_SSE) + const int cal_num = C4NUM; + Row2ColMajor = RowMajor2Col4Major; +#elif defined(ENABLE_ARM64) + int cal_num = 0; + MatmulFloatOptFuncPtr MatmulFloatOpt = NULL; + if (output_hw <= C4NUM) { + cal_num = C4NUM; + Row2ColMajor = RowMajor2Col4Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow4; + } else if (output_hw <= C8NUM) { + cal_num = C8NUM; + Row2ColMajor = RowMajor2Col8Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow8; + } else { + cal_num = C12NUM; + Row2ColMajor = RowMajor2Col12Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow12; + } +#elif defined(ENABLE_ARM32) + const int cal_num = C12NUM; + Row2ColMajor = RowMajor2Col12Major; +#else + const int cal_num = C12NUM; + Row2ColMajor = RowMajor2Col12Major; +#endif + + int block_batch_per_thread = UP_DIV(conv_param->input_batch_, conv_param->thread_num_); + int start_batch = block_batch_per_thread * task_id; + int end_batch = MSMIN(conv_param->input_batch_, (start_batch + block_batch_per_thread)); + + int out_stride = conv_param->output_channel_ * cal_num; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + col_major_input += task_id * deep * cal_num; + size_t input_size = deep * cal_num * sizeof(float); + + for (int b = start_batch; b < end_batch; b++) { + int out_channel = conv_param->output_channel_; + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * out_channel * output_hw; + for (int i = 0; i < output_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + Row2ColMajor(packed_input, col_major_input, cal_num, deep); + float *gemm_output = output_data + out_offset; +// x86 func param types are different +#if ENABLE_AVX + MatmulFloatAvxOpt(col_major_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep, + real_cal_row, out_channel, (size_t)out_channel, (size_t)OutType_Nhwc); +#elif ENABLE_SSE + MatmulFloatSse64Opt(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, + real_cal_row, out_channel, (size_t)out_channel, (int)OutType_Nhwc); +#elif ENABLE_ARM32 + MatmulFloatNeon32Opt12x4(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, + real_cal_row, out_channel, out_channel, OutType_Nhwc); +#elif ENABLE_ARM64 + MatmulFloatOpt(col_major_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_row, + out_channel, out_channel, OutType_Nhwc); +#else + MatMul12x8(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, real_cal_row, + out_channel, out_channel, OutType_Nhwc); +#endif + } + } +} + +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) +void ConvFp32OutNC4HW4(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param) { + if (conv_param->thread_num_ == 0) { + return; + } + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int out_channel = conv_param->output_channel_; + int input_hw = conv_param->input_h_ * conv_param->input_w_; + int in_channel = conv_param->input_channel_; + Row2ColMajorFuncPtr Row2ColMajor = NULL; + int cal_num = 0; + int out_tile = 0; +#ifdef ENABLE_AVX + cal_num = C6NUM; + out_tile = C8NUM; + Row2ColMajor = RowMajor2Col6Major; + int align_channel = UP_DIV(out_channel, C16NUM) * C16NUM; +#else + out_tile = C4NUM; + MatmulFloatOptFuncPtr MatmulFloatOpt = NULL; + if (output_hw <= C4NUM) { + cal_num = C4NUM; + Row2ColMajor = RowMajor2Col4Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow4; + } else if (output_hw <= C8NUM) { + cal_num = C8NUM; + Row2ColMajor = RowMajor2Col8Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow8; + } else { + cal_num = C12NUM; + Row2ColMajor = RowMajor2Col12Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow12; + } +#endif + int block_per_thread = UP_DIV(UP_DIV(output_hw, cal_num), conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int start_hw = start_block * cal_num; + int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * cal_num); + if (start_hw >= end_hw) { + return; + } +#ifdef ENABLE_AVX + int act_type = 0; + if (conv_param->act_type_ == ActType_Relu6) { + act_type += 1; + } + if (conv_param->act_type_ == ActType_Relu || conv_param->act_type_ == ActType_Relu6) { + act_type += 2; + } + int out_stride = out_tile * cal_num; + int out_block_stride = output_hw * C8NUM; +#else + int out_stride = MSMIN(out_channel, out_tile) * cal_num; +#endif + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + col_major_input += task_id * deep * cal_num; + size_t input_size = deep * cal_num * sizeof(float); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_offset = b * in_channel * input_hw; +#ifdef ENABLE_AVX + int out_offset = b * align_channel * output_hw + start_hw * out_tile; +#else + int out_offset = b * out_channel * output_hw + start_hw * MSMIN(out_channel, out_tile); +#endif + for (int i = start_hw; i < end_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + Row2ColMajor(packed_input, col_major_input, cal_num, deep); + float *gemm_output = output_data + out_offset; +#ifdef ENABLE_AVX + for (int oc = 0; oc < out_channel; oc += C16NUM) { + CommonConv6x16Kernel(gemm_output + oc * output_hw, col_major_input, packed_weight + oc * deep, bias_data + oc, + deep, out_block_stride, act_type, real_cal_row); + } +#else + MatmulFloatOpt(col_major_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_row, + out_channel, output_hw, OutType_NC4HW4); +#endif + } + } +} +#endif + +#ifdef ENABLE_AVX +void CommonConv6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t depth, + const size_t out_step, const size_t act_flag, const size_t real_cal_row) { +#define Store1 \ + _mm256_storeu_ps(dst, out[0]); \ + _mm256_storeu_ps(dst + out_step, out[1]); +#define Store2 \ + Store1 _mm256_storeu_ps(dst + C8NUM, out[2]); \ + _mm256_storeu_ps(dst + out_step + C8NUM, out[3]); +#define Store3 \ + Store2 _mm256_storeu_ps(dst + C16NUM, out[4]); \ + _mm256_storeu_ps(dst + out_step + C16NUM, out[5]); +#define Store4 \ + Store3 _mm256_storeu_ps(dst + C24NUM, out[6]); \ + _mm256_storeu_ps(dst + out_step + C24NUM, out[7]); +#define Store5 \ + Store4 _mm256_storeu_ps(dst + C32NUM, out[8]); \ + _mm256_storeu_ps(dst + out_step + C32NUM, out[9]); +#define Store6 \ + Store5 _mm256_storeu_ps(dst + C40NUM, out[10]); \ + _mm256_storeu_ps(dst + out_step + C40NUM, out[11]); + + __m256 out[12]; + if (bias != NULL) { + out[0] = _mm256_loadu_ps(bias); + out[1] = _mm256_loadu_ps(bias + C8NUM); + } else { + out[0] = _mm256_set1_ps(0.0f); + out[1] = _mm256_set1_ps(0.0f); + } + out[2] = out[0]; + out[3] = out[1]; + out[4] = out[0]; + out[5] = out[1]; + out[6] = out[0]; + out[7] = out[1]; + out[8] = out[0]; + out[9] = out[1]; + out[10] = out[0]; + out[11] = out[1]; + for (int d = 0; d < depth; ++d) { + __m256 w1 = _mm256_loadu_ps(weight); + __m256 w2 = _mm256_loadu_ps(weight + C8NUM); + __m256 s1 = _mm256_set1_ps(*src); + __m256 s2 = _mm256_set1_ps(*(src + 1)); + out[0] = _mm256_fmadd_ps(s1, w1, out[0]); + out[1] = _mm256_fmadd_ps(s1, w2, out[1]); + out[2] = _mm256_fmadd_ps(s2, w1, out[2]); + out[3] = _mm256_fmadd_ps(s2, w2, out[3]); + s1 = _mm256_set1_ps(*(src + 2)); + s2 = _mm256_set1_ps(*(src + 3)); + out[4] = _mm256_fmadd_ps(s1, w1, out[4]); + out[5] = _mm256_fmadd_ps(s1, w2, out[5]); + out[6] = _mm256_fmadd_ps(s2, w1, out[6]); + out[7] = _mm256_fmadd_ps(s2, w2, out[7]); + s1 = _mm256_set1_ps(*(src + 4)); + s2 = _mm256_set1_ps(*(src + 5)); + out[8] = _mm256_fmadd_ps(s1, w1, out[8]); + out[9] = _mm256_fmadd_ps(s1, w2, out[9]); + out[10] = _mm256_fmadd_ps(s2, w1, out[10]); + out[11] = _mm256_fmadd_ps(s2, w2, out[11]); + weight += C16NUM; + src += C6NUM; + } + __m256 six = _mm256_set1_ps(6.0f); + __m256 zero = _mm256_set1_ps(0.0f); + if (0x1 & act_flag) { // relu6 + out[0] = _mm256_min_ps(out[0], six); + out[1] = _mm256_min_ps(out[1], six); + out[2] = _mm256_min_ps(out[2], six); + out[3] = _mm256_min_ps(out[3], six); + out[4] = _mm256_min_ps(out[4], six); + out[5] = _mm256_min_ps(out[5], six); + out[6] = _mm256_min_ps(out[6], six); + out[7] = _mm256_min_ps(out[7], six); + out[8] = _mm256_min_ps(out[8], six); + out[9] = _mm256_min_ps(out[9], six); + out[10] = _mm256_min_ps(out[10], six); + out[11] = _mm256_min_ps(out[11], six); + } + if (0x2 & act_flag) { // relu + out[0] = _mm256_max_ps(out[0], zero); + out[1] = _mm256_max_ps(out[1], zero); + out[2] = _mm256_max_ps(out[2], zero); + out[3] = _mm256_max_ps(out[3], zero); + out[4] = _mm256_max_ps(out[4], zero); + out[5] = _mm256_max_ps(out[5], zero); + out[6] = _mm256_max_ps(out[6], zero); + out[7] = _mm256_max_ps(out[7], zero); + out[8] = _mm256_max_ps(out[8], zero); + out[9] = _mm256_max_ps(out[9], zero); + out[10] = _mm256_max_ps(out[10], zero); + out[11] = _mm256_max_ps(out[11], zero); + } + if (real_cal_row == C6NUM) { + Store6 + } else if (real_cal_row == C5NUM) { + Store5 + } else if (real_cal_row == C4NUM) { + Store4 + } else if (real_cal_row == C3NUM) { + Store3 + } else if (real_cal_row == C2NUM) { + Store2 + } else if (real_cal_row == C1NUM) { + Store1 + } +} + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_common_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_common_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..35eb7ccd9e6fa69735d28a8c05bdeb0c9110aa12 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_common_fp32.h @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_COMMON_H_ +#define MINDSPORE_NNACL_FP32_CONV_COMMON_H_ + +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/fp32/conv_sw_avx_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*Row2ColMajorFuncPtr)(const float *src_ptr, float *dst_ptr, int row, int col); +#ifdef ENABLE_ARM64 +typedef void (*MatmulFloatOptFuncPtr)(const float *a, const float *b, float *c, const float *bias, int act_type, + int depth, int row, int col, size_t stride, size_t write_mode); +#endif + +void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, + int block_index); + +// fp32 convolution common (im2col+gemm) +void ConvFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param); + +// fp32 convolution common (im2col+gemm) +void ConvFp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *col_major_input, float *output_data, int task_id, + const ConvParameter *conv_param); + +// common convolution output C4HW4, if out_channel mod 4 remains, just output real channel, no zeros padded. +void ConvFp32OutNC4HW4(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param); + +#ifdef ENABLE_AVX +void CommonConv6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t depth, + size_t out_step, size_t act_flag, size_t real_cal_row); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_COMMON_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_avx_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_avx_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..55e11032bbebcdb050e9676a45d92cb649ddb22a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_avx_fp32.c @@ -0,0 +1,93 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/conv_depthwise_avx_fp32.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp32/activation_fp32.h" + +int ConvDwAVX(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id, ConvDwCalcParam *conv_dw_calc_param) { + if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) { + return NNACL_ERR; + } + + int32_t *num_pixels = conv_dw_calc_param->num_pixels_; + int32_t *out_w_start = conv_dw_calc_param->out_w_start_; + int first_calc_kw = conv_dw_calc_param->first_calc_kw_; + + int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int h_start = h_step * task_id; + int h_end = MSMIN(h_start + h_step, conv_param->output_h_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + + for (int b = 0; b < conv_param->output_batch_; b++) { + const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = h_start; oh < h_end; oh++) { + int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_)); + float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_; + + bool first_calc_flag = true; + if (first_calc_kw == -1) { + for (int ow = 0; ow < conv_param->output_w_; ow++) { + memcpy(dst_data + ow * conv_param->output_channel_, bias_data, + conv_param->output_channel_ * (int)(sizeof(float))); + } + first_calc_flag = false; + } + for (int kh = start_kh; kh < end_kh; kh++) { + int ih = ih_origin + conv_param->dilation_h_ * kh; + int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_; + + const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_; + const float *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_; + + if (first_calc_flag) { + int iw_origin = -conv_param->pad_l_ + conv_param->dilation_w_ * first_calc_kw; + const float *src_kw = src_kh + iw_origin * conv_param->input_channel_; + ConvDwAVXFp32Row(dst_data, src_kw, weight_kh + first_calc_kw * conv_param->output_channel_, + conv_param->output_w_, conv_param->output_channel_, in_sw_step, true, bias_data); + } + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + if (first_calc_flag && (kw == first_calc_kw)) { + weight_kh += conv_param->output_channel_; + first_calc_flag = false; + continue; + } + int iw_origin = (out_w_start[kw] * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw; + const float *src_kw = src_kh + iw_origin * conv_param->input_channel_; + float *dst_w = dst_data + out_w_start[kw] * conv_param->output_channel_; + + ConvDwAVXFp32Row(dst_w, src_kw, weight_kh, num_pixels[kw], conv_param->output_channel_, in_sw_step, false, + bias_data); + weight_kh += conv_param->output_channel_; + } + } + if (relu) { + Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } + if (relu6) { + Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_avx_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_avx_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..5ab723165059e394b710e60e9b8296c2a51c09df --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_avx_fp32.h @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_AVX_H_ +#define MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_AVX_H_ + +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/base/conv_common_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwAVX(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id, ConvDwCalcParam *conv_dw_calc_param_); + +void ConvDwAVXFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step, bool first_calc_flag, const float *bias); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c8b7de864790cad014247184ad365538f5477491 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.c @@ -0,0 +1,2074 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp32/activation_fp32.h" + +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) +void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels, + int output_channel, int input_step) { + for (int i = 0; i < num_pixels; i++) { + for (int c = 0; c < output_channel; c++) { + *output_ptr++ += weight_ptr[c] * input_ptr[c]; + } + input_ptr += input_step; + } +} +#endif + +int ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id) { + if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) { + return NNACL_ERR; + } + int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int h_start = h_step * task_id; + int h_end = MSMIN(h_start + h_step, conv_param->output_h_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + for (int b = 0; b < conv_param->output_batch_; b++) { + const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = h_start; oh < h_end; oh++) { + float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_; + + int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_)); + + for (int ow = 0; ow < conv_param->output_w_; ow++) { + memcpy(dst_data + ow * conv_param->output_channel_, bias_data, + conv_param->output_channel_ * (int)(sizeof(float))); + } + for (int kh = start_kh; kh < end_kh; kh++) { + int ih = ih_origin + conv_param->dilation_h_ * kh; + + const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_; + const float *dw_weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_; + + int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_; + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + int out_w_start = MSMAX( + 0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_); + int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ - + conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / + conv_param->stride_w_); + + float *dst_w = dst_data + out_w_start * conv_param->output_channel_; + int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw; + + const float *src_kw = src_kh + iw_origin * conv_param->input_channel_; + int num_pixels = out_w_end - out_w_start; + + ConvDwFp32Row(dst_w, src_kw, dw_weight_kh, num_pixels, conv_param->output_channel_, in_sw_step); + dw_weight_kh += conv_param->output_channel_; + } + } + if (relu) { + Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } + if (relu6) { + Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } + } + } + return NNACL_OK; +} + +#ifdef ENABLE_AVX512 +int ConvDwAVX512(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id, ConvDwCalcParam *conv_dw_calc_param) { + if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) { + return NNACL_ERR; + } + int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int h_start = h_step * task_id; + int h_end = MSMIN(h_start + h_step, conv_param->output_h_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + + int32_t *num_pixels = conv_dw_calc_param->num_pixels_; + int32_t *out_w_start = conv_dw_calc_param->out_w_start_; + int first_calc_kw = conv_dw_calc_param->first_calc_kw_; + + for (int b = 0; b < conv_param->output_batch_; b++) { + const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = h_start; oh < h_end; oh++) { + float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_; + + int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_)); + + bool first_calc_flag = true; + if (first_calc_kw == -1) { + first_calc_flag = false; + for (int ow = 0; ow < conv_param->output_w_; ow++) { + memcpy(dst_data + ow * conv_param->output_channel_, bias_data, + conv_param->output_channel_ * (int)(sizeof(float))); + } + } + for (int kh = start_kh; kh < end_kh; kh++) { + int ih = ih_origin + conv_param->dilation_h_ * kh; + + const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_; + const float *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_; + + int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_; + + if (first_calc_flag) { + int iw_origin = -conv_param->pad_l_ + conv_param->dilation_w_ * first_calc_kw; + + const float *src_kw = src_kh + iw_origin * conv_param->input_channel_; + + ConvDwAVX512Fp32Row(dst_data, src_kw, weight_kh + first_calc_kw * conv_param->output_channel_, + conv_param->output_w_, conv_param->output_channel_, in_sw_step, true, bias_data); + } + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + if (first_calc_flag && (kw == first_calc_kw)) { + first_calc_flag = false; + weight_kh += conv_param->output_channel_; + continue; + } + + float *dst_w = dst_data + out_w_start[kw] * conv_param->output_channel_; + int iw_origin = (out_w_start[kw] * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw; + + const float *src_kw = src_kh + iw_origin * conv_param->input_channel_; + + ConvDwAVX512Fp32Row(dst_w, src_kw, weight_kh, num_pixels[kw], conv_param->output_channel_, in_sw_step, false, + bias_data); + weight_kh += conv_param->output_channel_; + } + } + if (relu) { + Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } else if (relu6) { + Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } + } + } + return NNACL_OK; +} +#endif + +void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) { + if (block == 0) { + return; + } + int left = 0; + int right = conv_param->output_w_; + int top = 0; + int bottom = conv_param->output_h_; + + while (left * conv_param->stride_w_ < conv_param->pad_l_) { + left++; + } + while ((right - 1) * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_ * conv_param->dilation_w_ > + conv_param->input_w_ && + right > left) { + right--; + } + while (top * conv_param->stride_h_ < conv_param->pad_u_) { + top++; + } + while ((bottom - 1) * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_ * conv_param->dilation_h_ > + conv_param->input_h_ && + bottom > top) { + bottom--; + } + sliding->left_ = left; + sliding->right_ = right; + sliding->top_ = top; + sliding->bottom_ = bottom; + sliding->c_block_ = UP_DIV(conv_param->output_channel_, block); + sliding->block_channel_ = UP_DIV(conv_param->output_channel_, block) * block; + sliding->out_step_ = conv_param->output_h_ * conv_param->output_w_ * sliding->block_channel_; + if (conv_param->out_format_ == Format_NC4HW4) { + // write to nc8hw8 + sliding->out_h_step_ = conv_param->output_w_ * block; + sliding->out_c_step_ = block * conv_param->output_h_ * conv_param->output_w_; + sliding->out_w_step_ = block; + sliding->out_block_step_ = sliding->out_c_step_; + } else { + // write to nhwc + sliding->out_h_step_ = conv_param->output_w_ * sliding->block_channel_; + sliding->out_c_step_ = block; + sliding->out_w_step_ = sliding->block_channel_; + sliding->out_block_step_ = sliding->out_w_step_; + } +} + +void InitSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block, + int weight_block) { + InitSlidingParam(sliding, conv_param, weight_block); + AppendSlidingParamConv(sliding, conv_param, input_block, weight_block); +} + +void AppendSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block, + int weight_block) { + if (input_block == 0) { // is not aligned + sliding->ic_align_ = conv_param->input_channel_; + } else { // 1x1 input is aligned to input_block + sliding->ic_align_ = UP_DIV(conv_param->input_channel_, input_block) * input_block; + } + sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->ic_align_; // for batch loop + sliding->in_h_step_ = conv_param->input_w_ * sliding->ic_align_; + sliding->in_sh_step_ = conv_param->input_w_ * sliding->ic_align_ * conv_param->stride_h_; // stride H + sliding->in_sw_step_ = sliding->ic_align_ * conv_param->stride_w_; // stride W + sliding->in_kh_step_ = conv_param->input_w_ * sliding->ic_align_ * conv_param->dilation_h_; // kernel H + sliding->in_kw_step_ = sliding->ic_align_ * conv_param->dilation_w_; // kernel W + sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * sliding->ic_align_ * weight_block; +} + +void InitSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) { + InitSlidingParam(sliding, conv_param, block); + AppendSlidingParamConvDw(sliding, conv_param, block); +} + +void AppendSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) { + sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->block_channel_; // for batch loop + sliding->in_h_step_ = conv_param->input_w_ * sliding->block_channel_; + sliding->in_sh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->stride_h_; // stride H + sliding->in_sw_step_ = sliding->block_channel_ * conv_param->stride_w_; // stride W + sliding->in_kh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->dilation_h_; // kernel H + sliding->in_kw_step_ = sliding->block_channel_ * conv_param->dilation_w_; // kernel W + sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * block; +} + +/*conv depthwise fp32 begin*/ +void ConvDwBorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width, + int in_kh_step, int in_kw_step, int kernel_w_step, bool is_relu, bool is_relu6) { + const float *src_kh = src; + const float *weight_kh = weight; + for (int c = 0; c < C4NUM; c++) { + dst[c] = 0; + } + for (int kh = 0; kh < height; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst[c] += src_kw[c] * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w_step; + } // kernel_h loop + for (int c = 0; c < C4NUM; c++) { + dst[c] += bias[c]; + dst[c] = (is_relu) ? (MSMAX(0, dst[c])) : (dst[c]); + dst[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[c]))) : (dst[c]); + } +} + +void ConvDwBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left, + int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) { + return; + } + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + float *dst_h = dst + top * sliding->out_h_step_; + for (int oh = top; oh < bottom; oh++) { + int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const float *src_h = src + ih * sliding->in_h_step_; + + float *dst_kernel = dst_h + left * sliding->block_channel_; + for (int ow = left; ow < right; ow++) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const float *src_w = src_h + iw * sliding->block_channel_; + + const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; +#ifdef ENABLE_AVX + ConvDwFp32BorderParam *param = (ConvDwFp32BorderParam *)malloc(sizeof(ConvDwFp32BorderParam)); + if (param == NULL) { + return; + } + param->dst = dst_kernel; + param->src = src_kernel; + param->weight = weight_kernel; + param->bias = bias; + param->height = end_kh - start_kh; + param->width = end_kw - start_kw; + param->in_kh_step = sliding->in_kh_step_ * sizeof(float); + param->in_kw_step = sliding->in_kw_step_ * sizeof(float); + param->kernel_w = conv_param->kernel_w_ * C4NUM * sizeof(float); + param->relu = relu; + param->relu6 = relu6; + ConvDwFp32Border(param); + free(param); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), + conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6); +#else + ConvDwBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM, relu, relu6); +#endif + dst_kernel += sliding->block_channel_; + } // width loop + dst_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void ConvDwCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, + int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) { + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + const float *src_kh = src_w; + const float *weight_kh = weight; + for (int c = 0; c < C4NUM; c++) { + dst_w[c] = 0; + } + for (int kh = 0; kh < kernel_h; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_w[c] += src_kw[c] * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + // add biad relu + for (int c = 0; c < C4NUM; c++) { + dst_w[c] += bias[c]; + dst_w[c] = (is_relu) ? (MSMAX(0, dst_w[c])) : (dst_w[c]); + dst_w[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[c]))) : (dst_w[c]); + } + dst_w += block_channel; + src_w += in_sw_step; + } // dst_width loop + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} +#endif + +// conv depthwise fp32: sliding window +void ConvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + if (conv_param->thread_num_ == 0) { + return; + } + const float *src = input_data; + float *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float *src_data = src + oc * C4NUM; + float *dst_data = dst + oc * C4NUM; + const float *weight = weight_data + oc * sliding->kernel_step_; + const float *bias = bias_data + oc * C4NUM; + ConvDwBorder(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, sliding); + ConvDwBorder(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, conv_param->output_w_, + conv_param, sliding); + ConvDwBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param, + sliding); + ConvDwBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->output_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; + float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), + sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), + sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float), + sliding->in_kw_step_ * sizeof(float), relu, relu6); +#else + ConvDwCenter(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, relu, + relu6); +#endif + } + } // output C4 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nhwc4 +} +/*conv depthwise fp32 end*/ + +/*conv depthwise 3x3 fp32 begin*/ +bool CheckConvDwUse3X3(const ConvParameter *conv_param) { + bool use_3x3 = + conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && + (conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) && + (conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) && conv_param->stride_h_ == conv_param->stride_w_ && + (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && (conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) && + conv_param->pad_u_ == conv_param->pad_l_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1; + if (!use_3x3 || conv_param->input_h_ == 1 || conv_param->input_w_ == 1) { + return false; + } + const int in_h = (conv_param->output_h_ - 1) * conv_param->stride_h_ + conv_param->kernel_h_; + const int in_w = (conv_param->output_w_ - 1) * conv_param->stride_w_ + conv_param->kernel_w_; + return in_h == (conv_param->input_h_ + 2 * conv_param->pad_u_) && + in_w == (conv_param->input_w_ + 2 * conv_param->pad_l_); +} + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +static void ConvDw3x3RowLeft(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2, v3; + v0 = MS_MOVQ_F32(0.0f); + int ic = 0; + for (; ic < channel - 3; ic += 4) { + v1 = MS_LDQ_F32(src + ic); + v2 = MS_LDQ_F32(src + channel + ic); + v3 = MS_LDQ_F32(src + 2 * channel + ic); + MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2); + MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1); + MS_STQ_F32(line + lw * ic, b0); + MS_STQ_F32(line + lw * ic + 4, b1); + MS_STQ_F32(line + lw * ic + 8, b2); + MS_STQ_F32(line + lw * ic + 12, b3); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 64); + for (int i = 0; i < channel - ic; i++) { + float d1 = src[i + ic]; + float d2 = src[i + ic + channel]; + float d3 = src[i + ic + 2 * channel]; + remain_line[i] = 0.0f - d2; + remain_line[i + 4] = d1 + d2; + remain_line[i + 8] = d2 - d1; + remain_line[i + 12] = d3 - d1; + } + } +} + +static void ConvDw3x3RowMiddle(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2, v3; + int ic = 0; + for (; ic < channel - 3; ic += 4) { + v0 = MS_LDQ_F32(src + ic); + v1 = MS_LDQ_F32(src + channel + ic); + v2 = MS_LDQ_F32(src + 2 * channel + ic); + v3 = MS_LDQ_F32(src + 3 * channel + ic); + MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2); + MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1); + MS_STQ_F32(line + lw * ic, b0); + MS_STQ_F32(line + lw * ic + 4, b1); + MS_STQ_F32(line + lw * ic + 8, b2); + MS_STQ_F32(line + lw * ic + 12, b3); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 64); + for (int i = 0; i < channel - ic; i++) { + float d0 = src[i + ic]; + float d1 = src[i + ic + channel]; + float d2 = src[i + ic + 2 * channel]; + float d3 = src[i + ic + 3 * channel]; + remain_line[i] = d0 - d2; + remain_line[i + 4] = d1 + d2; + remain_line[i + 8] = d2 - d1; + remain_line[i + 12] = d3 - d1; + } + } +} + +static void ConvDw3x3RowRight(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2, v3; + int ic = 0; + v3 = MS_MOVQ_F32(0.0f); + for (; ic < channel - 3; ic += 4) { + v0 = MS_LDQ_F32(src + ic); + v1 = MS_LDQ_F32(src + channel + ic); + v2 = MS_LDQ_F32(src + 2 * channel + ic); + MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2); + MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1); + MS_STQ_F32(line + lw * ic, b0); + MS_STQ_F32(line + lw * ic + 4, b1); + MS_STQ_F32(line + lw * ic + 8, b2); + MS_STQ_F32(line + lw * ic + 12, b3); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 64); + for (int i = 0; i < channel - ic; i++) { + float d0 = src[i + ic]; + float d1 = src[i + ic + channel]; + float d2 = src[i + ic + 2 * channel]; + remain_line[i] = d0 - d2; + remain_line[i + 4] = d1 + d2; + remain_line[i + 8] = d2 - d1; + remain_line[i + 12] = 0.0f - d1; + } + } +} + +static void ConvDw3x3RowSingle(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2; + int ic = 0; + v2 = MS_MOVQ_F32(0.0f); + for (; ic < channel - 3; ic += 4) { + v0 = MS_LDQ_F32(src + ic); + v1 = MS_LDQ_F32(src + channel + ic); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_STQ_F32(line + lw * ic, v0); + MS_STQ_F32(line + lw * ic + 4, v1); + MS_STQ_F32(line + lw * ic + 8, b2); + memset(line + lw * ic + 12, 0, 16); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 64); + for (int i = 0; i < channel - ic; i++) { + float d0 = src[i + ic]; + float d1 = src[i + ic + channel]; + remain_line[i] = d0; + remain_line[i + 4] = d1; + remain_line[i + 8] = 0.0f - d1; + } + } +} + +static void ConvDw3x3InitTop(const float *src, float **lines, int width, int channel) { + float *line0 = lines[0]; + float *line1 = lines[1]; + float *line2 = lines[2]; + int c4 = UP_ROUND(channel, C4NUM); + int lw = UP_DIV(width, C2NUM) * C4NUM; + memset(line0, 0, c4 * lw * sizeof(float)); + ConvDw3x3RowLeft(src, line1, lw, channel); + ConvDw3x3RowLeft(src + width * channel, line2, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } +} + +static void ConvDw3x3InitRow(const float *src, float **lines, int width, int channel) { + float *line0 = lines[0]; + float *line1 = lines[1]; + float *line2 = lines[2]; + int lw = UP_DIV(width, C2NUM) * C4NUM; + ConvDw3x3RowLeft(src - width * channel, line0, lw, channel); + ConvDw3x3RowLeft(src, line1, lw, channel); + ConvDw3x3RowLeft(src + width * channel, line2, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel); + ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRight(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel); + ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel); + ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } +} + +static void ConvDw3x3Row(const float *src, float **lines, int width, int channel) { + float *tmp = lines[0]; + lines[0] = lines[1]; + lines[1] = lines[2]; + lines[2] = tmp; + int c4 = UP_ROUND(channel, C4NUM); + int lw = UP_DIV(width, C2NUM) * C4NUM; + memset(tmp, 0, c4 * lw * sizeof(float)); + ConvDw3x3RowLeft(src, tmp, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRight(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel); + } +} + +static void ConvDw3x3Bottom(float **lines, int width, int channel) { + float *tmp = lines[0]; + lines[0] = lines[1]; + lines[1] = lines[2]; + lines[2] = tmp; + int c4 = UP_ROUND(channel, C4NUM); + memset(tmp, 0, UP_DIV(width, C2NUM) * c4 * C4NUM * sizeof(float)); +} + +#ifndef ENABLE_ARM64 +void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel, + bool relu, bool relu6) { + int channel = ori_channel; + float *line0 = lines[0]; + float *line1 = lines[1]; + float *line2 = lines[2]; + for (; channel > 0; channel -= 4) { + MS_FLOAT32X4 bias = MS_LDQ_F32(bias_data); + bias_data += 4; + MS_FLOAT32X4 g00 = MS_LDQ_F32(weight); + MS_FLOAT32X4 g01 = MS_LDQ_F32(weight + 4); + MS_FLOAT32X4 g02 = MS_LDQ_F32(weight + 8); + MS_FLOAT32X4 g03 = MS_LDQ_F32(weight + 12); + MS_FLOAT32X4 g10 = MS_LDQ_F32(weight + 16); + MS_FLOAT32X4 g11 = MS_LDQ_F32(weight + 20); + MS_FLOAT32X4 g12 = MS_LDQ_F32(weight + 24); + MS_FLOAT32X4 g13 = MS_LDQ_F32(weight + 28); + MS_FLOAT32X4 g20 = MS_LDQ_F32(weight + 32); + MS_FLOAT32X4 g21 = MS_LDQ_F32(weight + 36); + MS_FLOAT32X4 g22 = MS_LDQ_F32(weight + 40); + MS_FLOAT32X4 g23 = MS_LDQ_F32(weight + 44); + weight += 48; + float *cur_dst = dst; + int ow = 0; + for (; ow < width - 1; ow += 2) { + MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00); + MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01); + MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02); + MS_FLOAT32X4 acc3 = MS_MULQ_F32(MS_LDQ_F32(line0 + 12), g03); + line0 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12); + acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line1 + 12), g13); + line1 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22); + acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line2 + 12), g23); + line2 += 16; + MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1)); + MS_FLOAT32X4 res1 = MS_ADDQ_F32(acc1, MS_SUBQ_F32(acc3, acc2)); + res0 = MS_ADDQ_F32(res0, bias); + res1 = MS_ADDQ_F32(res1, bias); + if (relu || relu6) { + res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f)); + res1 = MS_MAXQ_F32(res1, MS_MOVQ_F32(0.0f)); + } + if (relu6) { + res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f)); + res1 = MS_MINQ_F32(res1, MS_MOVQ_F32(6.0f)); + } + if (channel >= 4) { + MS_STQ_F32(cur_dst, res0); + MS_STQ_F32(cur_dst + ori_channel, res1); + } else { + for (int i = 0; i < channel; i++) { + cur_dst[i] = MS_F32X4_GETI(res0, i); + cur_dst[ori_channel + i] = MS_F32X4_GETI(res1, i); + } + } + cur_dst += 2 * ori_channel; + } + if (ow < width) { + MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00); + MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01); + MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02); + line0 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12); + line1 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22); + line2 += 16; + MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1)); + res0 = MS_ADDQ_F32(res0, bias); + if (relu || relu6) { + res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f)); + } + if (relu6) { + res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f)); + } + if (channel >= 4) { + MS_STQ_F32(cur_dst, res0); + } else { + for (int i = 0; i < channel; i++) { + cur_dst[i] = MS_F32X4_GETI(res0, i); + } + } + } + dst += 4; + } +} +#endif + +void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data, + const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh) { + int units = UP_DIV(conv_param->output_w_, C2NUM); + int c4 = UP_ROUND(conv_param->input_channel_, C4NUM); + int line = conv_param->input_channel_ * conv_param->input_w_; + + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + + for (int b = 0; b < conv_param->output_batch_; b++) { + const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + float *line0 = buffer; + float *line1 = buffer + units * c4 * C4NUM; + float *line2 = buffer + units * c4 * C8NUM; + float *lines[3] = {line0, line1, line2}; + int oh = start_oh; + if (oh == 0) { + // input trans + ConvDw3x3InitTop(src, lines, conv_param->output_w_, conv_param->input_channel_); + } else { + // input trans + ConvDw3x3InitRow(src + oh * line, lines, conv_param->output_w_, conv_param->input_channel_); + } + // dst calc and trans + ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + for (oh = start_oh + 1; oh < end_oh - 1; oh++) { + // input trans + ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); + // dst calc and trans + ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + } + if (oh == conv_param->output_h_ - 1) { + // input trans + ConvDw3x3Bottom(lines, conv_param->output_w_, conv_param->input_channel_); + } else { + // input trans + ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); + } + // dst calc and trans + ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + } +} +#endif + +/*conv depthwise indirect buffer fp32 begin*/ +bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param) { + bool use_indirect = (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3) || + (conv_param->kernel_h_ == 5 && conv_param->kernel_w_ == 5); + return use_indirect; +} + +void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param, + int step_h, int step_w) { +#ifdef ENABLE_AVX + int div = C8NUM; +#else + int div = C4NUM; +#endif + + int ic_div = UP_DIV(conv_param->input_channel_, div) * div; + for (int b = 0; b < conv_param->output_batch_; b++) { + float **indirect = indirect_buffer + b * conv_param->output_h_ * step_h; + float *input = src + b * conv_param->input_h_ * conv_param->input_w_ * ic_div; + for (int oh = 0; oh < conv_param->output_h_; oh++) { + for (int kh = 0; kh < conv_param->kernel_h_; kh++) { + int ih = oh * conv_param->stride_h_ + kh * conv_param->dilation_h_ - conv_param->pad_u_; + if (ih < conv_param->input_h_ && ih >= 0) { + for (int ow = 0; ow < conv_param->output_w_; ow++) { + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + int iw = ow * conv_param->stride_w_ + kw * conv_param->dilation_w_ - conv_param->pad_l_; + int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh; + if (iw < conv_param->input_w_ && iw >= 0) { + indirect[index] = input + (ih * conv_param->input_w_ + iw) * ic_div; + } else { + indirect[index] = zero_ptr; + } + } + } + } else { + for (int ow = 0; ow < conv_param->output_w_; ow++) { + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh; + indirect[index] = zero_ptr; + } + } + } + } + } + } +} + +#if !defined(ENABLE_ARM64) && !defined(ENABLE_AVX) +void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, int input_stride, bool relu, bool relu6, int kernel) { + do { + float **in = input; + size_t c = (size_t)channels; + const float *w = weights; + float *out = output; + memcpy(out, bias, channels * (int)sizeof(float)); + for (; c >= C4NUM; c -= C4NUM) { + for (int i = 0; i < C4NUM; i++) { + for (int k = 0; k < kernel; k++) { + out[i] += in[k][i] * w[i + k * C4NUM]; + } + } + w += kernel * C4NUM; + out += C4NUM; + for (int k = 0; k < kernel; k++) { + in[k] += C4NUM; + } + } + for (int i = 0; i < c; i++) { + for (int k = 0; k < kernel; k++) { + out[i] += in[k][i] * w[i + k * C4NUM]; + } + } + if (relu) { + Fp32Relu(output, channels, output); + } + if (relu6) { + Fp32Relu6(output, channels, output); + } + output += channels; + input = input + input_stride; + } while (--output_width != 0); +} +#endif + +#ifdef ENABLE_ARM64 +void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, int input_stride, bool relu, bool relu6, int kernel) { + if (kernel == 9) { + ConvDwFp32Indirect3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, + relu6); + } else if (kernel == 25) { + ConvDwFp32Indirect5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, + relu6); + } +} +#endif + +#ifdef ENABLE_AVX +void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, int input_stride, bool relu, bool relu6, int kernel) { + if (kernel == 9) { + ConvDwFp32Avx3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6); + } else if (kernel == 25) { + ConvDwFp32Avx5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6); + } +} +#endif + +void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data, + float *zero_ptr, const ConvParameter *conv_param, int task_id) { + if (conv_param->thread_num_ == 0) { + return; + } + int step_w = conv_param->dilation_w_ == 1 ? conv_param->stride_w_ : conv_param->kernel_w_; + int step_h = + (conv_param->kernel_h_ * conv_param->kernel_w_) + (conv_param->output_w_ - 1) * step_w * conv_param->kernel_h_; + int input_stride = conv_param->kernel_h_ * step_w; + + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + + int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int h_start = h_step * task_id; + int h_end = MSMIN(h_start + h_step, conv_param->output_h_); + + for (int b = 0; b < conv_param->output_batch_; b++) { + float **indirect_b = indirect_buffer + b * conv_param->output_h_ * step_h; + float *outout_b = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = h_start; oh < h_end; oh++) { + float **indirect = indirect_b + oh * step_h; + float *output_h = outout_b + oh * conv_param->output_w_ * conv_param->output_channel_; + if (conv_param->kernel_w_ == 3) { + ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_, + conv_param->output_w_, input_stride, relu, relu6, 9); + } else if (conv_param->kernel_w_ == 5) { + ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_, + conv_param->output_w_, input_stride, relu, relu6, 25); + } + } + } +} +/*conv depthwise indirect buffer fp32 end*/ + +/*deconv depthwise fp32 begin*/ +void DeconvDwBorderPixel(float *dst, const float *src, const float *weight, int height, int width, int in_kh_step, + int in_kw_step, int kernel_w_step) { + float *dst_kh = dst; + const float *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + float *dst_kw = dst_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { +#ifdef ENABLE_ARM64 + float32x4_t src_4 = vld1q_f32(src); + float32x4_t weight_4 = vld1q_f32(weight_kw); + float32x4_t dst_4 = vld1q_f32(dst_kw); + dst_4 = vfmaq_f32(dst_4, src_4, weight_4); + vst1q_f32(dst_kw, dst_4); +#else + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src[c] * weight_kw[c]; + } +#endif + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w_step; + } // kernel_h loop +} + +void DeconvDwBorder(float *dst, const float *src, const float *weight, int top, int bottom, int left, int right, + const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) { + return; + } + const float *src_h = src + top * sliding->out_h_step_; + for (int ih = top; ih < bottom; ih++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + float *dst_h = dst + oh * sliding->in_h_step_; + + const float *src_kernel = src_h + left * sliding->block_channel_; + for (int iw = left; iw < right; iw++) { + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + float *dst_w = dst_h + ow * sliding->block_channel_; + + const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; + float *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; +#ifdef ENABLE_ARM64 + DeconvDwFp32Border(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), + conv_param->kernel_w_ * C4NUM * sizeof(float)); +#else + DeconvDwBorderPixel(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM); +#endif + src_kernel += sliding->block_channel_; + } // width loop + src_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void DeconvDwCenter(float *dst, const float *src, const float *weight, int height, int width, int kernel_h, + int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, int in_kh_step, + int in_kw_step) { + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + float *dst_kh = dst_w; + const float *weight_kh = weight; + for (int kh = 0; kh < kernel_h; kh++) { + float *dst_kw = dst_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src_w[c] * weight_kw[c]; + } + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} +#endif + +void DeconvDwPost(float *dst, const float *bias, int block_channel, const ConvParameter *conv_param) { + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + float *dst_k = dst; + for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) { + for (int c = 0; c < C4NUM; c++) { + dst_k[c] += bias[c]; + dst_k[c] = (relu) ? (MSMAX(0, dst_k[c])) : (dst_k[c]); + dst_k[c] = (relu6) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]); + } + dst_k += block_channel; + } +} + +// deconv depthwise fp32: sliding window +void DeconvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { + const float *src = input_data; + float *dst = output_data; + if (conv_param->thread_num_ == 0) { + return; + } + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float *src_data = src + oc * C4NUM; + float *dst_data = dst + oc * C4NUM; + const float *weight = weight_data + oc * sliding->kernel_step_; + const float *bias = bias_data + oc * C4NUM; + DeconvDwBorder(dst_data, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, sliding); + DeconvDwBorder(dst_data, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, conv_param->input_w_, + conv_param, sliding); + DeconvDwBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param, + sliding); + DeconvDwBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, conv_param->input_w_, + conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; + const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), + sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), + sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float), + sliding->in_kw_step_ * sizeof(float)); +#else + DeconvDwCenter(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_); +#endif + } + DeconvDwPost(dst_data, bias, sliding->block_channel_, conv_param); + } // output C4 loop + src += sliding->out_step_; + dst += sliding->in_step_; + } // batch loop + // output nhwc4 +} +/*deconv depthwise fp32 end*/ + +#ifdef ENABLE_AVX +void DepthwiseBorderAvxFp32(float *dst, const float *src, const float *weight, const float *bias, int top, int left, + int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, + const DepthwiseSWKernel kernel, int act_type, int ow_bock, int oc_block) { + // dw border compate + int ih = top * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const float *src_h = src + ih * sw_param->in_h_step_; + float *dst_kernel = dst + left * sw_param->block_channel_; + for (int ow = left; ow < right; ow += ow_bock) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const float *src_w = src_h + iw * sw_param->block_channel_; + const float *src_kernel = src_w + start_kh * sw_param->in_kh_step_ + start_kw * sw_param->in_kw_step_; + const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM * oc_block; + kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, ow_bock, + oc_block, sw_param->block_channel_, sw_param->in_kw_step_, sw_param->in_kh_step_, sw_param->in_sw_step_, + (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block); + dst_kernel += ow_bock * sw_param->block_channel_; + } // width loop +} + +void DepthwiseSWAvxFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sw_param, int task_id) { + int oh_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int oh_start = oh_step * task_id; + int oh_end = MSMIN(oh_start + oh_step, conv_param->output_h_); + if (oh_start >= oh_end) { + return; + } + // depthwise sw in x86 avx instructions + int oc_tile_ = C8NUM; // oc in algin to C8NUM in x86_64_avx + int act_type = 0; + if (conv_param->act_type_ == ActType_Relu6) { + act_type += 1; + } + if (conv_param->act_type_ == ActType_Relu || conv_param->act_type_ == ActType_Relu6) { + act_type += 2; + } + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int output_w = conv_param->output_w_; + int oc_algin = sw_param->block_channel_; + int oc_num = sw_param->c_block_; + int in_step = sw_param->in_step_; + int out_step = sw_param->out_step_; + int in_sw_step = sw_param->in_sw_step_; + int in_kw_step = sw_param->in_kw_step_; + int in_kh_step = sw_param->in_kh_step_; + int in_sh_step = sw_param->in_sh_step_; + int out_right = sw_param->right_; + int out_left = sw_param->left_; + int out_top = sw_param->top_; + int out_bottom = sw_param->bottom_; + int kernel_step = sw_param->kernel_step_; + int out_h_step = sw_param->out_h_step_; + int in_h_start = out_top * conv_param->stride_h_ - conv_param->pad_u_; + int in_w_start = out_left * conv_param->stride_w_ - conv_param->pad_l_; + int in_start = in_h_start * sw_param->in_h_step_ + in_w_start * oc_algin; + const int ow_block_num[4] = {8, 4, 4, 3}; + const DepthwiseSWKernel kernel[4][2] = {{DepthwiseSW1x8Kernel, DepthwiseSW8x8Kernel}, + {DepthwiseSW1x16Kernel, DepthwiseSW4x16Kernel}, + {DepthwiseSW1x24Kernel, DepthwiseSW4x24Kernel}, + {DepthwiseSW1x32Kernel, DepthwiseSW3x32Kernel}}; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oh = oh_start; oh < oh_end; ++oh) { + float *dst_oh = output_data + oh * out_h_step; + const float *src_h = input_data + in_start + (oh - out_top) * in_sh_step; + int oc_block = 0; + const float *bias = bias_data; + for (int oc = 0; oc < oc_num; oc += oc_block) { + oc_block = MSMIN(C4NUM, oc_num - oc); // 4 3 2 1 + int oc_step = oc * oc_tile_; + const float *weight = weight_data + oc * kernel_step; + if (bias != NULL) { + bias = bias_data + oc_step; + } + float *dst_w = dst_oh + oc_step; + const DepthwiseSWKernel kernel_border = kernel[oc_block - 1][0]; + if (oh < out_top || oh >= out_bottom) { // oh in up or down border + DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, 0, output_w, conv_param, sw_param, + kernel_border, act_type, 1, oc_block); + } else { // oh in center + // ow in right + DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, 0, out_left, conv_param, sw_param, + kernel_border, act_type, 1, oc_block); + // ow in center + const float *src_w = src_h + oc_step; + int ow_block = ow_block_num[oc_block - 1]; // 8 4 4 3 + for (int ow = out_left; ow < out_right; ow += ow_block) { // left ~ right + if (ow_block > out_right - ow) { // ow is not enough and process one ow + ow_block = 1; + } + kernel[oc_block - 1][ow_block / ow_block_num[oc_block - 1]]( + dst_w + ow * oc_algin, src_w, weight, bias, kernel_h, kernel_w, act_type, ow_block, oc_block, oc_algin, + in_kw_step, in_kh_step, in_sw_step, 0); + src_w += ow_block * in_sw_step; + } + // ow in left + DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, out_right, output_w, conv_param, + sw_param, kernel_border, act_type, 1, oc_block); + } + } + } // output h loop + input_data += in_step; + output_data += out_step; + } // batch loop +} + +#ifdef ENABLE_DEBUG +void DepthwiseSWWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + __m256 dst_data[12]; + __m256 src_data; + const float *src_kh[12]; + const float *src_kw[12]; + __m256 weight_data[4]; + for (int i = 0; i < ow_block; ++i) { + if (bias != NULL) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_loadu_ps(bias + j * 8); + } + } else { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_set1_ps(0.0f); + } + } + src_kh[i] = src + i * in_sw_step; + src_kw[i] = NULL; + } + const float *weight_kernel = weight; + for (int kh = 0; kh < kernel_h; kh++) { + for (int i = 0; i < ow_block; ++i) { + src_kw[i] = src_kh[i]; + } + for (int kw = 0; kw < kernel_w; kw++) { + for (int j = 0; j < oc_block; ++j) { + weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM); + } + for (int i = 0; i < ow_block; ++i) { // loop ow + for (int j = 0; j < oc_block; ++j) { + src_data = _mm256_loadu_ps(src_kw[i] + j * C8NUM); + dst_data[i * oc_block + j] += src_data * weight_data[j]; + } + } + for (int i = 0; i < ow_block; ++i) { + src_kw[i] += in_kw_step; // ic8 * dilation_w + } + weight_kernel += oc_block * C8NUM; + } // kernel_w loop + weight_kernel += kw_remainder; + for (int i = 0; i < ow_block; ++i) { + src_kh[i] += in_kh_step; // + } + } // kernel_h loop + // add bias and relu + for (int i = 0; i < ow_block; ++i) { + for (int j = 0; j < oc_block; ++j) { + if (0x1 & act_flag) { // relu6 + dst_data[i * oc_block + j] = _mm256_min_ps(dst_data[i * oc_block + j], _mm256_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f)); + } + _mm256_storeu_ps(dst + i * oc_algin + j * C8NUM, dst_data[i * oc_block + j]); + } + } +} +#endif + +void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_sw_step *= sizeof(float); + in_kw_step *= sizeof(float); + oc_algin *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups 0x40(%2), %%ymm6\n" + "vmovups 0x60(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups 0x40(%2), %%ymm10\n" + "vmovups 0x60(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + + "vmovups (%1), %%ymm12\n" + "vmovups (%%rcx), %%ymm13\n" + "vmovups (%%rcx, %7), %%ymm14\n" + "vmovups (%%rcx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + + "vmovups 0x20(%1), %%ymm12\n" + "vmovups 0x20(%%rcx), %%ymm13\n" + "vmovups 0x20(%%rcx, %7), %%ymm14\n" + "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + + "vmovups 0x40(%1), %%ymm12\n" + "vmovups 0x40(%%rcx), %%ymm13\n" + "vmovups 0x40(%%rcx, %7), %%ymm14\n" + "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + + "vmovups 0x60(%1), %%ymm12\n" + "vmovups 0x60(%%rcx), %%ymm13\n" + "vmovups 0x60(%%rcx, %7), %%ymm14\n" + "vmovups 0x60(%%rcx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + "addq $128, %1\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %8, %1\n" + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder) // 8 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", + "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" + "vmovups %%ymm4, (%2, %1, 1)\n" + "vmovups %%ymm5, 0x20(%2, %1, 1)\n" + "vmovups %%ymm6, 0x40(%2, %1, 1)\n" + "vmovups %%ymm7, 0x60(%2, %1, 1)\n" + "vmovups %%ymm8, (%2, %1, 2)\n" + "vmovups %%ymm9, 0x20(%2, %1, 2)\n" + "vmovups %%ymm10, 0x40(%2, %1, 2)\n" + "vmovups %%ymm11, 0x60(%2, %1, 2)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + oc_algin *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // Loopw + "vmovups (%%rcx), %%ymm4\n" + "vmovups 0x20(%%rcx), %%ymm5\n" + "vmovups 0x40(%%rcx), %%ymm6\n" + "vmovups 0x60(%%rcx), %%ymm7\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm6, %%ymm2\n" + "vfmadd231ps 0x60(%1), %%ymm7, %%ymm3\n" + "addq $128, %1\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %7, %1\n" + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(kw_remainder) // 7 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm14"); +} + +void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + in_sw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_algin; + oc_algin *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. + "vmovups (%2), %%ymm3\n" + "vmovups 0x20(%2), %%ymm4\n" + "vmovups 0x40(%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups 0x40(%2), %%ymm8\n" + "vmovups (%2), %%ymm9\n" + "vmovups 0x20(%2), %%ymm10\n" + "vmovups 0x40(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "vmovups (%1), %%ymm12\n" + "vmovups (%%rcx), %%ymm13\n" + "vmovups (%%rcx, %7, 1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n" + "vmovups (%%rcx, %7, 2), %%ymm15\n" + "vmovups (%%rcx, %9), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + + "vmovups 0x20(%1), %%ymm12\n" + "vmovups 0x20(%%rcx), %%ymm13\n" + "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n" + "vmovups 0x20(%%rcx, %9), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm10\n" + + "vmovups 0x40(%1), %%ymm12\n" + "vmovups 0x40(%%rcx), %%ymm13\n" + "vmovups 0x40(%%rcx, %7, 1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n" + "vmovups 0x40(%%rcx, %9), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm11\n" + + "addq $96, %1\n" + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder), "r"(src_3_step) // 9 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", + "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, (%2, %1, 1)\n" + "vmovups %%ymm4, 0x20(%2, %1, 1)\n" + "vmovups %%ymm5, 0x40(%2, %1, 1)\n" + "vmovups %%ymm6, (%2, %1, 2)\n" + "vmovups %%ymm7, 0x20(%2, %1, 2)\n" + "vmovups %%ymm8, 0x40(%2, %1, 2)\n" + "vmovups %%ymm9, (%3)\n" // dst+3 + "vmovups %%ymm10, 0x20(%3)\n" + "vmovups %%ymm11, 0x40(%3)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + oc_algin *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // Loopw + "vmovups (%%rcx), %%ymm4\n" + "vmovups 0x20(%%rcx), %%ymm5\n" + "vmovups 0x40(%%rcx), %%ymm6\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm6, %%ymm2\n" + "addq $96, %1\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %7, %1\n" // kw_remainder + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(kw_remainder) // 7 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm4", "%ymm5", "%ymm6"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm14"); +} + +void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + in_sw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_algin; + oc_algin *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. + "vmovups (%2), %%ymm3\n" + "vmovups 0x20(%2), %%ymm4\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups (%2), %%ymm9\n" + "vmovups 0x20(%2), %%ymm10\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "vmovups (%1), %%ymm12\n" + "vmovups (%%rcx), %%ymm13\n" + "vmovups (%%rcx, %7, 1), %%ymm14\n" + "vmovups (%%rcx, %7, 2), %%ymm15\n" + "vmovups (%%rcx, %9), %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm2, %%ymm9\n" + + "vmovups 0x20(%1), %%ymm12\n" + "vmovups 0x20(%%rcx), %%ymm13\n" + "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n" + "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n" + "vmovups 0x20(%%rcx, %9), %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm2, %%ymm10\n" + + "addq $64, %1\n" + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder), "r"(src_3_step) // 9 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm13", + "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm3, (%2, %1, 1)\n" + "vmovups %%ymm4, 0x20(%2, %1, 1)\n" + "vmovups %%ymm6, (%2, %1, 2)\n" + "vmovups %%ymm7, 0x20(%2, %1, 2)\n" + "vmovups %%ymm9, (%3)\n" // dst+3 + "vmovups %%ymm10, 0x20(%3)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3) + : "%ecx", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm14"); +} + +void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + oc_algin *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // Loopw + "vmovups (%%rcx), %%ymm4\n" + "vmovups 0x20(%%rcx), %%ymm5\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n" + "addq $64, %1\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %7, %1\n" // kw_remainder + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(kw_remainder) // 7 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm4", "%ymm5"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm14"); +} + +void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_sw_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_algin; + float *dst_5 = dst + 5 * oc_algin; + oc_algin *= sizeof(float); + asm volatile( + "cmpq $0, %0\n" + "je 0f\n" + "vmovups (%0), %%ymm0\n" + "vmovups (%0), %%ymm1\n" + "vmovups (%0), %%ymm2\n" + "vmovups (%0), %%ymm3\n" + "vmovups (%0), %%ymm4\n" + "vmovups (%0), %%ymm5\n" + "vmovups (%0), %%ymm6\n" + "vmovups (%0), %%ymm7\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "1:\n" + : + : "r"(bias) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + + asm volatile( + "LoopH:\n" + "movq %3, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "LoopW:\n" + "movq %%rcx, %%rax\n" + "vmovups (%1), %%ymm12\n" + "vmovups (%%rax), %%ymm13\n" + "vmovups (%%rax, %6), %%ymm14\n" + "vmovups (%%rax, %6, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "addq %7, %%rax\n" + "vmovups (%%rax), %%ymm13\n" + "vmovups (%%rax, %6), %%ymm14\n" + "vmovups (%%rax, %6, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "addq %7, %%rax\n" + "vmovups (%%rax), %%ymm13\n" + "vmovups (%%rax, %6), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + + "addq $32, %1\n" + "addq %4, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg LoopW\n" + + "addq %5, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %2\n" + "jg LoopH\n" + : + : "r"(src), "r"(weight), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), "r"(in_kh_step), // 5 + "r"(in_sw_step), "r"(src_3_step), "r"(kw_remainder) // 8 + : "%rcx", "%rsi", "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", + "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je Write\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + + "and $0x1, %%eax\n" + "je Write\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + + "Write:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%3)\n" // dst_3 + "vmovups %%ymm4, (%2, %1, 4)\n" + "vmovups %%ymm5, (%4)\n" // dst_5 + "vmovups %%ymm6, (%4, %1, 1)\n" + "vmovups %%ymm7, (%4, %1, 2)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_5) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", "%ymm14"); +} + +void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + oc_algin *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // Loopw + "vmovups (%%rcx), %%ymm4\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "addq $32, %1\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %7, %1\n" // kw_remainder + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(kw_remainder) // 7 + : "%rcx", "%rsi", "%ymm0", "%ymm4"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + : + : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "%ecx", "%ymm0", "%ymm12", "%ymm14"); +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..988391ab55d830f6ab7ff9843b4c8c5f3848a182 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.h @@ -0,0 +1,148 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_H_ +#define MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_H_ + +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/base/conv_common_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef ENABLE_ARM64 +void DepthwiseCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, + int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6); +#endif + +int ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id); + +int ConvDwAVX512(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id, ConvDwCalcParam *conv_dw_calc_param_); + +void ConvDwAVX512Fp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step, bool first_calc_flag, const float *bias); + +void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block); + +void InitSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block, + int weight_block); + +void AppendSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int in_block, + int weight_block); + +void InitSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block); + +void AppendSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block); + +void ConvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +bool CheckConvDwUse3X3(const ConvParameter *conv_param); + +bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param); + +void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param, + int step_h, int step_w); + +#ifdef ENABLE_ARM64 +void ConvDwFp32Indirect3x3(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, size_t input_stride, size_t relu, size_t relu6); + +void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, size_t input_stride, size_t relu, size_t relu6); +#endif + +#ifdef ENABLE_AVX +typedef void (*DepthwiseSWKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSWAvxFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +void DepthwiseBorderAvxFp32(float *dst, const float *src, const float *weight, const float *bias, int top, int left, + int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, + const DepthwiseSWKernel kernel, int act_type, int ow_bock, int oc_block); + +void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels, + size_t output_width, size_t input_stride, size_t relu, size_t relu6); + +void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const float *bias, size_t channels, + size_t output_width, size_t input_stride, size_t relu, size_t relu6); +#ifdef ENABLE_DEBUG +void DepthwiseSWWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); +#endif +#endif + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel, + bool relu, bool relu6); +void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data, + const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh); +#endif + +void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, int input_stride, bool relu, bool relu6, int kernel); + +void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data, + float *zero_ptr, const ConvParameter *conv_param, int task_id); + +void DeconvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_avx512_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_avx512_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..308bae077faa3df36de48918ffc6b37829041413 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_avx512_fp32.c @@ -0,0 +1,92 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/conv_im2col_avx512_fp32.h" +#include "nnacl_c/fp32/conv_im2col_fp32.h" +#include "nnacl_c/fp32/matmul_avx512_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_avx512_instructions.h" + +// fp32 conv common +void ConvIm2ColAVX512Fp32(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *output_data, int task_id, const ConvParameter *conv_param, + int cal_num) { + if (conv_param->thread_num_ == 0) { + return; + } + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int out_channel_align = UP_ROUND(conv_param->output_channel_, C16NUM); + + int block_per_thread = UP_DIV(UP_DIV(output_hw, cal_num), conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int start_hw = start_block * cal_num; + int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * cal_num); + if (start_hw >= end_hw) { + return; + } + int out_stride = out_channel_align * cal_num; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + size_t input_size = deep * cal_num * sizeof(float); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * out_channel_align * output_hw + start_hw * out_channel_align; + for (int i = start_hw; i < end_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColDataPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + + float *gemm_output = output_data + out_offset; + MatMulAvx512Fp32(packed_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep, + out_channel_align, out_channel_align, real_cal_row); + } + } +} + +// fp32 conv common +void ConvIm2ColAVX512Fp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *output_data, int task_id, + const ConvParameter *conv_param, int cal_num) { + if (conv_param->thread_num_ == 0) { + return; + } + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int out_channel_align = UP_ROUND(conv_param->output_channel_, C16NUM); + + int block_batch_per_thread = UP_DIV(conv_param->input_batch_, conv_param->thread_num_); + int start_batch = block_batch_per_thread * task_id; + int end_batch = MSMIN(conv_param->input_batch_, (start_batch + block_batch_per_thread)); + + int out_stride = out_channel_align * cal_num; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + + size_t input_size = deep * cal_num * sizeof(float); + + for (int b = start_batch; b < end_batch; b++) { + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * out_channel_align * output_hw; + for (int i = 0; i < output_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColDataPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + + float *gemm_output = output_data + out_offset; + MatMulAvx512Fp32(packed_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep, + out_channel_align, out_channel_align, real_cal_row); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_avx512_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_avx512_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..9991c8d7985065554402a791c66d75a98cbf1d28 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_avx512_fp32.h @@ -0,0 +1,38 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_IM2COL_AVX512_H_ +#define MINDSPORE_NNACL_FP32_CONV_IM2COL_AVX512_H_ + +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ConvIm2ColAVX512Fp32(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *output_data, int task_id, const ConvParameter *conv_param, + int cal_num); + +void ConvIm2ColAVX512Fp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *output_data, int task_id, + const ConvParameter *conv_param, int cal_num); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_IM2COL_AVX512_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..9c5f8c9277a37fd51f14321667731d111c97a1f2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_fp32.c @@ -0,0 +1,65 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/conv_im2col_fp32.h" + +void Im2ColDataPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index) { + // input format : nhwc + int kernel_w = conv_param->kernel_w_; + int kernel_h = conv_param->kernel_h_; + int kernel_plane = kernel_h * kernel_w; + int dilation_w = conv_param->dilation_w_; + int dilation_h = conv_param->dilation_h_; + + int out_w = conv_param->output_w_; + if (dilation_w == 0 || dilation_h == 0 || out_w == 0) { + return; + } + int in_channel = conv_param->input_channel_; + int in_w = conv_param->input_w_; + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_; + int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_; + if (conv_param->input_h_ - input_h < 0 || in_w - input_w < 0) { + continue; + } + int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); + int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); + int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); + int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h)); + int input_stride = (input_h * in_w + input_w) * in_channel; + if (dilation_w == 1 && dilation_h == 1) { + for (int j = kh_s; j < kh_e; j++) { + int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; + int input_y_stride = j * in_w * in_channel + input_stride; + int input_x_stride = input_y_stride + kw_s * in_channel; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, + (kw_e - kw_s) * in_channel * sizeof(float)); + } // kernel_h loop + } else { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; + for (int k = kw_s; k < kw_e; ++k) { + int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane; + int input_x_stride = input_y_stride + k * dilation_w * in_channel; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float)); + } + } // kernel_h loop + } + } // tile num loop +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..19ff0418d47ad078e72d80296abdc7a8d22be968 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_IM2COL_H_ +#define MINDSPORE_NNACL_FP32_CONV_IM2COL_H_ + +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Im2ColDataPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_IM2COL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw.h new file mode 100644 index 0000000000000000000000000000000000000000..e95a691122c885574730089b03f935bea243f02a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw.h @@ -0,0 +1,131 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_SW_H_ +#define MINDSPORE_NNACL_FP32_CONV_SW_H_ + +#define GenerateConvSWFunc(backend, oc_unit_num, row_num_list, kernel_list, compute_core, outer_compute) \ + void SWBorder##backend(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, \ + int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, \ + const SWConvKernel kernel, int act_type, int ow_bock, int oc_block, size_t write_mode) { \ + for (int oh = top; oh < bottom; oh++) { \ + int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; \ + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); \ + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); \ + const float *src_h = src + ih * sw_param->in_h_step_; \ + float *dst_kernel = dst + left * sw_param->out_w_step_; \ + for (int ow = left; ow < right; ow += ow_bock) { \ + int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; \ + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); \ + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); \ + const float *src_w = src_h + iw * sw_param->ic_align_; \ + const float *src_kernel = src_w + start_kh * sw_param->in_kh_step_ + start_kw * sw_param->in_kw_step_; \ + const float *weight_kernel = \ + weight + (start_kh * conv_param->kernel_w_ + start_kw) * sw_param->ic_align_ * C8NUM * oc_block; \ + outer_compute dst_kernel += ow_bock * sw_param->out_w_step_; \ + } \ + dst += sw_param->out_h_step_; \ + } \ + } \ + \ + void ConvSW##backend##Fp32(const float *input_data, const float *packed_weight, const float *bias_data, \ + float *output_data, int task_id, ConvParameter *conv_param, \ + SlidingWindowParam *sw_param) { \ + int out_h = conv_param->output_h_; \ + int oh_step = UP_DIV(out_h, conv_param->thread_num_); \ + int oh_start = oh_step * task_id; \ + int oh_end = MSMIN(oh_start + oh_step, out_h); \ + if (oh_start >= oh_end) { \ + return; \ + } \ + int oc_tile_ = C8NUM; /* oc in algin to C8NUM in arm64 */ \ + int act_type = 0; \ + if (conv_param->act_type_ == ActType_Relu6) { \ + act_type += 1; \ + } \ + if (conv_param->act_type_ == ActType_Relu || conv_param->act_type_ == ActType_Relu6) { \ + act_type += 2; \ + } \ + int kernel_h = conv_param->kernel_h_; \ + int kernel_w = conv_param->kernel_w_; \ + int ic_algin = sw_param->ic_align_; \ + int in_sw_step = sw_param->in_sw_step_; \ + int in_kw_step = sw_param->in_kw_step_; \ + int in_kh_step = sw_param->in_kh_step_; \ + int in_sh_step = sw_param->in_sh_step_; \ + int out_h_step = sw_param->out_h_step_; \ + int out_c_step = sw_param->out_c_step_; \ + int out_w_step = sw_param->out_w_step_; \ + int out_block_step = sw_param->out_block_step_; \ + int kernel_step = sw_param->kernel_step_; \ + int in_step = sw_param->in_step_; \ + int out_step = sw_param->out_step_; \ + int c_block = sw_param->c_block_; \ + int top = sw_param->top_; \ + int left = sw_param->left_; \ + int right = sw_param->right_; \ + int bottom = sw_param->bottom_; \ + int stride_h = conv_param->stride_h_; \ + int stride_w = conv_param->stride_w_; \ + int out_w = conv_param->output_w_; \ + int pad_u = conv_param->pad_u_; \ + int pad_l = conv_param->pad_l_; \ + int in_h_step = sw_param->in_h_step_; \ + int out_batch = conv_param->output_batch_; \ + int in_h_start = top * stride_h - pad_u; \ + int in_w_start = left * stride_w - pad_l; \ + int center_step = in_h_start * in_h_step + in_w_start * ic_algin; \ + int write_mode = conv_param->out_format_; \ + row_num_list kernel_list for (int b = 0; b < out_batch; b++) { \ + for (int oh = oh_start; oh < oh_end; oh += 1) { \ + float *dst_oh = output_data + oh * out_h_step; \ + const float *src_h = input_data + center_step; \ + \ + int oc_block = 0; \ + const float *bias = bias_data; \ + for (int oc = 0; oc < c_block; oc += oc_block) { \ + oc_block = MSMIN(oc_unit_num, c_block - oc); \ + const float *weight = packed_weight + oc * kernel_step; \ + if (bias != NULL) { \ + bias = bias_data + oc * oc_tile_; \ + } \ + float *dst_oc = dst_oh + oc * out_c_step; \ + const SWConvKernel kernel_border = kernel[oc_block - 1][0]; \ + if (oh < top || oh >= bottom) { /* oh in up or down border */ \ + SWBorder##backend(dst_oc, input_data, weight, bias, oh, oh + 1, 0, out_w, conv_param, sw_param, \ + kernel_border, act_type, 1, oc_block, write_mode); \ + } else { /* oh in center */ \ + /* ow in right */ \ + SWBorder##backend(dst_oc, input_data, weight, bias, oh, oh + 1, 0, left, conv_param, sw_param, \ + kernel_border, act_type, 1, oc_block, write_mode); \ + /* ow in center */ \ + const float *src_w = src_h + (oh - top) * in_sh_step; \ + int ow_block = ow_block_num[oc_block - 1]; \ + for (int ow = left; ow < right; ow += ow_block) { /* left ~ right */ \ + ow_block = MSMIN(ow_block, right - ow); \ + compute_core src_w += ow_block * in_sw_step; \ + } \ + /* ow in left */ \ + SWBorder##backend(dst_oc, input_data, weight, bias, oh, oh + 1, right, out_w, conv_param, sw_param, \ + kernel_border, act_type, 1, oc_block, write_mode); \ + } \ + } \ + } /* output h loop */ \ + input_data += in_step; \ + output_data += out_step; \ + } /* batch loop */ \ + } +#endif // MINDSPORE_NNACL_FP32_CONV_SW_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_arm64_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_arm64_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..5b15d4fbf20c87940fb97df17d9a3a0045b56c2d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_arm64_fp32.c @@ -0,0 +1,99 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/conv_sw_arm64_fp32.h" +#include "nnacl_c/fp32/conv_sw.h" + +bool CheckArm64UseSWConv(const ConvParameter *conv_param) { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + return false; + } + if (conv_param->input_channel_ > C128NUM) { + return false; + } + if (conv_param->kernel_h_ > C5NUM || conv_param->kernel_w_ > C5NUM) { + return false; + } + if (conv_param->dilation_h_ != 1 || conv_param->dilation_w_ != 1) { + return false; + } + if (conv_param->stride_w_ > C3NUM) { + return false; + } + if (conv_param->input_h_ / conv_param->kernel_h_ < C48NUM || conv_param->input_w_ / conv_param->kernel_w_ < C48NUM) { + return false; + } + return true; +} + +typedef void (*SWConvKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv2x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv2x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv3x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv3x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv5x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv5x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +#define ROW_NUM_LIST const int ow_block_num[2] = {5, 5}; +#define KERNEL_LIST \ + const SWConvKernel kernel[2][5] = { \ + {SWConv1x8Kernel, SWConv2x8Kernel, SWConv3x8Kernel, SWConv4x8Kernel, SWConv5x8Kernel}, \ + {SWConv1x16Kernel, SWConv2x16Kernel, SWConv3x16Kernel, SWConv4x16Kernel, SWConv5x16Kernel}}; +#define COMPUTE_CORE \ + kernel[oc_block - 1][ow_block - 1](dst_oc + ow * out_w_step, src_w, weight, bias, kernel_h, kernel_w, act_type, \ + out_block_step, ic_algin, in_kw_step, in_kh_step, in_sw_step, 0, write_mode); +#define OUTER_COMPUTE \ + kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, \ + sw_param->out_block_step_, sw_param->ic_align_, sw_param->in_kw_step_, sw_param->in_kh_step_, \ + sw_param->in_sw_step_, (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block * sw_param->ic_align_, \ + write_mode); +GenerateConvSWFunc(Arm64, C2NUM, ROW_NUM_LIST, KERNEL_LIST, COMPUTE_CORE, OUTER_COMPUTE); diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_arm64_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_arm64_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..cb38240d35f046ffd1594559fc415176fcaab123 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_arm64_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_CONV_SW_ARM64_FP32_H_ +#define NNACL_FP32_CONV_SW_ARM64_FP32_H_ +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +bool CheckArm64UseSWConv(const ConvParameter *conv_param); +void ConvSWArm64Fp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data, + int task_id, ConvParameter *conv_param, SlidingWindowParam *sw_param); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_CONV_SW_ARM64_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1c979e63578582ba4d839e2d990c21ece5f4a8fe --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.c @@ -0,0 +1,1231 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/conv_sw_avx_fp32.h" +#include "nnacl_c/fp32/conv_sw.h" +#include "nnacl_c/intrinsics/ms_simd_avx_instructions.h" + +void SWConv3x32AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_sw_step *= sizeof(float); + in_kw_step *= sizeof(float); + float *dst_4 = dst + out_step * C3NUM; + out_step *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups 0x40(%2), %%ymm6\n" + "vmovups 0x60(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups 0x40(%2), %%ymm10\n" + "vmovups 0x60(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vbroadcastss (%%rdx), %%ymm13\n" + "vbroadcastss (%%rdx, %8), %%ymm14\n" + "vbroadcastss (%%rdx, %8, 2), %%ymm15\n" + "vmovups (%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 0x20(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 0x40(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 0x60(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + "addq $128, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %9, %1\n" + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder) // 9 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", + "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" + "vmovups %%ymm4, (%2, %1, 1)\n" + "vmovups %%ymm5, 0x20(%2, %1, 1)\n" + "vmovups %%ymm6, 0x40(%2, %1, 1)\n" + "vmovups %%ymm7, 0x60(%2, %1, 1)\n" + "vmovups %%ymm8, (%2, %1, 2)\n" + "vmovups %%ymm9, 0x20(%2, %1, 2)\n" + "vmovups %%ymm10, 0x40(%2, %1, 2)\n" + "vmovups %%ymm11, 0x60(%2, %1, 2)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm4, 0x20(%2)\n" + "vmovups %%ymm8, 0x40(%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm5, 0x20(%2, %1, 1)\n" + "vmovups %%ymm9, 0x40(%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm6, 0x20(%2, %1, 2)\n" + "vmovups %%ymm10, 0x40(%2, %1, 2)\n" + "vmovups %%ymm3, (%4)\n" + "vmovups %%ymm7, 0x20(%4)\n" + "vmovups %%ymm11, 0x40(%4)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode), "r"(dst_4) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void SWConv1x32AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + float *dst_4 = dst + out_step * C3NUM; + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // Loopw + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vbroadcastss (%%rdx), %%ymm4\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n" + "addq $128, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %8, %1\n" + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(kw_remainder) // 8 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + + "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%4)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode), "r"(dst_4) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm14"); +} + +void SWConv4x24AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + in_sw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * out_step; + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %0\n" + "je 0f\n" + "vmovups (%0), %%ymm0\n" + "vmovups 0x20(%0), %%ymm1\n" + "vmovups 0x40(%0), %%ymm2\n" + // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. + "vmovups (%0), %%ymm3\n" + "vmovups 0x20(%0), %%ymm4\n" + "vmovups 0x40(%0), %%ymm5\n" + "vmovups (%0), %%ymm6\n" + "vmovups 0x20(%0), %%ymm7\n" + "vmovups 0x40(%0), %%ymm8\n" + "vmovups (%0), %%ymm9\n" + "vmovups 0x20(%0), %%ymm10\n" + "vmovups 0x40(%0), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" + : + : "r"(bias) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11", + "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vmovups (%1), %%ymm12\n" + "vmovups 0x20(%1), %%ymm13\n" + "vmovups 0x40(%1), %%ymm14\n" + + "vbroadcastss (%%rdx), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm2\n" + + "vbroadcastss (%%rdx, %8), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm3\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm4\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm5\n" + + "vbroadcastss (%%rdx, %8, 2), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm6\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm7\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm8\n" + + "addq %2, %%rdx\n" // src_3 + "vbroadcastss (%%rdx), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm9\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm10\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm11\n" + + "subq %2, %%rdx\n" + "addq $96, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %9, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(src_3_step), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder) + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", + "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "cmpq $13, %4\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, (%2, %1, 1)\n" + "vmovups %%ymm4, 0x20(%2, %1, 1)\n" + "vmovups %%ymm5, 0x40(%2, %1, 1)\n" + "vmovups %%ymm6, (%2, %1, 2)\n" + "vmovups %%ymm7, 0x20(%2, %1, 2)\n" + "vmovups %%ymm8, 0x40(%2, %1, 2)\n" + "vmovups %%ymm9, (%3)\n" // dst+3 + "vmovups %%ymm10, 0x20(%3)\n" + "vmovups %%ymm11, 0x40(%3)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm3, 0x20(%2)\n" + "vmovups %%ymm6, 0x40(%2)\n" + "vmovups %%ymm9, 0x60(%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm4, 0x20(%2, %1, 1)\n" + "vmovups %%ymm7, 0x40(%2, %1, 1)\n" + "vmovups %%ymm10, 0x60(%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm5, 0x20(%2, %1, 2)\n" + "vmovups %%ymm8, 0x40(%2, %1, 2)\n" + "vmovups %%ymm11, 0x60(%2, %1, 2)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(write_mode) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void SWConv1x24AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vbroadcastss (%%rdx), %%ymm3\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm3, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm3, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm3, %%ymm2\n" + "addq $96, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(kw_remainder) // 8 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + + "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "jmp 2f\n" + "1:\n" + // write to nc4hw4 + "vmovups %%ymm0, (%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm14"); +} + +void SWConv6x16AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + in_sw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * out_step; + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %0\n" + "je 0f\n" + "vmovups (%0), %%ymm0\n" + "vmovups 0x20(%0), %%ymm1\n" + // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. + "vmovups (%0), %%ymm2\n" + "vmovups 0x20(%0), %%ymm3\n" + "vmovups (%0), %%ymm4\n" + "vmovups 0x20(%0), %%ymm5\n" + "vmovups (%0), %%ymm6\n" + "vmovups 0x20(%0), %%ymm7\n" + "vmovups (%0), %%ymm8\n" + "vmovups 0x20(%0), %%ymm9\n" + "vmovups (%0), %%ymm10\n" + "vmovups 0x20(%0), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" + : + : "r"(bias) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11", + "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vmovups (%1), %%ymm12\n" + "vmovups 0x20(%1), %%ymm13\n" + + "vbroadcastss (%%rdx), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm1\n" + + "vbroadcastss (%%rdx, %8), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm2\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm3\n" + + "vbroadcastss (%%rdx, %8, 2), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm4\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm5\n" + + "addq %2, %%rdx\n" // src_3 + "vbroadcastss (%%rdx), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm6\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm7\n" + + "vbroadcastss (%%rdx, %8), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm8\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm9\n" + + "vbroadcastss (%%rdx, %8, 2), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm10\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm11\n" + + "subq %2, %%rdx\n" + "addq $64, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %9, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(src_3_step), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder) // 9 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", + "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "cmpq $13, %4\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, (%2, %1, 1)\n" + "vmovups %%ymm3, 0x20(%2, %1, 1)\n" + "vmovups %%ymm4, (%2, %1, 2)\n" + "vmovups %%ymm5, 0x20(%2, %1, 2)\n" + "vmovups %%ymm6, (%3)\n" // dst+3 + "vmovups %%ymm7, 0x20(%3)\n" + "vmovups %%ymm8, (%3, %1, 1)\n" + "vmovups %%ymm9, 0x20(%3, %1, 1)\n" + "vmovups %%ymm10, (%3, %1, 2)\n" + "vmovups %%ymm11, 0x20(%3, %1, 2)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm2, 0x20(%2)\n" + "vmovups %%ymm4, 0x40(%2)\n" + "vmovups %%ymm6, 0x60(%2)\n" // dst+3 + "vmovups %%ymm8, 0x80(%2)\n" + "vmovups %%ymm10, 0xA0(%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm3, 0x20(%2, %1, 1)\n" + "vmovups %%ymm5, 0x40(%2, %1, 1)\n" + "vmovups %%ymm7, 0x60(%2, %1, 1)\n" + "vmovups %%ymm9, 0x80(%2, %1, 1)\n" + "vmovups %%ymm11, 0xA0(%2, %1, 1)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(write_mode) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void SWConv1x16AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vbroadcastss (%%rdx), %%ymm3\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm3, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm3, %%ymm1\n" + "addq $64, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(kw_remainder) // 8 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm3"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + + "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "jmp 2f\n" + "1:\n" + // write nc8hw8 + "vmovups %%ymm0, (%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode) + : "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm14"); +} + +void SWConv12x8AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_sw_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * out_step; + float *dst_5 = dst + 5 * out_step; + float *dst_9 = dst + 9 * out_step; + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %0\n" + "je 0f\n" + "vmovups (%0), %%ymm0\n" + "vmovups (%0), %%ymm1\n" + "vmovups (%0), %%ymm2\n" + "vmovups (%0), %%ymm3\n" + "vmovups (%0), %%ymm4\n" + "vmovups (%0), %%ymm5\n" + "vmovups (%0), %%ymm6\n" + "vmovups (%0), %%ymm7\n" + "vmovups (%0), %%ymm8\n" + "vmovups (%0), %%ymm9\n" + "vmovups (%0), %%ymm10\n" + "vmovups (%0), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" + : + : "r"(bias) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + + asm volatile( + "LoopH:\n" + "movq %3, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "LoopW:\n" + "movq %%rcx, %%rdx\n" + "movq %4, %%r12\n" // ic_algin + "LoopIC:\n" + "vmovups (%1), %%ymm12\n" + "addq $32, %1\n" + "vbroadcastss (%%rdx), %%ymm13\n" + "vbroadcastss (%%rdx, %7), %%ymm14\n" + "vbroadcastss (%%rdx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "addq %8, %%rdx\n" + "vbroadcastss (%%rdx), %%ymm13\n" + "vbroadcastss (%%rdx, %7), %%ymm14\n" + "vbroadcastss (%%rdx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "addq %8, %%rdx\n" + "vbroadcastss (%%rdx), %%ymm13\n" + "vbroadcastss (%%rdx, %7), %%ymm14\n" + "vbroadcastss (%%rdx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "addq %8, %%rdx\n" + "vbroadcastss (%%rdx), %%ymm13\n" + "vbroadcastss (%%rdx, %7), %%ymm14\n" + "vbroadcastss (%%rdx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "subq %8, %%rdx\n" + "subq %8, %%rdx\n" + "subq %8, %%rdx\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg LoopIC\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg LoopW\n" + + "addq %6, %0\n" // in_kh_step + "addq %9, %1\n" // border in sw need to add remainder data + "dec %2\n" + "jg LoopH\n" + : + : "r"(src), "r"(weight), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), "r"(in_kh_step), // 6 + "r"(in_sw_step), "r"(src_3_step), "r"(kw_remainder) // 9 + : "%rcx", "%rdx", "%r12", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", + "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je Write\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je Write\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "Write:\n" + "cmpq $13, %6\n" + "je WriteNC8HW8\n" + // write nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%3)\n" // dst_3 + "vmovups %%ymm4, (%2, %1, 4)\n" + "vmovups %%ymm5, (%4)\n" // dst_5 + "vmovups %%ymm6, (%4, %1, 1)\n" + "vmovups %%ymm7, (%4, %1, 2)\n" + "vmovups %%ymm8, (%2, %1, 8)\n" + "vmovups %%ymm9, (%5)\n" // dst_9 + "vmovups %%ymm10, (%5, %1, 1)\n" + "vmovups %%ymm11, (%5, %1, 2)\n" + "jmp End\n" + "WriteNC8HW8:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" // dst_3 + "vmovups %%ymm4, 0x80(%2)\n" + "vmovups %%ymm5, 0xA0(%2)\n" // dst_5 + "vmovups %%ymm6, 0xC0(%2)\n" + "vmovups %%ymm7, 0xE0(%2)\n" + "vmovups %%ymm8, 0x100(%2)\n" + "vmovups %%ymm9, 0x120(%2)\n" // dst_9 + "vmovups %%ymm10, 0x140(%2)\n" + "vmovups %%ymm11, 0x160(%2)\n" + "End:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "r"(write_mode) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void SWConv4x8AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_sw_step *= sizeof(float); + in_kw_step *= sizeof(float); + size_t src_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * out_step; + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %0\n" + "je 0f\n" + "vmovups (%0), %%ymm0\n" + "vmovups (%0), %%ymm1\n" + "vmovups (%0), %%ymm2\n" + "vmovups (%0), %%ymm3\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "1:\n" + : + : "r"(bias) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + + asm volatile( + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vmovups (%1), %%ymm12\n" + "movq %%rdx, %%rax\n" + "addq $32, %1\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vbroadcastss (%%rax, %8), %%ymm14\n" + "vbroadcastss (%%rax, %8, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "addq %9, %%rax\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %2, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(kw_remainder), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(in_sw_step), "r"(src_step) // 9 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%3)\n" // dst_3 + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm12", "%ymm14"); +} + +void SWConv1x8AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vbroadcastss (%%rdx), %%ymm1\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm1, %%ymm0\n" + "addq $32, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(kw_remainder) // 8 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + + "0:\n" + // write to nhec and nc8hw8 is identical! + "vmovups %%ymm0, (%2)\n" // dst_0 + : + : "a"(act_flag), "r"(out_step), "r"(dst) + : "%ecx", "%ymm0", "%ymm12", "%ymm14"); +} + +typedef void (*SWConvKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, + size_t kw_remainder, size_t write_mode); + +#define ROW_NUM_LIST const int ow_block_num[4] = {12, 6, 4, 3}; +#define KERNEL_LIST \ + const SWConvKernel kernel[4][2] = {{SWConv1x8AVXKernel, SWConv12x8AVXKernel}, \ + {SWConv1x16AVXKernel, SWConv6x16AVXKernel}, \ + {SWConv1x24AVXKernel, SWConv4x24AVXKernel}, \ + {SWConv1x32AVXKernel, SWConv3x32AVXKernel}}; +#define COMPUTE_CORE \ + if (ow_block < ow_block_num[oc_block - 1]) { \ + ow_block = 1; \ + } \ + kernel[oc_block - 1][ow_block / ow_block_num[oc_block - 1]]( \ + dst_oc + ow * out_w_step, src_w, weight, bias, kernel_h, kernel_w, act_type, ow_block, oc_block, out_block_step, \ + ic_algin, in_kw_step, in_kh_step, in_sw_step, 0, write_mode); +#define OUTER_COMPUTE \ + kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, ow_bock, \ + oc_block, sw_param->out_block_step_, sw_param->ic_align_, sw_param->in_kw_step_, sw_param->in_kh_step_, \ + sw_param->in_sw_step_, (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block * sw_param->ic_align_, \ + write_mode); + +GenerateConvSWFunc(AVX, C4NUM, ROW_NUM_LIST, KERNEL_LIST, COMPUTE_CORE, OUTER_COMPUTE); + +#ifdef ENABLE_DEBUG +void SWConvWxKAVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + __m256 dst_data[12]; + const float *src_kh[12]; + const float *src_kw[12]; + __m256 weight_data[4]; + for (int i = 0; i < ow_block; ++i) { + if (bias != NULL) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_loadu_ps(bias + j * C8NUM); + } + } else { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_set1_ps(0.0f); + } + } + src_kh[i] = src + i * in_sw_step; + src_kw[i] = NULL; + } + const float *weight_kernel = weight; + for (int kh = 0; kh < kernel_h; kh++) { + for (int i = 0; i < ow_block; ++i) { + src_kw[i] = src_kh[i]; + } + for (int kw = 0; kw < kernel_w; kw++) { + for (int ic = 0; ic < ic_algin; ++ic) { + for (int j = 0; j < oc_block; ++j) { + weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM); + } + for (int i = 0; i < ow_block; ++i) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] += src_kw[i][ic] * weight_data[j]; + } + } + weight_kernel += C8NUM * oc_block; + } // ic loop + for (int i = 0; i < ow_block; ++i) { + src_kw[i] += in_kw_step; + } + } // kernel_w loop + weight_kernel += kw_remainder; + for (int i = 0; i < ow_block; ++i) { + src_kh[i] += in_kh_step; + } + } // kernel_h loop + // add bias and relu + for (int i = 0; i < ow_block; ++i) { + for (int j = 0; j < oc_block; ++j) { + if (0x1 & act_flag) { // relu6 + dst_data[i * oc_block + j] = _mm256_min_ps(dst_data[i * oc_block + j], _mm256_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f)); + } + if (write_mode == C13NUM) { + // write nc8hw8 + _mm256_storeu_ps(dst + j * out_step + i * C8NUM, dst_data[i * oc_block + j]); + } else { + // write nhwc + _mm256_storeu_ps(dst + i * out_step + j * C8NUM, dst_data[i * oc_block + j]); + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..ec4bd2b358f32df02bc5c49a39c361cb1fd53728 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.h @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CONV_SW_AVX_H_ +#define MINDSPORE_NNACL_FP32_CONV_SW_AVX_H_ + +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ConvSWAVXFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data, + int task_id, ConvParameter *conv_param, SlidingWindowParam *sw_param); + +#ifdef ENABLE_DEBUG +void SWConvWxKAVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_SW_AVX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d4e1672132378f32c9365bc0090a6e31f52a8b07 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.c @@ -0,0 +1,265 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/conv_winograd_fp32.h" +#include +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/fp32/winograd_transform.h" +#include "nnacl_c/fp32/matmul_fp32.h" + +// fp32 conv winograd +void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, const ConvParameter *conv_param, + TransFuncList trans_func) { + if (conv_param->output_unit_ == 0) { + return; + } + int in_channel = conv_param->input_channel_; + int input_unit = conv_param->input_unit_; + int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); + int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); + int output_count = out_w_block * out_h_block; + const int tile_num = C12NUM; + int output_tile_count = UP_DIV(output_count, tile_num); +#ifdef ENABLE_AVX + const int col_tile = C16NUM; + const int channel_pack_tile = C8NUM; +#else + const int col_tile = C8NUM; + const int channel_pack_tile = C4NUM; +#endif + int oc_tile = UP_DIV(conv_param->output_channel_, col_tile); + int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); + int input_unit_square = input_unit * input_unit; + + float *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel; + float *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM; + float *tmp_data = buffer_list[2] + task_id * input_unit_square * channel_pack_tile; + float *col_buffer = buffer_list[3] + task_id * tile_num * in_channel; + // step 1 : filter transform (pre-processed offline) + // step 2 : input transform (online) + + int block_per_thread = UP_DIV(output_tile_count, conv_param->thread_num_); + int start_index = block_per_thread * task_id * tile_num; + if (start_index >= output_count) { + return; + } + int end_index = MSMIN(start_index + block_per_thread * tile_num, output_count); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_w_ * conv_param->output_h_; + + for (int out_tile_index = start_index; out_tile_index < end_index; out_tile_index += tile_num) { + int cal_num = output_count - out_tile_index; + cal_num = cal_num > tile_num ? tile_num : cal_num; + if (cal_num <= 0) { + return; + } + +#ifdef ENABLE_ARM64 + // Optimize input transform. Only valid for arm64, the tile num is 12, the channel_tile is 4. + // For arm32, the tile_num is 4. + // For x86_sse, the tile_num is 4, the channel_tile is 4. + // For avx, the tile_num is 6, the channel_tile is 8. + // N = input_unit, M = tile_num + // The function(InputTransformNxNStep, InputTransform4x4PackM) needs to be rewritten. + bool fused_pack = + (cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL); + if (fused_pack) { + float *opt_trans_input = + buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, channel_pack_tile); + WinogradInputTransformOptStep(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_step_func_); + + for (int w_index = 0; w_index < input_unit; w_index++) { + float *src_w = opt_trans_input + w_index * input_unit * tile_num * channel_pack_tile; + for (int c = 0; c < UP_DIV(in_channel, channel_pack_tile); c++) { + int real_c = in_channel - c * channel_pack_tile; + real_c = real_c > channel_pack_tile ? channel_pack_tile : real_c; + float *src_c = src_w + c * input_unit_square * tile_num * channel_pack_tile; + float *dst_c = trans_input + c * tile_num * channel_pack_tile; + trans_func.in_pack_func_(src_c, dst_c, channel_pack_tile, in_channel * tile_num, real_c); + } + + for (int h_index = 0; h_index < input_unit; h_index++) { + const float *gemm_input = trans_input + h_index * tile_num * in_channel; + int point_index = h_index * input_unit + w_index; + const float *gemm_weight = trans_weight + point_index * in_channel * oc_tile * col_tile; + MatMulOpt(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num, + oc8 * C8NUM, input_unit_square, OutType_TileC8); + } + } + } else { +#endif + WinogradInputTransform(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_func_); + // step 3 : gemm + float *src_ptr = trans_input; + float *dst_ptr = gemm_out; + float *tmp_col_ptr = col_buffer; + for (int i = 0; i < input_unit_square; ++i) { +#ifdef ENABLE_AVX + RowMajor2Col6Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) + RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#else + RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#endif + MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc_tile * col_tile, dst_ptr + i * C8NUM, NULL, 0, + in_channel, cal_num, oc8 * C8NUM, input_unit_square, 2); + } +#ifdef ENABLE_ARM64 + } +#endif + + // step 4 : output transform + float *output_ptr = output_data + out_batch_offset; + if (conv_param->out_format_ != Format_NC4HW4) { // nc4hw4 + WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); + } else { +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) + WinogradOutputNC4HW4Transform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); +#else + WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); +#endif + } + } + } +} + +// fp32 conv winograd +void ConvWinogardFp32CutByBatch(const float *input_data, const float *trans_weight, const float *bias_data, + float *output_data, TmpBufferAddress *buffer_list, int task_id, + const ConvParameter *conv_param, TransFuncList trans_func) { + if (conv_param->output_unit_ == 0) { + return; + } + int in_channel = conv_param->input_channel_; + int input_unit = conv_param->input_unit_; + int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); + int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); + int output_count = out_w_block * out_h_block; + const int tile_num = C12NUM; +#ifdef ENABLE_AVX + const int col_tile = C16NUM; + const int channel_pack_tile = C8NUM; +#else + const int col_tile = C8NUM; + const int channel_pack_tile = C4NUM; +#endif + int oc_tile = UP_DIV(conv_param->output_channel_, col_tile); + int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); + int input_unit_square = input_unit * input_unit; + + float *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel; + float *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM; + float *tmp_data = buffer_list[2] + task_id * input_unit_square * channel_pack_tile; + float *col_buffer = buffer_list[3] + task_id * tile_num * in_channel; + // step 1 : filter transform (pre-processed offline) + // step 2 : input transform (online) + + int block_batch_per_thread = UP_DIV(conv_param->input_batch_, conv_param->thread_num_); + int start_batch = block_batch_per_thread * task_id; + int end_batch = MSMIN(conv_param->input_batch_, (start_batch + block_batch_per_thread)); + + for (int b = start_batch; b < end_batch; b++) { + int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_w_ * conv_param->output_h_; + + for (int out_tile_index = 0; out_tile_index < output_count; out_tile_index += tile_num) { + int cal_num = output_count - out_tile_index; + cal_num = cal_num > tile_num ? tile_num : cal_num; + if (cal_num <= 0) { + return; + } + +#ifdef ENABLE_ARM64 + // Optimize input transform. Only valid for arm64, the tile num is 12, the channel_tile is 4. + // For arm32, the tile_num is 4. + // For x86_sse, the tile_num is 4, the channel_tile is 4. + // For avx, the tile_num is 6, the channel_tile is 8. + // N = input_unit, M = tile_num + // The function(InputTransformNxNStep, InputTransform4x4PackM) needs to be rewritten. + bool fused_pack = + (cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL); + if (fused_pack) { + float *opt_trans_input = + buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, channel_pack_tile); + WinogradInputTransformOptStep(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_step_func_); + + for (int w_index = 0; w_index < input_unit; w_index++) { + float *src_w = opt_trans_input + w_index * input_unit * tile_num * channel_pack_tile; + for (int c = 0; c < UP_DIV(in_channel, channel_pack_tile); c++) { + int real_c = in_channel - c * channel_pack_tile; + real_c = real_c > channel_pack_tile ? channel_pack_tile : real_c; + float *src_c = src_w + c * input_unit_square * tile_num * channel_pack_tile; + float *dst_c = trans_input + c * tile_num * channel_pack_tile; + trans_func.in_pack_func_(src_c, dst_c, channel_pack_tile, in_channel * tile_num, real_c); + } + + for (int h_index = 0; h_index < input_unit; h_index++) { + const float *gemm_input = trans_input + h_index * tile_num * in_channel; + int point_index = h_index * input_unit + w_index; + const float *gemm_weight = trans_weight + point_index * in_channel * oc_tile * col_tile; + MatMulOpt(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num, + oc8 * C8NUM, input_unit_square, OutType_TileC8); + } + } + } else { +#endif + WinogradInputTransform(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_func_); + // step 3 : gemm + float *src_ptr = trans_input; + float *dst_ptr = gemm_out; + float *tmp_col_ptr = col_buffer; + for (int i = 0; i < input_unit_square; ++i) { +#ifdef ENABLE_AVX + RowMajor2Col6Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) + RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#else + RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#endif + MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc_tile * col_tile, dst_ptr + i * C8NUM, NULL, 0, + in_channel, cal_num, oc8 * C8NUM, input_unit_square, 2); + } +#ifdef ENABLE_ARM64 + } +#endif + + // step 4 : output transform + float *output_ptr = output_data + out_batch_offset; + if (conv_param->out_format_ != Format_NC4HW4) { // nc4hw4 + WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); + } else { +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) + WinogradOutputNC4HW4Transform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); +#else + WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); +#endif + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..6d3b49823608991da9e549dd90ae90aad19538c6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_WINOGRAD_H_ +#define MINDSPORE_NNACL_FP32_CONV_WINOGRAD_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/fp32/winograd_utils.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// fp32 convolution winograd +void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, const ConvParameter *conv_param, + TransFuncList trans_func); + +void ConvWinogardFp32CutByBatch(const float *input_data, const float *trans_weight, const float *bias_data, + float *output_data, TmpBufferAddress *buffer_list, int task_id, + const ConvParameter *conv_param, TransFuncList trans_func); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_WINOGRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/crop_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/crop_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c171deca2792923427d1499ca017c8b7de31bdb8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/crop_fp32.c @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/crop_fp32.h" +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/crop_parameter.h" + +void Pad4DOffset(const CropParameter *crop_param, int64_t *offset, int length) { + int axis = crop_param->axis_; + for (int i = length - 1; i >= 0; --i) { + int offset_index = i - axis; + if (offset_index >= 0 && offset_index < COMM_SHAPE_SIZE) { + offset[i] = crop_param->offset_[offset_index]; + } else { + offset[i] = 0; + } + } +} + +void Crop4D(const float *input, float *output, const int32_t *in_shape, const int32_t *out_shape, + const CropParameter *crop_param, int thread_id, int thread_num) { + int64_t offset_pad[DIMENSION_4D] = {0}; + Pad4DOffset(crop_param, offset_pad, DIMENSION_4D); + int out_shape1 = out_shape[1]; + int out_shape2 = out_shape[2]; + int out_shape3 = out_shape[3]; + size_t out_stride2 = out_shape3; + size_t out_stride1 = out_stride2 * out_shape2; + size_t out_stride0 = out_stride1 * out_shape1; + size_t in_stride2 = in_shape[3]; + size_t in_stride1 = in_stride2 * in_shape[2]; + size_t in_stride0 = in_stride1 * in_shape[1]; + size_t copy_size = out_shape3 * sizeof(float); + + size_t count_per_thread = UP_DIV(out_shape1, thread_num); + size_t thread_stride = thread_id * count_per_thread; + for (int i = 0; i < out_shape[0]; ++i) { + size_t out_offset0 = i * out_stride0; + size_t in_offset0 = (i + offset_pad[0]) * in_stride0 + offset_pad[3]; + for (size_t j = 0; j < count_per_thread; ++j) { + size_t k = j + thread_stride; + if (k >= out_shape1) { + break; + } + size_t out_offset1 = k * out_stride1 + out_offset0; + size_t in_offset1 = (k + offset_pad[1]) * in_stride1 + in_offset0; + for (int l = 0; l < out_shape2; ++l) { + size_t out_offset = l * out_stride2 + out_offset1; + size_t in_offset = (l + offset_pad[2]) * in_stride2 + in_offset1; + memcpy(output + out_offset, input + in_offset, copy_size); + } + } + } +} + +void Crop4DNoParallel(const float *input, float *output, const int32_t *in_shape, const int32_t *out_shape, + const CropParameter *crop_param) { + int64_t offset_pad[DIMENSION_4D] = {0}; + Pad4DOffset(crop_param, offset_pad, DIMENSION_4D); + size_t in_dim2_stride = in_shape[3]; + size_t in_dim1_stride = in_shape[2] * in_dim2_stride; + size_t in_dim0_stride = in_dim1_stride * in_shape[1]; + size_t offset_3 = offset_pad[3]; + size_t out_offset = 0; + size_t copy_num = out_shape[3]; + size_t copy_size = copy_num * sizeof(float); + size_t in_dim0_end = offset_pad[0] + out_shape[0]; + size_t in_dim1_end = offset_pad[1] + out_shape[1]; + size_t in_dim2_end = offset_pad[2] + out_shape[2]; + for (int i = offset_pad[0]; i < in_dim0_end; ++i) { + size_t dim0_offset = (size_t)i * in_dim0_stride + offset_3; + for (int j = offset_pad[1]; j < in_dim1_end; ++j) { + size_t dim1_offset = (size_t)j * in_dim1_stride + dim0_offset; + for (int k = offset_pad[2]; k < in_dim2_end; ++k) { + size_t in_offset = dim1_offset + (size_t)k * in_dim2_stride; + memcpy(output + out_offset, input + in_offset, copy_size); + out_offset += copy_num; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/crop_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/crop_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..07b66e6268fb8b6f163da617c5604444f9c05394 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/crop_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_CROP_FP32_H_ +#define NNACL_FP32_CROP_FP32_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/crop_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Crop4D(const float *input, float *output, const int32_t *in_shape, const int32_t *out_shape, + const CropParameter *crop_param, int thread_id, int thread_num); +void Crop4DNoParallel(const float *input, float *output, const int32_t *in_shape, const int32_t *out_shape, + const CropParameter *crop_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_CROP_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1900c27bc4a51391cc14271429eb00c5d3a2a57b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32.c @@ -0,0 +1,200 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/cumsum_fp32.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/cumsum_fp32_simd.h" + +// (a, b, c) -> (a, a+b, a+b+c) exclusive == false +// (a, b, c) -> (0, a, a+b) exclusive == true +void Cumsum(const float *input, float *output, int out_dim, int axis_dim, int inner_dim, bool exclusive) { + // when not exclusive, output axis dim[0] is the same as that of input. + // when exclusive, output axis dim[0] is 0.0f + if (!exclusive) { + for (int i = 0; i < out_dim; ++i) { + const float *layer_input = input + i * axis_dim * inner_dim; + float *layer_output = output + i * axis_dim * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumOutputInitWithInput, j, layer_input, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output + j) = *(layer_input + j); + } + } + } else { + for (int i = 0; i < out_dim; ++i) { + float *layer_output = output + i * axis_dim * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumOutputInitWithZero, j, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output + j) = 0.0f; + } + } + } + int input_offset = exclusive ? 0 : 1; + for (int i = 0; i < out_dim; ++i) { + const float *layer_input = input + i * axis_dim * inner_dim + inner_dim * input_offset; + float *layer_last_output = output + i * axis_dim * inner_dim; + float *layer_output = layer_last_output + inner_dim; + + for (int j = 1; j < axis_dim; ++j) { + int k = 0; + SIMD_RUN_NO_SCALAR(Cumsum, k, layer_input, layer_output, layer_last_output, inner_dim); + for (; k < inner_dim; ++k) { + // layer_output (i, j, k) = layer_input (i, j, k) + layer_last_output (i,j-1, k) + *(layer_output + k) = *(layer_input + k) + *(layer_last_output + k); + } + layer_input += inner_dim; + layer_last_output += inner_dim; + layer_output += inner_dim; + } + } +} + +// (a, b, c) -> (c+b+a, c+b, c) exclusive==false +// (a, b, c) -> (c+b, c, 0) exclusive==true +void CumsumReverse(const float *input, float *output, int out_dim, int axis_dim, int inner_dim, bool exclusive) { + if (!exclusive) { + for (int i = 0; i < out_dim; ++i) { + const float *layer_input = input + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + float *layer_output = output + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumOutputInitWithInput, j, layer_input, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output + j) = *(layer_input + j); + } + } + } else { + for (int i = 0; i < out_dim; ++i) { + float *layer_output = output + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumOutputInitWithZero, j, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output + j) = 0.0f; + } + } + } + int input_offset = exclusive ? 0 : 1; + for (int i = 0; i < out_dim; ++i) { + const float *layer_input = input + (i + 1) * axis_dim * inner_dim - 1 - input_offset * inner_dim; + float *layer_last_output = output + (i + 1) * axis_dim * inner_dim - 1; + float *layer_output = layer_last_output - inner_dim; + + for (int j = 1; j < axis_dim; ++j) { + int k = 0; + SIMD_RUN_NO_SCALAR(CumsumReverse, k, layer_input, layer_output, layer_last_output, inner_dim); + for (; k < inner_dim; ++k) { + *(layer_output - k) = *(layer_input - k) + *(layer_last_output - k); + } + layer_input -= inner_dim; + layer_last_output -= inner_dim; + layer_output -= inner_dim; + } + } +} + +// (a, b, c) -> (a, a+b, a+b+c) exclusive == false +// (a, b, c) -> (0, a, a+b) exclusive == true +void CumsumInt(const int32_t *input, int32_t *output, int out_dim, int axis_dim, int inner_dim, bool exclusive) { + // when not exclusive, output axis dim[0] is the same as that of input. + // when exclusive, output axis dim[0] is 0 + if (!exclusive) { + for (int i = 0; i < out_dim; ++i) { + const int32_t *layer_input = input + i * axis_dim * inner_dim; + int32_t *layer_output = output + i * axis_dim * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumIntOutputInitWithInput, j, layer_input, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output + j) = *(layer_input + j); + } + } + } else { + for (int i = 0; i < out_dim; ++i) { + int32_t *layer_output = output + i * axis_dim * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumIntOutputInitWithZero, j, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output++) = 0; + } + } + } + int input_offset = exclusive ? 0 : 1; + for (int i = 0; i < out_dim; ++i) { + const int32_t *layer_input = input + i * axis_dim * inner_dim + inner_dim * input_offset; + int32_t *layer_last_output = output + i * axis_dim * inner_dim; + int32_t *layer_output = layer_last_output + inner_dim; + + for (int j = 1; j < axis_dim; ++j) { + int k = 0; + SIMD_RUN_NO_SCALAR(CumsumInt, k, layer_input, layer_output, layer_last_output, inner_dim); + for (; k < inner_dim; ++k) { + *(layer_output + k) = *(layer_input + k) + *(layer_last_output + k); + } + layer_input += inner_dim; + layer_last_output += inner_dim; + layer_output += inner_dim; + } + } +} + +// (a, b, c) -> (c+b+a, c+b, c) exclusive==false +// (a, b, c) -> (c+b, c, 0) exclusive==true +void CumsumReverseInt(const int32_t *input, int32_t *output, int out_dim, int axis_dim, int inner_dim, bool exclusive) { + if (!exclusive) { + for (int i = 0; i < out_dim; ++i) { + const int32_t *layer_input = input + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + int32_t *layer_output = output + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumIntOutputInitWithInput, j, layer_input, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output++) = *(layer_input++); + } + } + } else { + for (int i = 0; i < out_dim; ++i) { + int32_t *layer_output = output + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumIntOutputInitWithZero, j, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output++) = 0.0f; + } + } + } + int input_offset = exclusive ? 0 : 1; + for (int i = 0; i < out_dim; ++i) { + const int32_t *layer_input = input + (i + 1) * axis_dim * inner_dim - 1 - input_offset * inner_dim; + int32_t *layer_last_output = output + (i + 1) * axis_dim * inner_dim - 1; + int32_t *layer_output = layer_last_output - inner_dim; + + for (int j = 1; j < axis_dim; ++j) { + int k = 0; + SIMD_RUN_NO_SCALAR(CumsumReverseInt, k, layer_input, layer_output, layer_last_output, inner_dim); + for (; k < inner_dim; ++k) { + *(layer_output - k) = *(layer_input - k) + *(layer_last_output - k); + } + layer_input -= inner_dim; + layer_last_output -= inner_dim; + layer_output -= inner_dim; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..6b30cfc97d6ddcdeb5dabd4144adc6207699fb56 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CUMSUM_H_ +#define MINDSPORE_NNACL_FP32_CUMSUM_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/cumsum_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Cumsum(const float *input, float *output, int out_dim, int axis_dim, int inner_dim, bool exclusive); +void CumsumReverse(const float *input, float *output, int out_dim, int axis_dim, int inner_dim, bool exclusive); +void CumsumInt(const int32_t *input, int32_t *output, int out_dim, int axis_dim, int inner_dim, bool exclusive); +void CumsumReverseInt(const int32_t *input, int32_t *output, int out_dim, int axis_dim, int inner_dim, bool exclusive); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CUMSUM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..ad5aa2877bf7e1d86c065dcdc08958d2136ba43e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32_simd.h.in @@ -0,0 +1,114 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CUMSUM_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_CUMSUM_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +// (a, b, c) -> (a, a+b, a+b+c) exclusive == false +// (a, b, c) -> (0, a, a+b) exclusive == true +static inline int64_t CumsumOutputInitWithInput@SIMD_INSTRUCTION@(int64_t index, const float *layer_input, + float *layer_output, int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(layer_output + index, SIMD_LD_F32(layer_input + index)); + } + return index; +} + +static inline int64_t CumsumOutputInitWithZero@SIMD_INSTRUCTION@(int64_t index, float *layer_output, int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(layer_output + index, SIMD_MOV_F32(0.0f)); + } + return index; +} + +static inline int64_t Cumsum@SIMD_INSTRUCTION@(int64_t index, const float *layer_input, float *layer_output, float *layer_last_output, + int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input_val = SIMD_LD_F32(layer_input + index); + SIMD_F32 last_output_val = SIMD_LD_F32(layer_last_output + index); + SIMD_F32 out_val = SIMD_ADD_F32(input_val, last_output_val); + SIMD_ST_F32(layer_output + index, out_val); + } + return index; +} + +// (a, b, c) -> (c+b+a, c+b, c) exclusive==false +// (a, b, c) -> (c+b, c, 0) exclusive==true +static inline int64_t CumsumReverse@SIMD_INSTRUCTION@(int64_t index, const float *layer_input, float *layer_output, + float *layer_last_output, int inner_dim) { + + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input_val = SIMD_LD_F32(layer_input - index - BLOCK_NUM + 1); + SIMD_F32 last_output_val = SIMD_LD_F32(layer_last_output - index - BLOCK_NUM + 1); + SIMD_F32 out_val = SIMD_ADD_F32(input_val, last_output_val); + SIMD_ST_F32(layer_output - index - BLOCK_NUM + 1, out_val); + } + return index; +} + +// (a, b, c) -> (a, a+b, a+b+c) exclusive == false +// (a, b, c) -> (0, a, a+b) exclusive == true +static inline int64_t CumsumIntOutputInitWithInput@SIMD_INSTRUCTION@(int64_t index, const int32_t *layer_input, + int32_t *layer_output, int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(layer_output + index, SIMD_LD_EPI32(layer_input + index)); + } + return index; +} + +static inline int64_t CumsumIntOutputInitWithZero@SIMD_INSTRUCTION@(int64_t index, int32_t *layer_output, int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(layer_output + index, SIMD_MOV_EPI32(0.0f)); + } + return index; +} + +static inline int64_t CumsumInt@SIMD_INSTRUCTION@(int64_t index, const int32_t *layer_input, int32_t *layer_output, int32_t *layer_last_output, + int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 input_val = SIMD_LD_EPI32(layer_input + index); + SIMD_EPI32 last_output_val = SIMD_LD_EPI32(layer_last_output + index); + SIMD_EPI32 out_val = SIMD_ADD_EPI32(input_val, last_output_val); + SIMD_ST_EPI32(layer_output + index, out_val); + } + return index; +} + +// (a, b, c) -> (c+b+a, c+b, c) exclusive==false +// (a, b, c) -> (c+b, c, 0) exclusive==true +static inline int64_t CumsumReverseInt@SIMD_INSTRUCTION@(int64_t index, const int32_t *layer_input, int32_t *layer_output, int32_t *layer_last_output, + int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 input_val = SIMD_LD_EPI32(layer_input - index - BLOCK_NUM + 1); + SIMD_EPI32 last_output_val = SIMD_LD_EPI32(layer_last_output - index - BLOCK_NUM + 1); + SIMD_EPI32 out_val = SIMD_ADD_EPI32(input_val, last_output_val); + SIMD_ST_EPI32(layer_output - index - BLOCK_NUM + 1, out_val); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/custom_gru_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/custom_gru_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..290ba8fdf394a406c5ae34fe41c75d6a2067edc5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/custom_gru_fp32.c @@ -0,0 +1,72 @@ +#ifdef ENABLE_ARM64 +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/custom_gru_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +void CustomGru(float *output, const float *input, const float *weight_input, const float *weight_hidden, + const float *bias_input, const float *bias_hidden, const float *init_h, float *buffer[4], + const CustomGruParameter *gru_param) { + int num_step = gru_param->num_step; + int batch_size = gru_param->batch_size; + int input_size = gru_param->input_size; + int hidden_size = gru_param->hidden_size; + int output_size = batch_size * hidden_size; + int double_output_size = output_size * C2NUM; + int col_align = UP_ROUND(hidden_size, C8NUM); + int weight_in_offset = col_align * input_size; + int weight_hidden_offset = col_align * hidden_size; + float *input_gate = buffer[1]; + float *hidden_gate = buffer[C3NUM]; + for (int i = 0; i < num_step; ++i) { + if (batch_size != 1) { + RowMajor2Col12MajorParallel(input + i * batch_size * input_size, buffer[0], batch_size, input_size, 0, + batch_size); + for (int j = 0; j < C3NUM; ++j) { + MatMulOpt(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + RowMajor2Col12MajorParallel(init_h, buffer[C2NUM], batch_size, hidden_size, 0, batch_size); + for (int j = 0; j < C3NUM; ++j) { + MatMulOpt(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + } else { + for (int j = 0; j < C3NUM; ++j) { + MatVecMulPackFp32(input + i * input_size, weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, hidden_size); + MatVecMulPackFp32(init_h, weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, hidden_size); + } + } + ElementAdd(input_gate, hidden_gate, input_gate, double_output_size); + Sigmoid(input_gate, double_output_size, input_gate); + ElementMul(input_gate, hidden_gate + double_output_size, input_gate, output_size); + ElementAdd(input_gate, input_gate + double_output_size, input_gate, output_size); + Tanh(input_gate, output_size, input_gate); + ElementSub(init_h, input_gate, hidden_gate, output_size); + ElementMul(input_gate + output_size, hidden_gate, hidden_gate, output_size); + ElementAdd(input_gate, hidden_gate, output, output_size); + init_h = output; + output += output_size; + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/custom_gru_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/custom_gru_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..16c4774998f4b37f8d49a847c70f299414aa6386 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/custom_gru_fp32.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_ +#define MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_ +#ifdef ENABLE_ARM64 +#include "nnacl_c/custom_gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CustomGru(float *output, const float *input, const float *weight_input, const float *weight_hidden, + const float *bias_input, const float *bias_hidden, const float *init_h, float *buffer[4], + const CustomGruParameter *gru_param); +#ifdef __cplusplus +} +#endif + +#endif +#endif // MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..72bec29a2d17cceb539aab82d4e6d95a6794e172 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_fp32.c @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/deconv_fp32.h" + +void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane) { + /* ichwoc(nhwc) -> oc4 * h * w * incUP4 * 4 */ + int ic_up4 = UP_ROUND(input_channel, C4NUM); + for (int oc = 0; oc < output_channel; oc++) { + int oc4div = oc / C4NUM; + int oc4mod = oc % C4NUM; + for (int ic = 0; ic < input_channel; ic++) { + for (int hw = 0; hw < plane; hw++) { + int src_index = ic * plane * output_channel + hw * output_channel + oc; + int dst_index = oc4div * ic_up4 * plane * C4NUM + hw * ic_up4 * C4NUM + ic * C4NUM + oc4mod; + dst[dst_index] = weight[src_index]; + } + } + } + return; +} + +void DeConvPostFp32C8(const float *src, float *tmp, const float *bias, float *dst, int output_channel, + const ConvParameter *conv_param) { + /* arm64 row12x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ + /* arm32 row4x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ + size_t input_plane = conv_param->input_w_ * conv_param->input_h_; + size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + size_t output_plane = conv_param->output_w_ * conv_param->output_h_; + int oc8 = UP_ROUND(output_channel, C8NUM); +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) + const int tile_num = 4; +#else + const int tile_num = 12; +#endif + int in_plane_round = UP_ROUND(input_plane, tile_num); + int src_iw_stride = C8NUM; + int src_ih_stride = conv_param->input_w_ * C8NUM; + int src_kw_stride = in_plane_round * C8NUM; + int src_kh_stride = in_plane_round * conv_param->kernel_w_ * C8NUM; + int dst_oh_stride = conv_param->output_w_ * C8NUM; + int dst_ow_stride = C8NUM; + int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM; + int dst_kw_stride = conv_param->dilation_w_ * C8NUM; + if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) { + return; + } + for (int c = 0; c < oc8; c += 8) { + float *dst_ptr = tmp + c * output_plane; + const float *src_ptr = src + c * in_plane_round * kernel_plane; + memset(dst_ptr, 0, output_plane * C8NUM * (int)sizeof(float)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + for (int kh = kh_start; kh < kh_end; kh++) { + for (int kw = kw_start; kw < kw_end; kw++) { + int src_index = ih * src_ih_stride + iw * src_iw_stride + kh * src_kh_stride + kw * src_kw_stride; + int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride; + float *tmp_dst = dst_ptr + dst_index; + const float *tmp_src = src_ptr + src_index; +#ifdef ENABLE_ARM64 + asm volatile( + "mov x0, %[tmp_src] \n" + "mov x1, %[tmp_dst] \n" + + "ld1 {v0.4s, v1.4s}, [x0] \n" + "ld1 {v2.4s, v3.4s}, [x1] \n" + + "fadd v0.4s, v0.4s, v2.4s \n" + "fadd v1.4s, v1.4s, v3.4s \n" + + "st1 {v0.4s, v1.4s}, [x1] \n" + + : + : [tmp_src] "r"(tmp_src), [tmp_dst] "r"(tmp_dst) + : "x0", "x1", "v0", "v1", "v2", "v3"); +#else + for (int i = 0; i < C8NUM; i++) { + tmp_dst[i] += tmp_src[i]; + } +#endif + } /*kw*/ + } /*kh*/ + } /*iw*/ + } /*ih*/ + } /*oc8*/ + + PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->act_type_); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..28fe275e931dae2e14cdaac5e695e606d08ae0bd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_fp32.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_DECONV_H_ +#define MINDSPORE_NNACL_FP32_DECONV_H_ + +#include +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/base/minimal_filtering_generator.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane); +void DeConvPostFp32C8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel, + const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_DECONV_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..673762070cb7e539e55c297dc180e4aa58b4e662 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.c @@ -0,0 +1,733 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/deconv_winograd_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +int PackDeConvWgDataFp32(const float *nhwc_weight, DeConvComputeUnit *unit, const ConvParameter *conv_param, + const DeConvParam *deconv_param) { +#ifdef ENABLE_AVX + int tile_num = C8NUM; +#else + int tile_num = C4NUM; +#endif + unsigned int tmp_kernel_plane = unit->w_size_ * unit->h_size_; + unsigned int size = conv_param->input_channel_ * conv_param->output_channel_ * tmp_kernel_plane; + float *current_unit_weight = (float *)malloc(size * sizeof(float)); + if (current_unit_weight == NULL) { + return NNACL_NULL_PTR; + } + for (int ic = 0; ic < conv_param->input_channel_; ic++) { + const float *src_ic = nhwc_weight + deconv_param->kernel_plane_ * conv_param->output_channel_ * ic; + float *dst_ic = current_unit_weight + tmp_kernel_plane * conv_param->output_channel_ * ic; + for (int uhi = 0; uhi < unit->h_size_; uhi++) { + for (int uwi = 0; uwi < unit->w_size_; uwi++) { + int src_h_offset = unit->h_start_ + uhi * conv_param->stride_h_; + int src_w_offset = unit->w_start_ + uwi * conv_param->stride_w_; + const float *src_hw = + src_ic + (src_h_offset * conv_param->kernel_w_ + src_w_offset) * conv_param->output_channel_; + float *dst_hw = dst_ic + (uhi * unit->w_size_ + uwi) * conv_param->output_channel_; + memcpy(dst_hw, src_hw, conv_param->output_channel_ * sizeof(float)); + } + } + } + + if (unit->use_winograd_) { + /* Generate winograd */ + float matrix_g[64], matrix_a[64], matrix_b[64]; + float matrix_gt[64], matrix_at[64], matrix_bt[64]; + int ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 0.5f, + DECONV_WINOGRAD_DEFAULT_UNIT, unit->h_size_); + if (ret != NNACL_OK) { + free(current_unit_weight); + current_unit_weight = NULL; + return NNACL_ERRCODE_WINOGRAD_GENERATOR_ERROR; + } + + /* winograd AT */ + unit->winograd_.AT_ = malloc(unit->winograd_.i_ * unit->winograd_.o_ * sizeof(float)); + if (unit->winograd_.AT_ == NULL) { + if (current_unit_weight != NULL) { + free(current_unit_weight); + current_unit_weight = NULL; + } + return NNACL_NULL_PTR; + } + memcpy(unit->winograd_.AT_, matrix_at, unit->winograd_.i_ * unit->winograd_.o_ * sizeof(float)); + + /* winograd BT */ + unit->winograd_.BT_ = malloc(unit->winograd_.o_ * unit->winograd_.o_ * sizeof(float)); + if (unit->winograd_.BT_ == NULL) { + if (current_unit_weight != NULL) { + free(current_unit_weight); + current_unit_weight = NULL; + } + if (unit->winograd_.AT_ != NULL) { + free(unit->winograd_.AT_); + unit->winograd_.AT_ = NULL; + } + return NNACL_NULL_PTR; + } + memcpy(unit->winograd_.BT_, matrix_bt, unit->winograd_.o_ * unit->winograd_.o_ * sizeof(float)); + + /* winograd Weight */ + size = conv_param->input_channel_ * conv_param->output_channel_ * unit->winograd_.kh_ * unit->winograd_.kw_; + float *winograd_unit_weight = (float *)malloc(size * sizeof(float)); + if (winograd_unit_weight == NULL) { + if (current_unit_weight != NULL) { + free(current_unit_weight); + current_unit_weight = NULL; + } + if (unit->winograd_.AT_ != NULL) { + free(unit->winograd_.AT_); + unit->winograd_.AT_ = NULL; + } + if (unit->winograd_.BT_ != NULL) { + free(unit->winograd_.BT_); + unit->winograd_.BT_ = NULL; + } + return NNACL_NULL_PTR; + } + WinogradWeightTransform(current_unit_weight, winograd_unit_weight, matrix_g, matrix_gt, tile_num, + unit->winograd_.kh_, unit->h_size_, conv_param->output_channel_, conv_param->input_channel_, + false); + + /* reset weight data & info */ + tmp_kernel_plane = unit->winograd_.kh_ * unit->winograd_.kw_; + free(current_unit_weight); + current_unit_weight = NULL; + current_unit_weight = winograd_unit_weight; + winograd_unit_weight = NULL; + } + + /* trans mhwc -> hw1:k1-knc0-c4:k1-knc5-c8:hw2:k1-knc0-c4:k1 */ + float *dst_weight = (float *)unit->weight_; + size = deconv_param->ic_up_ * deconv_param->oc_up_ * tmp_kernel_plane; + memset(dst_weight, 0, size * sizeof(float)); + for (int ic = 0; ic < conv_param->input_channel_; ic++) { + for (int oc = 0; oc < conv_param->output_channel_; oc++) { + int oc4div = oc / tile_num, oc4mod = oc % tile_num; + for (int upi = 0; upi < tmp_kernel_plane; upi++) { + int src_index = ic * conv_param->output_channel_ * tmp_kernel_plane + upi * conv_param->output_channel_ + oc; + int dst_index = upi * deconv_param->oc_up_ * deconv_param->ic_up_ + oc4div * tile_num * deconv_param->ic_up_ + + ic * tile_num + oc4mod; + dst_weight[dst_index] = current_unit_weight[src_index]; + } + } + } + + if (current_unit_weight != NULL) { + free(current_unit_weight); + current_unit_weight = NULL; + } + return NNACL_OK; +} + +void DeConvWgInputPack(const float *src_ptr, float *dst_ptr, int channel, int stride) { +#ifdef ENABLE_AVX + int ic_tile = C8NUM; +#else + int ic_tile = C4NUM; +#endif + int ic4div = channel / ic_tile; + int ic4mod = channel % ic_tile; + const float *src = src_ptr; + float *dst = dst_ptr; + + for (int ic = 0; ic < ic4div; ic++) { +#ifdef ENABLE_AVX + MS_ST256_F32(dst, MS_LD256_F32(src)); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_STQ_F32(dst, MS_LDQ_F32(src)); +#else + memcpy(dst, src, C4NUM * sizeof(float)); +#endif + dst += stride; + src += ic_tile; + } + + if (ic4mod != 0) { + int ic_res = 0; + for (; ic_res < ic4mod; ic_res++) { + dst[ic_res] = src[ic_res]; + } + for (; ic_res < ic_tile; ic_res++) { + dst[ic_res] = 0; + } + } +} + +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) +void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) { + int dx, sz, dz; + const int src_depth_step = C4NUM * DECONV_WINOGRAD_DEFAULT_TILE; + for (dz = 0; dz < oc4; ++dz) { + float *dst_z = dst + dz * cal_num; + const float *weight_dz = weight + dz * ic4 * C16NUM; + for (dx = 0; dx < DECONV_WINOGRAD_DEFAULT_TILE; ++dx) { + float *dst_x = dst_z + dx * C4NUM; + dst_x[0] = 0.0f; + dst_x[1] = 0.0f; + dst_x[2] = 0.0f; + dst_x[3] = 0.0f; + const float *src_dx = src + C4NUM * dx; + for (sz = 0; sz < ic4; ++sz) { + const float *src_z = src_dx + sz * src_depth_step; + const float *weight_z = weight_dz + sz * C16NUM; + for (int i = 0; i < C4NUM; ++i) { + for (int j = 0; j < C4NUM; ++j) { + dst_x[j] += src_z[i] * weight_z[C4NUM * i + j]; + } + } + } + } + } +} +#endif + +#ifdef ENABLE_ARM32 +#ifndef SUPPORT_NNIE +void DeConvWgMergeArm32(const float *src_ptr, float *dst_ptr, size_t src_step, size_t dst_step) { + asm volatile( + "mov r11, %[src_ptr]\n" + "mov r8, %[dst_ptr]\n" + "mov r10, r8\n" + + "vld1.32 {q0}, [r11], %[src_step]\n" + "vld1.32 {q1}, [r8], %[dst_step]\n" + "vld1.32 {q2}, [r11], %[src_step]\n" + "vld1.32 {q3}, [r8], %[dst_step]\n" + + "vadd.f32 q0, q0, q1\n" + "vld1.32 {q8}, [r11], %[src_step]\n" + "vadd.f32 q2, q2, q3\n" + + "vst1.32 {q0}, [r10], %[dst_step]\n" + "vst1.32 {q2}, [r10], %[dst_step]\n" + + "vld1.32 {q9}, [r8], %[dst_step]\n" + + "vld1.32 {q10}, [r11], %[src_step]\n" + + "vadd.f32 q8, q8, q9\n" + "vld1.32 {q11}, [r8], %[dst_step]\n" + "vadd.f32 q10, q10, q11\n" + + "vld1.32 {q0}, [r11], %[src_step]\n" + "vst1.32 {q8}, [r10], %[dst_step]\n" + "vst1.32 {q10}, [r10], %[dst_step]\n" + + "vld1.32 {q1}, [r8], %[dst_step]\n" + + "vld1.32 {q2}, [r11], %[src_step]\n" + "vld1.32 {q3}, [r8], %[dst_step]\n" + + "vadd.f32 q0, q0, q1\n" + "vadd.f32 q2, q2, q3\n" + + "vst1.32 {q0}, [r10], %[dst_step]\n" + "vst1.32 {q2}, [r10], %[dst_step]\n" + + "vld1.32 {q8}, [r11], %[src_step]\n" + "vld1.32 {q9}, [r8], %[dst_step]\n" + + "vld1.32 {q10}, [r11], %[src_step]\n" + "vld1.32 {q11}, [r8], %[dst_step]\n" + + "vadd.f32 q8, q8, q9\n" + "vadd.f32 q10, q10, q11\n" + + "vst1.32 {q8}, [r10], %[dst_step]\n" + "vst1.32 {q10}, [r10], %[dst_step]\n" + + : + : [src_ptr] "r"(src_ptr), [dst_ptr] "r"(dst_ptr), [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "r8", "r10", "r11", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11"); + return; +} +#else +void DeConvWgMergeArm32(const float *src_ptr, float *dst_ptr, size_t src_step, size_t dst_step) { + asm volatile( + "mov r7, %[src_ptr]\n" + "mov r8, %[dst_ptr]\n" + "mov r10, r8\n" + + "vld1.32 {q0}, [r7], %[src_step]\n" + "vld1.32 {q1}, [r8], %[dst_step]\n" + "vld1.32 {q2}, [r7], %[src_step]\n" + "vld1.32 {q3}, [r8], %[dst_step]\n" + + "vadd.f32 q0, q0, q1\n" + "vld1.32 {q8}, [r7], %[src_step]\n" + "vadd.f32 q2, q2, q3\n" + + "vst1.32 {q0}, [r10], %[dst_step]\n" + "vst1.32 {q2}, [r10], %[dst_step]\n" + + "vld1.32 {q9}, [r8], %[dst_step]\n" + + "vld1.32 {q10}, [r7], %[src_step]\n" + + "vadd.f32 q8, q8, q9\n" + "vld1.32 {q11}, [r8], %[dst_step]\n" + "vadd.f32 q10, q10, q11\n" + + "vld1.32 {q0}, [r7], %[src_step]\n" + "vst1.32 {q8}, [r10], %[dst_step]\n" + "vst1.32 {q10}, [r10], %[dst_step]\n" + + "vld1.32 {q1}, [r8], %[dst_step]\n" + + "vld1.32 {q2}, [r7], %[src_step]\n" + "vld1.32 {q3}, [r8], %[dst_step]\n" + + "vadd.f32 q0, q0, q1\n" + "vadd.f32 q2, q2, q3\n" + + "vst1.32 {q0}, [r10], %[dst_step]\n" + "vst1.32 {q2}, [r10], %[dst_step]\n" + + "vld1.32 {q8}, [r7], %[src_step]\n" + "vld1.32 {q9}, [r8], %[dst_step]\n" + + "vld1.32 {q10}, [r7], %[src_step]\n" + "vld1.32 {q11}, [r8], %[dst_step]\n" + + "vadd.f32 q8, q8, q9\n" + "vadd.f32 q10, q10, q11\n" + + "vst1.32 {q8}, [r10], %[dst_step]\n" + "vst1.32 {q10}, [r10], %[dst_step]\n" + + : + : [src_ptr] "r"(src_ptr), [dst_ptr] "r"(dst_ptr), [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "r8", "r10", "r7", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11"); + return; +} +#endif +#endif + +#ifdef ENABLE_AVX +void DeConvWgMerge(const float *src, float *dst, size_t src_stride, size_t dst_stride, size_t count) { + const float *src_ptr = src; + float *dst_ptr = dst; + size_t count8 = count / C8NUM * C8NUM; + size_t count4 = count / C4NUM * C4NUM; + int i = 0; + for (; i < count8; i += C8NUM) { + MS_FLOAT32X8 src1 = MS_LD256_F32(src_ptr + 0 * src_stride); + MS_FLOAT32X8 src2 = MS_LD256_F32(src_ptr + 1 * src_stride); + MS_FLOAT32X8 src3 = MS_LD256_F32(src_ptr + 2 * src_stride); + MS_FLOAT32X8 src4 = MS_LD256_F32(src_ptr + 3 * src_stride); + MS_FLOAT32X8 src5 = MS_LD256_F32(src_ptr + 4 * src_stride); + MS_FLOAT32X8 src6 = MS_LD256_F32(src_ptr + 5 * src_stride); + MS_FLOAT32X8 src7 = MS_LD256_F32(src_ptr + 6 * src_stride); + MS_FLOAT32X8 src8 = MS_LD256_F32(src_ptr + 7 * src_stride); + MS_FLOAT32X8 dst1 = MS_LD256_F32(dst_ptr + 0 * dst_stride); + MS_FLOAT32X8 dst2 = MS_LD256_F32(dst_ptr + 1 * dst_stride); + MS_FLOAT32X8 dst3 = MS_LD256_F32(dst_ptr + 2 * dst_stride); + MS_FLOAT32X8 dst4 = MS_LD256_F32(dst_ptr + 3 * dst_stride); + MS_FLOAT32X8 dst5 = MS_LD256_F32(dst_ptr + 4 * dst_stride); + MS_FLOAT32X8 dst6 = MS_LD256_F32(dst_ptr + 5 * dst_stride); + MS_FLOAT32X8 dst7 = MS_LD256_F32(dst_ptr + 6 * dst_stride); + MS_FLOAT32X8 dst8 = MS_LD256_F32(dst_ptr + 7 * dst_stride); + dst1 = MS_ADD256_F32(dst1, src1); + dst2 = MS_ADD256_F32(dst2, src2); + dst3 = MS_ADD256_F32(dst3, src3); + dst4 = MS_ADD256_F32(dst4, src4); + dst5 = MS_ADD256_F32(dst5, src5); + dst6 = MS_ADD256_F32(dst6, src6); + dst7 = MS_ADD256_F32(dst7, src7); + dst8 = MS_ADD256_F32(dst8, src8); + MS_ST256_F32(dst_ptr + 0 * dst_stride, dst1); + MS_ST256_F32(dst_ptr + 1 * dst_stride, dst2); + MS_ST256_F32(dst_ptr + 2 * dst_stride, dst3); + MS_ST256_F32(dst_ptr + 3 * dst_stride, dst4); + MS_ST256_F32(dst_ptr + 4 * dst_stride, dst5); + MS_ST256_F32(dst_ptr + 5 * dst_stride, dst6); + MS_ST256_F32(dst_ptr + 6 * dst_stride, dst7); + MS_ST256_F32(dst_ptr + 7 * dst_stride, dst8); + src_ptr += C8NUM * src_stride; + dst_ptr += C8NUM * dst_stride; + } + for (; i < count4; i += C4NUM) { + MS_FLOAT32X8 src1 = MS_LD256_F32(src_ptr + 0 * src_stride); + MS_FLOAT32X8 src2 = MS_LD256_F32(src_ptr + 1 * src_stride); + MS_FLOAT32X8 src3 = MS_LD256_F32(src_ptr + 2 * src_stride); + MS_FLOAT32X8 src4 = MS_LD256_F32(src_ptr + 3 * src_stride); + MS_FLOAT32X8 dst1 = MS_LD256_F32(dst_ptr + 0 * dst_stride); + MS_FLOAT32X8 dst2 = MS_LD256_F32(dst_ptr + 1 * dst_stride); + MS_FLOAT32X8 dst3 = MS_LD256_F32(dst_ptr + 2 * dst_stride); + MS_FLOAT32X8 dst4 = MS_LD256_F32(dst_ptr + 3 * dst_stride); + dst1 = MS_ADD256_F32(dst1, src1); + dst2 = MS_ADD256_F32(dst2, src2); + dst3 = MS_ADD256_F32(dst3, src3); + dst4 = MS_ADD256_F32(dst4, src4); + MS_ST256_F32(dst_ptr + 0 * dst_stride, dst1); + MS_ST256_F32(dst_ptr + 1 * dst_stride, dst2); + MS_ST256_F32(dst_ptr + 2 * dst_stride, dst3); + MS_ST256_F32(dst_ptr + 3 * dst_stride, dst4); + src_ptr += C4NUM * src_stride; + dst_ptr += C4NUM * dst_stride; + } + for (; i < count; i++) { + MS_FLOAT32X8 src_data = MS_LD256_F32(src_ptr); + MS_FLOAT32X8 dst_data = MS_LD256_F32(dst_ptr); + dst_data = MS_ADD256_F32(src_data, dst_data); + MS_ST256_F32(dst_ptr, dst_data); + src_ptr += src_stride; + dst_ptr += dst_stride; + } +} +#else +void DeConvWgMerge(const float *src, float *dst, size_t src_stride, size_t dst_stride, size_t count) { + const float *src_ptr = src; + float *dst_ptr = dst; + size_t count8 = count / C8NUM * C8NUM; + size_t count4 = count / C4NUM * C4NUM; + int i = 0; + for (; i < count8; i += C8NUM) { +#ifdef ENABLE_ARM64 + size_t src_step = src_stride * sizeof(float); + size_t dst_step = dst_stride * sizeof(float); + asm volatile( + "mov x7, %[src_ptr]\n" + "mov x8, %[dst_ptr]\n" + "mov x10, x8\n" + + "ld1 {v0.4s}, [x7], %[src_step]\n" + "ld1 {v1.4s}, [x8], %[dst_step]\n" + + "ld1 {v2.4s}, [x7], %[src_step]\n" + "ld1 {v3.4s}, [x8], %[dst_step]\n" + + "fadd v0.4s, v0.4s, v1.4s\n" + "ld1 {v4.4s}, [x7], %[src_step]\n" + "fadd v2.4s, v2.4s, v3.4s\n" + + "st1 {v0.4s}, [x10], %[dst_step]\n" + "st1 {v2.4s}, [x10], %[dst_step]\n" + + "ld1 {v5.4s}, [x8], %[dst_step]\n" + + "ld1 {v6.4s}, [x7], %[src_step]\n" + + "fadd v4.4s, v4.4s, v5.4s\n" + "ld1 {v7.4s}, [x8], %[dst_step]\n" + "fadd v6.4s, v6.4s, v7.4s\n" + + "ld1 {v0.4s}, [x7], %[src_step]\n" + "st1 {v4.4s}, [x10], %[dst_step]\n" + "st1 {v6.4s}, [x10], %[dst_step]\n" + + "ld1 {v1.4s}, [x8], %[dst_step]\n" + + "ld1 {v2.4s}, [x7], %[src_step]\n" + "ld1 {v3.4s}, [x8], %[dst_step]\n" + + "fadd v0.4s, v0.4s, v1.4s\n" + "fadd v2.4s, v2.4s, v3.4s\n" + + "st1 {v0.4s}, [x10], %[dst_step]\n" + "st1 {v2.4s}, [x10], %[dst_step]\n" + + "ld1 {v4.4s}, [x7], %[src_step]\n" + "ld1 {v5.4s}, [x8], %[dst_step]\n" + + "ld1 {v6.4s}, [x7], %[src_step]\n" + "ld1 {v7.4s}, [x8], %[dst_step]\n" + + "fadd v4.4s, v4.4s, v5.4s\n" + "fadd v6.4s, v6.4s, v7.4s\n" + + "st1 {v4.4s}, [x10], %[dst_step]\n" + "st1 {v6.4s}, [x10], %[dst_step]\n" + + : + : [src_ptr] "r"(src_ptr), [dst_ptr] "r"(dst_ptr), [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#elif defined(ENABLE_ARM32) + size_t src_step = src_stride * sizeof(float); + size_t dst_step = dst_stride * sizeof(float); + DeConvWgMergeArm32(src_ptr, dst_ptr, src_step, dst_step); +#elif defined(ENABLE_SSE) + MS_STQ_F32(dst_ptr + 0 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 0 * dst_stride), MS_LDQ_F32(src_ptr + 0 * src_stride))); + MS_STQ_F32(dst_ptr + 1 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 1 * dst_stride), MS_LDQ_F32(src_ptr + 1 * src_stride))); + MS_STQ_F32(dst_ptr + 2 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 2 * dst_stride), MS_LDQ_F32(src_ptr + 2 * src_stride))); + MS_STQ_F32(dst_ptr + 3 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 3 * dst_stride), MS_LDQ_F32(src_ptr + 3 * src_stride))); + MS_STQ_F32(dst_ptr + C4NUM * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + C4NUM * dst_stride), MS_LDQ_F32(src_ptr + C4NUM * src_stride))); + MS_STQ_F32(dst_ptr + 5 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 5 * dst_stride), MS_LDQ_F32(src_ptr + 5 * src_stride))); + MS_STQ_F32(dst_ptr + 6 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 6 * dst_stride), MS_LDQ_F32(src_ptr + 6 * src_stride))); + MS_STQ_F32(dst_ptr + 7 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 7 * dst_stride), MS_LDQ_F32(src_ptr + 7 * src_stride))); +#else + for (int j = 0; j < C8NUM; j++) { + const float *s = src_ptr + j * src_stride; + float *d = dst_ptr + j * dst_stride; + for (int k = 0; k < C4NUM; k++) { + d[k] += s[k]; + } + } +#endif + src_ptr += C8NUM * src_stride; + dst_ptr += C8NUM * dst_stride; + } + for (; i < count4; i += C4NUM) { +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) + MS_STQ_F32(dst_ptr + 0 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 0 * dst_stride), MS_LDQ_F32(src_ptr + 0 * src_stride))); + MS_STQ_F32(dst_ptr + 1 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 1 * dst_stride), MS_LDQ_F32(src_ptr + 1 * src_stride))); + MS_STQ_F32(dst_ptr + 2 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 2 * dst_stride), MS_LDQ_F32(src_ptr + 2 * src_stride))); + MS_STQ_F32(dst_ptr + 3 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 3 * dst_stride), MS_LDQ_F32(src_ptr + 3 * src_stride))); +#else + for (int j = 0; j < C4NUM; j++) { + const float *s = src_ptr + j * src_stride; + float *d = dst_ptr + j * dst_stride; + for (int k = 0; k < C4NUM; k++) { + d[k] += s[k]; + } + } +#endif + src_ptr += C4NUM * src_stride; + dst_ptr += C4NUM * dst_stride; + } + for (; i < count; i++) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src_data = MS_LDQ_F32(src_ptr); + MS_FLOAT32X4 dst_data = MS_LDQ_F32(dst_ptr); + dst_data = MS_ADDQ_F32(src_data, dst_data); + MS_STQ_F32(dst_ptr, dst_data); +#else + for (int j = 0; j < C4NUM; j++) { + dst_ptr[j] += src_ptr[j]; + } +#endif + src_ptr += src_stride; + dst_ptr += dst_stride; + } +} +#endif + +void DeConvWgCalWgFp32(const float *tile_in, float *tile_out, const float *weight_buf, float *tmp_buf, + const float *at_buf, float *a_mid_buf, float *trans_a_buf, bool *transferred, + const float *bt_buf, float *b_tmp_buf, int unit_size, int w_start, int h_start, + const ConvParameter *conv_param, const DeConvParam *deconv_param) { +#ifdef ENABLE_AVX + int tile_num = C8NUM; + TiledMatmulFp32 matmul_func = TiledC8MatmulFp32; +#else + TiledMatmulFp32 matmul_func = TiledC4MatmulFp32; + int tile_num = C4NUM; +#endif + int winograd_plane = unit_size * unit_size; + if (!transferred[unit_size]) { + WinogradTransLeft(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, + deconv_param->ic_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransRight(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, + deconv_param->ic_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + transferred[unit_size] = true; + } + + for (int index = 0; index < winograd_plane; index++) { + float *src = trans_a_buf + index * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float *dst = tmp_buf + index * deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + const float *weight = weight_buf + index * deconv_param->ic_up_ * deconv_param->oc_up_; + matmul_func(dst, src, weight, DECONV_WINOGRAD_DEFAULT_TILE * tile_num, deconv_param->ic_div_, + deconv_param->oc_div_); + } + WinogradTransLeft(tmp_buf, bt_buf, b_tmp_buf, unit_size, unit_size, unit_size, + deconv_param->oc_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransRight(b_tmp_buf, bt_buf, tmp_buf, unit_size, unit_size, unit_size, + deconv_param->oc_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + + // Add to dest + for (int uhi = 0; uhi < unit_size; uhi++) { + int h_index = uhi * conv_param->stride_h_ + h_start; + for (int uwi = 0; uwi < unit_size; uwi++) { + int w_index = uwi * conv_param->stride_w_ + w_start; + + float *dst = tile_out + w_index * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_ + + h_index * deconv_param->out_tile_w_ * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + float *src = tmp_buf + (uwi + uhi * unit_size) * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + DeConvWgMerge(src, dst, tile_num, tile_num, DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_div_); + } + } +} + +void DeConvWgCalCommFp32(const float *tile_in, float *tile_out, const float *weight, float *tmp_buf, int h_start, + int w_start, int h_size, int w_size, const ConvParameter *conv_param, + const DeConvParam *deconv_param) { +#ifdef ENABLE_AVX + int tile_num = C8NUM; + TiledMatmulFp32 matmul_func = TiledC8MatmulFp32; +#else + TiledMatmulFp32 matmul_func = TiledC4MatmulFp32; + int tile_num = C4NUM; +#endif + int count = deconv_param->oc_div_ * w_size * h_size; + int in_stride = DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + int out_stride = DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + + for (int hi = 0; hi < DECONV_WINOGRAD_DEFAULT_UNIT; hi++) { + for (int wi = 0; wi < DECONV_WINOGRAD_DEFAULT_UNIT; wi++) { + const float *src_in = tile_in + (wi + hi * DECONV_WINOGRAD_DEFAULT_UNIT) * in_stride; + matmul_func(tmp_buf, src_in, weight, DECONV_WINOGRAD_DEFAULT_TILE * tile_num, deconv_param->ic_div_, count); + + for (int uhi = 0; uhi < h_size; uhi++) { + for (int uwi = 0; uwi < w_size; uwi++) { + int w_index = (wi + uwi) * conv_param->stride_w_ + w_start; + int h_index = (hi + uhi) * conv_param->stride_h_ + h_start; + float *dst = tile_out + h_index * out_stride * deconv_param->out_tile_w_ + w_index * out_stride; + float *src = tmp_buf + (uwi + uhi * w_size) * out_stride; + DeConvWgMerge(src, dst, tile_num, tile_num, DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_div_); + } + } + } + } +} + +int DeconvWg(const float *nhwc_input_, float *tile_in, float *tile_out, int start_index, int calculate_count, + const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id) { + if (deconv_param->in_tile_w_count_ == 0) { + return NNACL_ERR; + } + /* pack tile input */ + int tile_in_unit_stride = deconv_param->ic_up_ * DECONV_WINOGRAD_DEFAULT_TILE; +#ifdef ENABLE_AVX + int tile_num = C8NUM; + MS_FLOAT32X8 zero = MS_MOV256_F32(0.0f); +#else + int tile_num = C4NUM; +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f); +#endif +#endif + for (int unit_index = 0; unit_index < calculate_count; unit_index++) { + int plane_index = start_index + unit_index; + int w_unit_index = plane_index % deconv_param->in_tile_w_count_; + int h_unit_index = plane_index / deconv_param->in_tile_w_count_; + int w_start = w_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT; + int h_start = h_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT; + + float *dst_unit = tile_in + unit_index * tile_num; + for (int hi = 0; hi < DECONV_WINOGRAD_DEFAULT_UNIT; hi++) { + for (int wi = 0; wi < DECONV_WINOGRAD_DEFAULT_UNIT; wi++) { + float *dst = dst_unit + (wi + hi * DECONV_WINOGRAD_DEFAULT_UNIT) * tile_in_unit_stride; + int w_index = w_start + wi; + int h_index = h_start + hi; + if (w_index >= conv_param->input_w_ || h_index >= conv_param->input_h_) { + for (int ic4_index = 0; ic4_index < deconv_param->ic_div_; ic4_index++) { +#ifdef ENABLE_AVX + MS_ST256_F32(dst + ic4_index * DECONV_WINOGRAD_DEFAULT_TILE * tile_num, zero); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_STQ_F32(dst + ic4_index * DECONV_WINOGRAD_DEFAULT_TILE * tile_num, zero); +#else + for (int i = 0; i < tile_num; i++) { + dst[tile_num * DECONV_WINOGRAD_DEFAULT_TILE * ic4_index + i] = 0; + } +#endif + } + continue; + } + + const float *src = nhwc_input_ + (w_index + h_index * conv_param->input_w_) * conv_param->input_channel_; + DeConvWgInputPack(src, dst, conv_param->input_channel_, DECONV_WINOGRAD_DEFAULT_TILE * tile_num); + } + } + } + + /* compute */ + bool transferred[DECONV_WINOGRAD_BUFFER_COUNT] = {false}; + for (int i = 0; i < deconv_param->compute_size_; i++) { + DeConvComputeUnit *unit = &deconv_param->compute_units_[i]; + if (unit->use_winograd_) { + float *tmp_buf = (float *)unit->tmp_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ * + deconv_param->oc_div_ * DECONV_WINOGRAD_DEFAULT_TILE * tile_num; + + /* winograd a buffer */ + if (unit->winograd_.kh_ >= DECONV_WINOGRAD_BUFFER_COUNT || unit->winograd_.AT_ == NULL) { + return NNACL_ERR; + } + DeConvWgABuffer *wg_buf = &deconv_param->a_buffer_[unit->winograd_.kh_]; + float *wg_mid_a_buf = (float *)wg_buf->middle_buffer_ + task_id * unit->winograd_.kw_ * unit->winograd_.kh_ * + DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float *wg_dst_a_buf = (float *)wg_buf->dest_buffer_ + task_id * unit->winograd_.kw_ * unit->winograd_.kh_ * + DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float *tmp_b_buf = (float *)unit->winograd_.b_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ * + deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + DeConvWgCalWgFp32(tile_in, tile_out, (float *)unit->weight_, tmp_buf, unit->winograd_.AT_, wg_mid_a_buf, + wg_dst_a_buf, transferred, unit->winograd_.BT_, tmp_b_buf, unit->winograd_.kh_, unit->w_start_, + unit->h_start_, conv_param, deconv_param); + } else { + float *tmp_buf = (float *)unit->tmp_buffer_ + task_id * deconv_param->oc_div_ * unit->w_size_ * unit->h_size_ * + DECONV_WINOGRAD_DEFAULT_TILE * tile_num; + DeConvWgCalCommFp32(tile_in, tile_out, (float *)unit->weight_, tmp_buf, unit->h_start_, unit->w_start_, + unit->h_size_, unit->w_size_, conv_param, deconv_param); + } + } + return NNACL_OK; +} + +int DeconvWgPost(const float *tile_out, float *nc4hw4_output, const ConvParameter *conv_param, + const DeConvParam *deconv_param, int calculate_count, int tile_index) { +#ifdef ENABLE_AVX + int tile_num = C8NUM; +#else + int tile_num = C4NUM; +#endif + + /* merge */ + int src_unit_stride = deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + + int src_stride = DECONV_WINOGRAD_DEFAULT_TILE * tile_num; + int dst_stride = conv_param->output_w_ * conv_param->output_h_ * tile_num; + + for (int index = 0; index < calculate_count; ++index) { + const float *src_start = tile_out + index * tile_num; + + int plane_index = tile_index * DECONV_WINOGRAD_DEFAULT_TILE + index; + int w_unit_index = plane_index % deconv_param->in_tile_w_count_; + int h_unit_index = plane_index / deconv_param->in_tile_w_count_; + int w_start = w_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT * conv_param->stride_w_ - conv_param->pad_l_; + int h_start = h_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT * conv_param->stride_h_ - conv_param->pad_u_; + float *dst_start = nc4hw4_output + h_start * conv_param->output_w_ * tile_num + w_start * tile_num; + + int merge_w_start = MSMAX(-w_start, 0); + int merge_h_start = MSMAX(-h_start, 0); + int merge_h_end = MSMIN(deconv_param->out_tile_h_, conv_param->output_h_ - h_start); + int merge_w_end = MSMIN(deconv_param->out_tile_w_, conv_param->output_w_ - w_start); + + for (int hi = merge_h_start; hi < merge_h_end; hi++) { + for (int wi = merge_w_start; wi < merge_w_end; wi++) { + const float *src = src_start + (hi * deconv_param->out_tile_w_ + wi) * src_unit_stride; + float *dst = dst_start + (hi * conv_param->output_w_ + wi) * tile_num; + DeConvWgMerge(src, dst, src_stride, dst_stride, deconv_param->oc_div_); + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..576ce4e769700e692f12b88081601a45cd739275 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_DECONV_WINOGRAD_H_ +#define MINDSPORE_NNACL_FP32_DECONV_WINOGRAD_H_ + +#include +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/base/minimal_filtering_generator.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*TiledMatmulFp32)(float *dst, const float *src, const float *weight, size_t ic_tiled, size_t cal_num, + size_t oc_tiled); + +int PackDeConvWgDataFp32(const float *nhwc_weight, DeConvComputeUnit *unit, const ConvParameter *conv_param, + const DeConvParam *deconv_param); +int DeconvWg(const float *nhwc_input_, float *tile_in, float *tile_out, int start_index, int calculate_count, + const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id); +int DeconvWgPost(const float *tile_out, float *nc4hw4_output, const ConvParameter *conv_param, + const DeConvParam *deconv_param, int calculate_count, int tile_index); +void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t ic4, size_t cal_num, size_t oc4); +void TiledC8MatmulFp32(float *dst, const float *src, const float *weight, size_t ic8, size_t cal_num, size_t oc8); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_DECONV_WINOGRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/detection_post_process_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/detection_post_process_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..45c29cbfc6328b465aa79384a8f882409cff6bae --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/detection_post_process_fp32.c @@ -0,0 +1,235 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/detection_post_process_fp32.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/nnacl_utils.h" + +float IntersectionOverUnion(const BboxCorner *a, const BboxCorner *b) { + const float area_a = (a->ymax - a->ymin) * (a->xmax - a->xmin); + const float area_b = (b->ymax - b->ymin) * (b->xmax - b->xmin); + if (area_a <= 0 || area_b <= 0) { + return 0.0f; + } + const float ymin = a->ymin > b->ymin ? a->ymin : b->ymin; + const float xmin = a->xmin > b->xmin ? a->xmin : b->xmin; + const float ymax = a->ymax < b->ymax ? a->ymax : b->ymax; + const float xmax = a->xmax < b->xmax ? a->xmax : b->xmax; + const float h = ymax - ymin > 0.0f ? ymax - ymin : 0.0f; + const float w = xmax - xmin > 0.0f ? xmax - xmin : 0.0f; + const float inter = h * w; + return inter / (area_a + area_b - inter); +} + +int DecodeBoxes(int num_boxes, const float *input_boxes, const float *anchors, + const DetectionPostProcessParameter *param) { + if (input_boxes == NULL || anchors == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + float *decoded_boxes = (float *)param->decoded_boxes_; + BboxCenter scaler; + scaler.y = param->y_scale_; + scaler.x = param->x_scale_; + scaler.h = param->h_scale_; + scaler.w = param->w_scale_; + for (int i = 0; i < num_boxes; ++i) { + BboxCenter *box = (BboxCenter *)(input_boxes) + i; + BboxCenter *anchor = (BboxCenter *)(anchors) + i; + BboxCorner *decoded_box = (BboxCorner *)(decoded_boxes) + i; + float y_center = box->y / scaler.y * anchor->h + anchor->y; + float x_center = box->x / scaler.x * anchor->w + anchor->x; + const float h_half = 0.5f * expf(box->h / scaler.h) * anchor->h; + const float w_half = 0.5f * expf(box->w / scaler.w) * anchor->w; + decoded_box->ymin = y_center - h_half; + decoded_box->xmin = x_center - w_half; + decoded_box->ymax = y_center + h_half; + decoded_box->xmax = x_center + w_half; + } + return NNACL_OK; +} + +int NmsSingleClass(const int num_boxes, const float *decoded_boxes, const int max_detections, const float *scores, + int32_t *selected, void (*PartialArgSort)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param) { + if (PartialArgSort == NULL) { + return NNACL_NULL_PTR; + } + uint8_t *nms_candidate = param->nms_candidate_; + const int output_num = num_boxes < max_detections ? num_boxes : max_detections; + int possible_candidate_num = num_boxes; + int selected_num = 0; + int32_t *indexes = (int32_t *)param->single_class_indexes_; + for (int i = 0; i < num_boxes; ++i) { + indexes[i] = i; + nms_candidate[i] = 1; + } + PartialArgSort(scores, indexes, num_boxes, num_boxes); + for (int i = 0; i < num_boxes; ++i) { + if (possible_candidate_num == 0 || selected_num >= output_num || scores[indexes[i]] < param->nms_score_threshold_) { + break; + } + if (nms_candidate[indexes[i]] == 0) { + continue; + } + selected[selected_num++] = indexes[i]; + nms_candidate[indexes[i]] = 0; + possible_candidate_num--; + const BboxCorner *bbox_i = (BboxCorner *)(decoded_boxes) + indexes[i]; + for (int t = i + 1; t < num_boxes; ++t) { + if (scores[indexes[t]] < param->nms_score_threshold_) break; + if (nms_candidate[indexes[t]] == 1) { + const BboxCorner *bbox_t = (BboxCorner *)(decoded_boxes) + indexes[t]; + const float iou = IntersectionOverUnion(bbox_i, bbox_t); + if (iou > param->nms_iou_threshold_) { + nms_candidate[indexes[t]] = 0; + possible_candidate_num--; + } + } + } + } + return selected_num; +} + +int NmsMultiClassesFastCore(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + void (*PartialArgSort)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param, const int task_id, const int thread_num) { + if (input_scores == NULL || param == NULL || PartialArgSort == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + const int first_class_index = num_classes_with_bg - (int)(param->num_classes_); + const int64_t max_classes_per_anchor = + param->max_classes_per_detection_ < param->num_classes_ ? param->max_classes_per_detection_ : param->num_classes_; + float *scores = (float *)param->scores_; + for (int i = task_id; i < num_boxes; i += thread_num) { + int32_t *indexes = (int32_t *)param->indexes_ + i * param->num_classes_; + for (int j = 0; j < param->num_classes_; ++j) { + indexes[j] = i * num_classes_with_bg + first_class_index + j; + } + PartialArgSort(input_scores, indexes, max_classes_per_anchor, param->num_classes_); + scores[i] = input_scores[indexes[0]]; + } + return NNACL_OK; +} + +int DetectionPostProcessFast(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + const float *decoded_boxes, float *output_boxes, float *output_classes, + float *output_scores, float *output_num, + void (*PartialArgSort)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param) { + if (input_scores == NULL || decoded_boxes == NULL || output_boxes == NULL || output_classes == NULL || + output_scores == NULL || output_num == NULL || param == NULL || PartialArgSort == NULL) { + return NNACL_NULL_PTR; + } + int out_num = 0; + const int first_class_index = num_classes_with_bg - (int)(param->num_classes_); + const int64_t max_classes_per_anchor = + param->max_classes_per_detection_ < param->num_classes_ ? param->max_classes_per_detection_ : param->num_classes_; + int32_t *selected = (int32_t *)param->selected_; + int selected_num = NmsSingleClass(num_boxes, decoded_boxes, param->max_detections_, (float *)param->scores_, selected, + PartialArgSort, param); + for (int i = 0; i < selected_num; ++i) { + int32_t *indexes = (int32_t *)param->indexes_ + selected[i] * param->num_classes_; + BboxCorner *box = (BboxCorner *)(decoded_boxes) + selected[i]; + for (int j = 0; j < max_classes_per_anchor; ++j) { + *((BboxCorner *)(output_boxes) + out_num) = *box; + output_scores[out_num] = input_scores[indexes[j]]; + NNACL_ASSERT(num_classes_with_bg != 0); + output_classes[out_num++] = (float)(indexes[j] % num_classes_with_bg - first_class_index); + } + } + *output_num = (float)out_num; + for (int i = out_num; i < param->max_detections_ * param->max_classes_per_detection_; ++i) { + ((BboxCorner *)(output_boxes) + i)->ymin = 0; + ((BboxCorner *)(output_boxes) + i)->xmin = 0; + ((BboxCorner *)(output_boxes) + i)->ymax = 0; + ((BboxCorner *)(output_boxes) + i)->xmax = 0; + output_scores[i] = 0; + output_classes[i] = 0; + } + return NNACL_OK; +} + +int DetectionPostProcessRegular(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + float *output_boxes, float *output_classes, float *output_scores, float *output_num, + void (*PartialArgSort)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param) { + if (input_scores == NULL || output_boxes == NULL || output_classes == NULL || output_scores == NULL || + output_num == NULL || param == NULL || PartialArgSort == NULL) { + return NNACL_NULL_PTR; + } + const int first_class_index = num_classes_with_bg - (int)(param->num_classes_); + float *decoded_boxes = (float *)param->decoded_boxes_; + int32_t *selected = (int32_t *)param->selected_; + float *scores = (float *)param->scores_; + float *all_scores = (float *)param->all_class_scores_; + int32_t *indexes = (int32_t *)(param->indexes_); + int32_t *all_indexes = (int32_t *)(param->all_class_indexes_); + int all_classes_sorted_num = 0; + int all_classes_output_num = 0; + for (int j = first_class_index; j < num_classes_with_bg; ++j) { + // process single class + for (int i = 0; i < num_boxes; ++i) { + scores[i] = input_scores[i * num_classes_with_bg + j]; + } + int selected_num = + NmsSingleClass(num_boxes, decoded_boxes, param->detections_per_class_, scores, selected, PartialArgSort, param); + for (int i = 0; i < all_classes_sorted_num; ++i) { + indexes[i] = all_indexes[i]; + all_indexes[i] = i; + } + // process all classes + for (int i = 0; i < selected_num; ++i) { + indexes[all_classes_sorted_num] = selected[i] * num_classes_with_bg + j; + all_indexes[all_classes_sorted_num] = all_classes_sorted_num; + all_scores[all_classes_sorted_num++] = scores[selected[i]]; + } + all_classes_output_num = + all_classes_sorted_num < param->max_detections_ ? all_classes_sorted_num : param->max_detections_; + PartialArgSort(all_scores, all_indexes, all_classes_output_num, all_classes_sorted_num); + for (int i = 0; i < all_classes_output_num; ++i) { + scores[i] = all_scores[all_indexes[i]]; + all_indexes[i] = indexes[all_indexes[i]]; + } + for (int i = 0; i < all_classes_output_num; ++i) { + all_scores[i] = scores[i]; + } + all_classes_sorted_num = all_classes_output_num; + } + for (int i = 0; i < param->max_detections_ * param->max_classes_per_detection_; ++i) { + if (i < all_classes_output_num) { + NNACL_CHECK_ZERO_RETURN_ERR(num_classes_with_bg); + const int box_index = all_indexes[i] / num_classes_with_bg; + const int class_index = all_indexes[i] % num_classes_with_bg - first_class_index; + *((BboxCorner *)(output_boxes) + i) = *((BboxCorner *)(decoded_boxes) + box_index); + output_classes[i] = (float)class_index; + output_scores[i] = all_scores[i]; + } else { + ((BboxCorner *)(output_boxes) + i)->ymin = 0; + ((BboxCorner *)(output_boxes) + i)->xmin = 0; + ((BboxCorner *)(output_boxes) + i)->ymax = 0; + ((BboxCorner *)(output_boxes) + i)->xmax = 0; + output_classes[i] = 0.0f; + output_scores[i] = 0.0f; + } + } + *output_num = (float)all_classes_output_num; + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/detection_post_process_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/detection_post_process_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..4df46a790681b2033da8f67359d5738ebade5d14 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/detection_post_process_fp32.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_DETECTION_POST_PROCESS_H_ +#define MINDSPORE_NNACL_FP32_DETECTION_POST_PROCESS_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/detection_post_process_parameter.h" + +typedef struct { + float y; + float x; + float h; + float w; +} BboxCenter; + +typedef struct { + float ymin; + float xmin; + float ymax; + float xmax; +} BboxCorner; + +#ifdef __cplusplus +extern "C" { +#endif +int DecodeBoxes(int num_boxes, const float *input_boxes, const float *anchors, + const DetectionPostProcessParameter *param); + +int NmsMultiClassesFastCore(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + void (*)(const float *, int32_t *, int, int), const DetectionPostProcessParameter *param, + const int task_id, const int thread_num); + +int DetectionPostProcessFast(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + const float *decoded_boxes, float *output_boxes, float *output_classes, + float *output_scores, float *output_num, void (*)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param); + +int DetectionPostProcessRegular(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + float *output_boxes, float *output_classes, float *output_scores, float *output_num, + void (*)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_DETECTION_POST_PROCESS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a0aa229bd4cb85103cdd7fc434940a5b4e83b628 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32.c @@ -0,0 +1,136 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/div_fp32.h" +#include +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/div_fp32_simd.h" + +int ElementOptDiv(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptDivNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] / in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptDivNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] / in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptDivRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptDivReluNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] / in1[index]; + out[index] = out[index] > 0 ? out[index] : 0; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptDivReluNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] / in1[0]; + out[index] = out[index] > 0 ? out[index] : 0; + } + } + return NNACL_OK; +} + +int ElementOptDivRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptDivRelu6Num0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[0] / in1[index], RELU6_MIN_VAL), RELU6_MAX_VAL); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptDivRelu6Num1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] / in1[0], RELU6_MIN_VAL), RELU6_MAX_VAL); + } + } + return NNACL_OK; +} + +int ElementOptDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptDivIntNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[index] != 0); + out[index] = in0[0] / in1[index]; + } + } else { + NNACL_CHECK_ZERO_RETURN_ERR(in1[0] != 0); + + SIMD_RUN_NO_SCALAR(ElementOptDivIntNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] / in1[0]; + } + } + return NNACL_OK; +} + +int BroadcastDiv(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param) { + TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param); + return ElementDiv(tile_in0, tile_in1, out, size); +} + +int ElementDiv(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementDiv, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] / in1[index]; + } + return NNACL_OK; +} + +int ElementDivRelu(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementDivRelu, index, in0, in1, out, size); + for (; index < size; index++) { + float res = in0[index] / in1[index]; + out[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementDivRelu6(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementDivRelu6, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] / in1[index], RELU6_MIN_VAL), RELU6_MAX_VAL); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..8a966e39b23afe4599ba95a955638ea9956e941e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32.h @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_DIV_H_ +#define MINDSPORE_NNACL_FP32_DIV_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/arithmetic_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ElementDiv(const float *in0, const float *in1, float *out, int size); +int ElementDivRelu(const float *in0, const float *in1, float *out, int size); +int ElementDivRelu6(const float *in0, const float *in1, float *out, int size); +int ElementOptDiv(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptDivRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptDivRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int BroadcastDiv(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_DIV_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..1495b6f5949dbd4d6e4469313a548d45eb6be56a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32_simd.h.in @@ -0,0 +1,160 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_DIV_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_DIV_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int ElementOptDivNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_DIV_F32(vin0_opt, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_DIV_F32(vin0, vin1_opt_); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin0_opt = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_DIV_EPI32(vin0_opt, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin1_opt_ = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vout = SIMD_DIV_EPI32(vin0, vin1_opt_); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivReluNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_DIV_F32(vin0_opt, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivReluNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_DIV_F32(vin0, vin1_opt_), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivRelu6Num0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_DIV_F32(vin0_opt, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivRelu6Num1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_DIV_F32(vin0, vin1_opt_), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementDiv@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_DIV_F32(vin0, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementDivInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_DIV_EPI32(vin0, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementDivRelu@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_DIV_F32(vin0, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementDivRelu6@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_DIV_F32(vin0, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +}; +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..65630afca25b2498fdfa02db5e40b85a3580e17b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32.c @@ -0,0 +1,28 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/dropout_fp32.h" +#include "nnacl_c/dropout_fp32_simd.h" + +void DropoutFp32(const float *input, float scale, int length, float *output) { + int i = 0; + + SIMD_RUN_NO_SCALAR(DropoutFp32, i, input, scale, length, output); + + for (; i < length; ++i) { + output[i] = scale * input[i]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..50193b601128f83937635d0d8a7cc5893948ab08 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32.h @@ -0,0 +1,28 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_DROPOUT_FP32_H_ +#define MINDSPORE_NNACL_FP32_DROPOUT_FP32_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DropoutFp32(const float *input, float scale, int length, float *output); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_DROPOUT_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..36109cd8108b5e0723d524b85e90c86b3e6ed2b1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32_simd.h.in @@ -0,0 +1,39 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_DROPOUTFP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_DROPOUTFP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int DropoutFp32@SIMD_INSTRUCTION@(int index, const float *input, float scale, + int length, float *output) { + SIMD_F32 scale_value = SIMD_MOV_F32(scale); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_MUL_F32(SIMD_LD_F32(input + index), scale_value)); + } + return index; +} +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/embedding_lookup_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/embedding_lookup_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..f2c6425c8694d8b115aa8438ddafd3516a193d96 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/embedding_lookup_fp32.c @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/embedding_lookup_fp32.h" +#include +#include "nnacl_c/errorcode.h" + +void l2_regulate(float *data, int size, float max_norm) { + float sum = 0; + for (int i = 0; i < size; ++i) { + sum += data[i]; + } + if (sum != 0) { + for (int i = 0; i < size; ++i) { + data[i] *= max_norm / sum; + } + } + return; +} + +int CopyData(float *input_data, const int32_t *ids, float *output_data, int num, + const EmbeddingLookupParameter *parameter) { + if (ids[num] >= parameter->layer_num_ || ids[num] < 0) { + return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; + } + float *out_data = output_data + num * parameter->layer_size_; + float *in_data = input_data + ids[num] * parameter->layer_size_; + if (!parameter->is_regulated_[ids[num]]) { + l2_regulate(in_data, parameter->layer_size_, parameter->max_norm_); + parameter->is_regulated_[ids[num]] = true; + } + + memcpy(out_data, in_data, sizeof(float) * (size_t)(parameter->layer_size_)); + return NNACL_OK; +} + +int EmbeddingLookup(float *input_data, const int32_t *ids, float *output_data, + const EmbeddingLookupParameter *parameter, int task_id) { + if (parameter->op_parameter_.thread_num_ == 0) { + return NNACL_PARAM_INVALID; + } + for (int i = task_id; i < parameter->ids_size_; i += parameter->op_parameter_.thread_num_) { + int ret = CopyData(input_data, ids, output_data, i, parameter); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/embedding_lookup_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/embedding_lookup_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..236d053d9ef5393a58629c39d2cfc849aa87eecf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/embedding_lookup_fp32.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_EMBEDDING_LOOKUP_H_ +#define MINDSPORE_NNACL_FP32_EMBEDDING_LOOKUP_H_ + +#include "nnacl_c/op_base.h" + +typedef struct EmbeddingLookupParameter { + OpParameter op_parameter_; + // primitive parameter + float max_norm_; + + // shape correlative + bool *is_regulated_; + int ids_size_; + int layer_size_; + int layer_num_; +} EmbeddingLookupParameter; + +#ifdef __cplusplus +extern "C" { +#endif +int EmbeddingLookup(float *input_data, const int32_t *ids, float *output_data, + const EmbeddingLookupParameter *parameter, int task_id); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_EMBEDDING_LOOKUP_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..cab8c075aaf29865e954ee85d14c1b1fc576ab1e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32.c @@ -0,0 +1,62 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/exp_fp32_simd.h" +#include +#include +#include "nnacl_c/errorcode.h" + +void ExpFp32(const float *src, float *dst, int num) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ExpFp32, i, src, dst, num); + for (; i < num; ++i) { + simd_exp32(src[i], dst + i); + } +} + +int ExpFusionFp32(const void *src_data, void *dst_data, const ExpStruct *exp, int task_id) { + NNACL_CHECK_ZERO_RETURN_ERR(exp->base_.thread_nr_); + ExpParameter *param = (ExpParameter *)exp->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + const float *src = (const float *)src_data; + float *dst = (float *)dst_data; + + int stride = UP_DIV(exp->element_num_, exp->base_.thread_nr_); + int start = stride * task_id; + int end = MSMIN(exp->element_num_, start + stride); + int num = end - start; + + if (param->scale_ == 1) { + ExpFp32(src + start, dst + start, num); + } else { + int i = 0; + SIMD_RUN_NO_SCALAR(ExpFp32WithInScale, i, src, dst, num, exp->in_scale_); + for (; i < num; ++i) { + simd_exp32(src[i] * exp->in_scale_, dst + i); + } + } + if (exp->out_scale_ != 1) { + int i = 0; + SIMD_RUN_NO_SCALAR(ExpFp32WithOutScale, i, src, dst, num, exp->out_scale_); + for (; i < num; ++i) { + simd_exp32(src[i], dst + i); + dst[i] *= exp->out_scale_; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..d02a05628fc8d6b4e73916ed6d938b178a99f54c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_EXP_H_ +#define MINDSPORE_NNACL_FP32_EXP_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/exp.h" +#include "nnacl_c/exp_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ExpFp32(const float *src, float *dst, int num); +int ExpFusionFp32(const void *src_data, void *dst_data, const ExpStruct *exp, int task_id); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_EXP_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..e884e43ebead5fea33df7d15de0beef2896dbeae --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32_simd.h.in @@ -0,0 +1,56 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_DIV_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_DIV_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t ExpFp32@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, int num) { + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EXP_ST_F32(SIMD_LD_F32(src + index), dst + index); + } + return index; +} + +static inline int64_t ExpFp32WithInScale@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, int num, float in_scale) { + SIMD_F32 scale_vec = SIMD_MOV_F32(in_scale); + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EXP_ST_F32(SIMD_MUL_F32(SIMD_LD_F32(src + index), scale_vec), dst + index); + } + return index; +} + +static inline int64_t ExpFp32WithOutScale@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, int num, float out_scale) { + SIMD_F32 scale_vec = SIMD_MOV_F32(out_scale); + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EXP_ST_F32(SIMD_LD_F32(src + index), dst + index); + SIMD_ST_F32(dst + index, SIMD_MUL_F32(SIMD_LD_F32(dst + index), scale_vec)); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +}; +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gatherNd_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gatherNd_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..ed9e6cf2580ad0b58bdd05a3b28b4e92967bc604 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gatherNd_fp32.c @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/gatherNd_fp32.h" +#include +#include "nnacl_c/errorcode.h" + +int GatherNd(const void *input, void *output, const int32_t *in_offset, int area, int count, int data_type_len) { + int i = 0; + for (i = 0; i < count; i++) { + (void)memcpy((int8_t *)output + area * i * data_type_len, (int8_t *)input + in_offset[i] * data_type_len, + (size_t)(area)*data_type_len); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gatherNd_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gatherNd_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..ffe80119d19979176f28c8dd34ed135b060dc4ef --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gatherNd_fp32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GATHERND_FP32_H_ +#define NNACL_FP32_GATHERND_FP32_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +int GatherNd(const void *input, void *output, const int32_t *in_offset, int area, int count, int data_type_len); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GATHERND_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d6de5be206700428be719021d033f01f8b45ef34 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32.c @@ -0,0 +1,125 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/group_norm_fp32.h" +#include +#include "nnacl_c/group_norm_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/group_norm_fp32_simd.h" + +static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run_var, int completed_group, + int cur_groups, const GroupNormParameter *param); + +int GroupNormFp32(const float *input, const float *scale, const float *offset, float *mean, float *variance, + const GroupNormParameter *param, int task_id, float *output) { + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + const int frame_elem_num = param->unit_ * param->channel_; + const int groups_per_thread = UP_DIV(param->num_groups_, param->op_parameter_.thread_num_); + const int completed_group = task_id * groups_per_thread; + const int cur_group = MSMIN(groups_per_thread, param->num_groups_ - completed_group); + const int num_of_ch_per_group = param->channel_ / param->num_groups_; + int cur_offset = completed_group * num_of_ch_per_group * param->unit_; + + for (int b = 0; b < param->batch_; b++) { + const float *b_in = input + b * frame_elem_num; + float *b_out = output + b * frame_elem_num; + int b_offset = cur_offset; + GroupNormFp32MeanVar(b_in, mean, variance, completed_group, cur_group, param); + for (int g = 0; g < cur_group; g++) { + int grp_idx = g + completed_group; + int c_offset = grp_idx * num_of_ch_per_group; + float m = mean[grp_idx]; + float v = variance[grp_idx]; + float variance_sqrt = sqrtf(v + param->epsilon_); + if (variance_sqrt == 0) { + return NNACL_ERR; + } + for (int c = 0; c < num_of_ch_per_group; c++) { + const float *unit_input = b_in + b_offset; + float *unit_output = b_out + b_offset; + float s = scale[c_offset + c]; + float o = offset[c_offset + c]; + int u = 0; + SIMD_RUN_NO_SCALAR(GroupNormFp32, u, unit_input, s, o, m, variance_sqrt, param->unit_, unit_output); + for (; u < param->unit_; u++) { + float norm_val = (unit_input[u] - m) / variance_sqrt; + unit_output[u] = norm_val * s + o; + } + b_offset += param->unit_; + } + } + } + return NNACL_OK; +} + +#define SimdReduceSum(block_size, block_num, in, i, sum) \ + do { \ + for (int block_max_size = param->unit_ - block_num + 1; i < block_max_size; i += block_num) { \ + MS_FLOAT_32xN(block_num) input = MS_LD_F32(block_size, in + i); \ + sum += MS_GET_SUM_F32(block_size, input); \ + } \ + } while (0) + +#define SimdReduceVar(block_size, block_num, in, m, i, sum) \ + do { \ + MS_FLOAT_32xN(block_num) mean = MS_MOVN_F32(block_size, m); \ + MS_FLOAT_32xN(block_num) tmp = MS_MOVN_F32(block_size, 0); \ + for (int block_max_size = param->unit_ - block_num + 1; i < block_max_size; i += block_num) { \ + MS_FLOAT_32xN(block_num) input = MS_SUB_F32(block_size, MS_LD_F32(block_size, in + i), mean); \ + tmp = MS_ADD_F32(block_size, tmp, MS_MUL_F32(block_size, input, input)); \ + } \ + sum += MS_GET_SUM_F32(block_size, tmp); \ + } while (0) + +static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run_var, int completed_group, + int cur_groups, const GroupNormParameter *param) { + const int num_of_ch_per_group = param->channel_ / param->num_groups_; + const float N = (float)(param->unit_ * num_of_ch_per_group); + + // calc mean + for (int g = 0; g < cur_groups; g++) { + int g_idx = g + completed_group; + float sum = 0; + for (int c = 0; c < num_of_ch_per_group; c++) { + const float *in = input + (num_of_ch_per_group * g_idx + c) * param->unit_; + int i = 0; + SIMD_RUN_NO_SCALAR(GroupNormReduceSum, i, in, &sum, param->unit_); + for (; i < param->unit_; i++) { + sum += in[i]; + } + } + run_mean[g_idx] = sum / N; + } + + // calc variance + for (int g = 0; g < cur_groups; g++) { + int g_idx = g + completed_group; + float var = 0; + run_var[g_idx] = 0; + for (int c = 0; c < num_of_ch_per_group; c++) { + const float *in = input + (num_of_ch_per_group * g_idx + c) * param->unit_; + int i = 0; + SIMD_RUN_NO_SCALAR(GroupNormReduceVar, i, in, run_mean[g_idx], &var, param->unit_); + for (; i < param->unit_; i++) { + var += (in[i] - run_mean[g_idx]) * (in[i] - run_mean[g_idx]); + } + } + run_var[g_idx] = var / N; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..f5b595ae5af3cd1dc9d5a904e13603ecefa5aa2d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GROUP_NORM_FP32_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GROUP_NORM_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/group_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GroupNormFp32(const float *input, const float *scale, const float *offset, float *mean, float *variance, + const GroupNormParameter *param, int task_id, float *output); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GROUP_NORM_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..33bde9b7ffb8a167edf4b9010a058d93806aa59e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32_simd.h.in @@ -0,0 +1,70 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_GROUP_NORM_FP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_GROUP_NORM_FP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t GroupNormFp32@SIMD_INSTRUCTION@(int64_t index, const float *unit_input, float scale, float offset, float mean, + float var_sqrt, int unit, float *unit_output) { + SIMD_F32 mean_val = SIMD_MOV_F32(mean); + SIMD_F32 v_sqrt = SIMD_MOV_F32(var_sqrt); + SIMD_F32 scale_val = SIMD_MOV_F32(scale); + SIMD_F32 offset_val = SIMD_MOV_F32(offset); + for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_LD_F32(unit_input + index); + SIMD_F32 norm_val = SIMD_DIV_F32(SIMD_SUB_F32(input, mean_val), v_sqrt); + SIMD_F32 output = SIMD_ADD_F32(SIMD_MUL_F32(norm_val, scale_val), offset_val); + SIMD_ST_F32(unit_output + index, output); + } + return index; +} + +static inline int64_t GroupNormReduceSum@SIMD_INSTRUCTION@(int64_t index, const float *in, float *sum, int unit) { + if (unit - index >= 4 * BLOCK_NUM) { + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(in + index)); + } + *sum += SIMD_GET_SUM_F32(tmp); + } + return index; +} + +static inline int64_t GroupNormReduceVar@SIMD_INSTRUCTION@(int64_t index, const float *in, float mean, float *sum, int unit) { + if (unit - index >= 4 * BLOCK_NUM) { + SIMD_F32 mean_val = SIMD_MOV_F32(mean); + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_SUB_F32(SIMD_LD_F32(in + index), mean_val); + tmp = SIMD_ADD_F32(tmp, SIMD_MUL_F32(input, input)); + } + *sum += SIMD_GET_SUM_F32(tmp); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gru_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gru_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..cf113430fd9d11f79b08a77cf845e1dd92ee6a4d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gru_fp32.c @@ -0,0 +1,154 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/gru_fp32.h" +#include +#include "nnacl_c/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +void GruMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) { + if (is_vec) { + MatVecMulFp32(a, b, c, bias, ActType_No, deep, col); + } else { + MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); + } +} + +void GruStepUnit(float *output, float *update_gate, float *reset_gate, float *hidden_buffer, const float *state_weight, + const float *state_bias, float *hidden_state, float *buffer[4], const GruParameter *gru_param) { + float *packed_state = buffer[2]; + float *state_gate = buffer[3]; + bool is_vec = gru_param->batch_ == 1; + + const float *state_update_weight = state_weight; + const float *state_reset_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_; + const float *state_hidden_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_ * 2; + float *state_update_gate = state_gate; + float *state_reset_gate = state_gate + gru_param->batch_ * gru_param->hidden_size_; + float *state_hidden_buffer = state_gate + gru_param->batch_ * gru_param->hidden_size_ * 2; + const float *state_update_bias = state_bias; + const float *state_reset_bias = state_bias + gru_param->hidden_size_; + const float *state_hidden_bias = state_bias + gru_param->hidden_size_ * 2; + + // state * weight + if (is_vec) { + GruMatMul(state_reset_gate, hidden_state, state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + GruMatMul(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + PackLstmInput(hidden_state, packed_state, gru_param->batch_, gru_param->hidden_size_); + GruMatMul(state_reset_gate, packed_state, state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + GruMatMul(state_update_gate, packed_state, state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + ElementAdd(update_gate, state_update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_); + ElementAdd(reset_gate, state_update_gate + gru_param->batch_ * gru_param->hidden_size_, reset_gate, + gru_param->batch_ * gru_param->hidden_size_); + + // update reset_gate + Sigmoid(reset_gate, gru_param->batch_ * gru_param->hidden_size_, reset_gate); + // update update_gate + Sigmoid(update_gate, gru_param->batch_ * gru_param->hidden_size_, update_gate); + + ElementMul(hidden_state, reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_); + if (is_vec) { + GruMatMul(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + PackLstmInput(reset_gate, packed_state, gru_param->batch_, gru_param->hidden_size_); + GruMatMul(state_hidden_buffer, packed_state, state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + ElementAdd(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_); + + Tanh(hidden_buffer, gru_param->batch_ * gru_param->hidden_size_, hidden_buffer); + + ElementMul(update_gate, hidden_state, hidden_state, gru_param->batch_ * gru_param->hidden_size_); + + const float one = 1.0f; + ElementOptSub(&one, update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_, true); + + ElementMulAcc(update_gate, hidden_buffer, hidden_state, gru_param->batch_ * gru_param->hidden_size_); + + memcpy(output, hidden_state, gru_param->batch_ * gru_param->hidden_size_ * sizeof(float)); +} + +void GruUnidirectional(float *output, const float *packed_input, const float *weight_g, const float *weight_r, + const float *input_bias, const float *state_bias, float *hidden_state, float *buffer[4], + const GruParameter *gru_param, bool is_backward) { + float *gate = buffer[1]; + for (int i = 0; i < 3; i++) { + const float *weight_loop = weight_g + gru_param->input_size_ * gru_param->input_col_align_ * i; + const float *bias_loop = input_bias + gru_param->input_col_align_ * i; + float *gate_loop = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * i; + MatMulOpt(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, gru_param->input_size_, + gru_param->seq_len_ * gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, OutType_Nhwc); + } + + float *update_gate = gate; + float *reset_gate = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_; + float *hidden_buffer = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * 2; + for (int t = 0; t < gru_param->seq_len_; t++) { + int real_t = is_backward ? gru_param->seq_len_ - t - 1 : t; + float *update_gate_t = update_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float *reset_gate_t = reset_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float *hidden_buffer_t = hidden_buffer + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float *output_ptr = output + real_t * gru_param->output_step_; + GruStepUnit(output_ptr, update_gate_t, reset_gate_t, hidden_buffer_t, weight_r, state_bias, hidden_state, buffer, + gru_param); + } +} + +void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *input_bias, + const float *state_bias, float *hidden_state, float *buffer[4], int check_seq_len, + const GruParameter *gru_param) { + // forward + float *packed_input = buffer[0]; + PackLstmInput(input, packed_input, gru_param->seq_len_ * gru_param->batch_, gru_param->input_size_); + GruUnidirectional(output, packed_input, weight_g, weight_r, input_bias, state_bias, hidden_state, buffer, gru_param, + false); + + // zero out extra fw outputs + for (int t = check_seq_len; t < gru_param->seq_len_; t++) { + float *output_ptr = output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + + // backward + if (gru_param->bidirectional_) { + const float *backward_weight_g = weight_g + 3 * gru_param->input_col_align_ * gru_param->input_size_; + const float *backward_weight_r = weight_r + 3 * gru_param->state_col_align_ * gru_param->hidden_size_; + const float *backward_input_bias = input_bias + 3 * gru_param->input_col_align_; + const float *backward_state_bias = state_bias + 3 * gru_param->state_col_align_; + float *backward_output = output + gru_param->batch_ * gru_param->hidden_size_; + float *backward_hidden_state = hidden_state + gru_param->batch_ * gru_param->hidden_size_; + + GruUnidirectional(backward_output, packed_input, backward_weight_g, backward_weight_r, backward_input_bias, + backward_state_bias, backward_hidden_state, buffer, gru_param, true); + + // zero out extra bw outputs + for (int t = gru_param->seq_len_ - 1; t >= check_seq_len; t--) { + float *output_ptr = backward_output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gru_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gru_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..cba05c1b48599852158241ceb6870dc2d6a31967 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gru_fp32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_GRU_FP32_H_ +#define MINDSPORE_NNACL_FP32_GRU_FP32_H_ +#include "nnacl_c/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *input_bias, + const float *state_bias, float *hidden_state, float *buffer[4], int check_seq_len, + const GruParameter *gru_parm); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_GRU_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/instance_norm_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/instance_norm_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..455eb07138266939e61cd46a7e2f382048367950 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/instance_norm_fp32.c @@ -0,0 +1,374 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/instance_norm_fp32.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +int InstanceNorm(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id) { + NNACL_CHECK_NULL_RETURN_ERR(src_data); + NNACL_CHECK_NULL_RETURN_ERR(dst_data); + NNACL_CHECK_NULL_RETURN_ERR(gamma_data); + NNACL_CHECK_NULL_RETURN_ERR(beta_data); + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); + int channel_step = UP_DIV(param->channel_, param->op_parameter_.thread_num_); + int channel_begin = (int)(task_id)*channel_step; + int channel_end = MSMIN(channel_begin + channel_step, param->channel_); + + for (int b = 0; b < param->batch_; b++) { + const float *src_b = src_data + b * param->channel_ * param->inner_size_; + float *dst_b = dst_data + b * param->channel_ * param->inner_size_; + for (int c = channel_begin; c < channel_end; c++) { + const float *src = src_b + c * param->inner_size_; + float *dst = dst_b + c * param->inner_size_; + double mean = 0.0f; + double squ_m = 0.0f; + + int index = 0; +#if defined(ENABLE_AVX) + for (; index <= param->inner_size_ - C8NUM; index += C8NUM) { + __m256 srcv = _mm256_loadu_ps(src + index); + __m256 squarev = _mm256_mul_ps(srcv, srcv); + __m128 src128 = _mm_add_ps(_mm256_extractf128_ps(srcv, 0), _mm256_extractf128_ps(srcv, 1)); + __m128 square128 = _mm_add_ps(_mm256_extractf128_ps(squarev, 0), _mm256_extractf128_ps(squarev, 1)); + for (int i = 0; i < C4NUM; ++i) { + mean += MS_F32X4_GETI(src128, i); + squ_m += MS_F32X4_GETI(square128, i); + } + } +#endif + +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + for (; index <= param->inner_size_ - C4NUM; index += C4NUM) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index); + MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv); +#ifdef ENABLE_ARM64 + mean += vaddvq_f32(srcv); + squ_m += vaddvq_f32(squarev); +#elif defined(ENABLE_SSE) + for (int i = 0; i < C4NUM; ++i) { + mean += MS_F32X4_GETI(srcv, i); + squ_m += MS_F32X4_GETI(squarev, i); + } +#else + float32x2_t src_add2 = vadd_f32(vget_low_f32(srcv), vget_high_f32(srcv)); + float32x2_t src_add4 = vpadd_f32(src_add2, src_add2); + mean += vget_lane_f32(src_add4, 0); + float32x2_t square_add2 = vadd_f32(vget_low_f32(squarev), vget_high_f32(squarev)); + float32x2_t square_add4 = vpadd_f32(square_add2, square_add2); + squ_m += vget_lane_f32(square_add4, 0); +#endif + } +#endif + for (; index < param->inner_size_; index++) { + mean += src[index]; + squ_m += src[index] * src[index]; + } + + mean /= (float)param->inner_size_; + squ_m /= (float)param->inner_size_; + const double deno = gamma_data[c] / sqrt(squ_m - mean * mean + param->epsilon_); + + index = 0; +#if defined(ENABLE_AVX) + MS_FLOAT32X8 meanv8 = MS_MOV256_F32(mean); + MS_FLOAT32X8 denov8 = MS_MOV256_F32(deno); + for (; index <= param->inner_size_ - C8NUM; index += C8NUM) { + MS_FLOAT32X8 srcv8 = MS_LD256_F32(src + index); + MS_FLOAT32X8 dstv8 = + MS_ADD256_F32(MS_MUL256_F32(MS_SUB256_F32(srcv8, meanv8), denov8), MS_MOV256_F32(*(beta_data + c))); + MS_ST256_F32(dst + index, dstv8); + } +#endif + +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 meanv4 = MS_MOVQ_F32(mean); + MS_FLOAT32X4 denov4 = MS_MOVQ_F32(deno); + for (; index <= param->inner_size_ - C4NUM; index += C4NUM) { + MS_FLOAT32X4 srcv4 = MS_LDQ_F32(src + index); + MS_FLOAT32X4 dstv4 = + MS_ADDQ_F32(MS_MULQ_F32(MS_SUBQ_F32(srcv4, meanv4), denov4), MS_MOVQ_F32(*(beta_data + c))); + MS_STQ_F32(dst + index, dstv4); + } +#endif + for (; index < param->inner_size_; index++) { + dst[index] = (src[index] - mean) * deno + beta_data[c]; + } + } + } + return NNACL_OK; +} + +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) +void InstanceNormC4HW4ArmSse(const float *src_b, float *dst_b, const float *gamma_data, const float *beta_data, + int32_t *c_src, const InstanceNormParameter *param, int channel, int channel_end, + int hw_plane, MS_FLOAT32X4 hw_planev) { + int c = *c_src; + for (; c <= channel_end - C16NUM; c += C16NUM) { + const float *src = src_b + c * hw_plane, *src1 = src_b + (c + C4NUM) * hw_plane; + const float *src2 = src_b + (c + C8NUM) * hw_plane, *src3 = src_b + (c + C12NUM) * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f), mean1 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 mean2 = MS_MOVQ_F32(0.0f), mean3 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 squ_m = MS_MOVQ_F32(0.0f), squ_m1 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 squ_m2 = MS_MOVQ_F32(0.0f), squ_m3 = MS_MOVQ_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 srcv2 = MS_LDQ_F32(src2 + index * C4NUM), srcv3 = MS_LDQ_F32(src3 + index * C4NUM); + MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv), squarev1 = MS_MULQ_F32(srcv1, srcv1); + MS_FLOAT32X4 squarev2 = MS_MULQ_F32(srcv2, srcv2), squarev3 = MS_MULQ_F32(srcv3, srcv3); + MS_ADDQ_F32_VEC(mean, mean1, mean2, mean3, srcv, srcv1, srcv2, srcv3); + MS_ADDQ_F32_VEC(squ_m, squ_m1, squ_m2, squ_m3, squarev, squarev1, squarev2, squarev3); + } + MS_DIVQ_F32_VEC(mean, mean1, mean2, mean3, hw_planev); + MS_DIVQ_F32_VEC(squ_m, squ_m1, squ_m2, squ_m3, hw_planev); + + MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(squ_m, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno1 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m1, MS_MULQ_F32(mean1, mean1)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno2 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m2, MS_MULQ_F32(mean2, mean2)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno3 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m3, MS_MULQ_F32(mean3, mean3)), MS_MOVQ_F32(param->epsilon_)); + + deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); + deno1 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno1)); + deno2 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno2)); + deno3 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno3)); + + MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno); // deno * gamma_data[c] + MS_FLOAT32X4 gammav1 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C4NUM), deno1); // deno * gamma_data[c] + MS_FLOAT32X4 gammav2 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C8NUM), deno2); // deno * gamma_data[c] + MS_FLOAT32X4 gammav3 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C12NUM), deno3); // deno * gamma_data[c] + MS_FLOAT32X4 betav = MS_LDQ_F32(beta_data + c), betav1 = MS_LDQ_F32(beta_data + c + C4NUM); + MS_FLOAT32X4 betav2 = MS_LDQ_F32(beta_data + c + C8NUM), betav3 = MS_LDQ_F32(beta_data + c + C12NUM); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 srcv2 = MS_LDQ_F32(src2 + index * C4NUM), srcv3 = MS_LDQ_F32(src3 + index * C4NUM); + MS_FLOAT32X4 outv = MS_SUBQ_F32(srcv, mean), outv1 = MS_SUBQ_F32(srcv1, mean1); + MS_FLOAT32X4 outv2 = MS_SUBQ_F32(srcv2, mean2), outv3 = MS_SUBQ_F32(srcv3, mean3); + + outv = MS_MULQ_F32(outv, gammav), outv1 = MS_MULQ_F32(outv1, gammav1); + outv2 = MS_MULQ_F32(outv2, gammav2), outv3 = MS_MULQ_F32(outv3, gammav3); + MS_ADDQ_F32_VEC(outv, outv1, outv2, outv3, betav, betav1, betav2, betav3); + + MS_STQ_F32(dst + index * channel, outv), MS_STQ_F32(dst + index * channel + C4NUM, outv1); + MS_STQ_F32(dst + index * channel + C8NUM, outv2), MS_STQ_F32(dst + index * channel + C12NUM, outv3); + } + } + for (; c <= channel_end - C8NUM; c += C8NUM) { + const float *src = src_b + c * hw_plane, *src1 = src_b + (c + C4NUM) * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f), mean1 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 squ_m = MS_MOVQ_F32(0.0f), squ_m1 = MS_MOVQ_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv), squarev1 = MS_MULQ_F32(srcv1, srcv1); + mean = MS_ADDQ_F32(mean, srcv), mean1 = MS_ADDQ_F32(mean1, srcv1); + squ_m = MS_ADDQ_F32(squ_m, squarev), squ_m1 = MS_ADDQ_F32(squ_m1, squarev1); + } + + MS_DIVQ_F32_VEC(mean, mean1, squ_m, squ_m1, hw_planev); + MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(squ_m, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno1 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m1, MS_MULQ_F32(mean1, mean1)), MS_MOVQ_F32(param->epsilon_)); + deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); + deno1 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno1)); + + MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno); // deno * gamma_data[c] + MS_FLOAT32X4 gammav1 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C4NUM), deno1); // deno * gamma_data[c] + MS_FLOAT32X4 betav = MS_LDQ_F32(beta_data + c), betav1 = MS_LDQ_F32(beta_data + c + C4NUM); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 outv = MS_SUBQ_F32(srcv, mean), outv1 = MS_SUBQ_F32(srcv1, mean1); + outv = MS_MULQ_F32(outv, gammav), outv1 = MS_MULQ_F32(outv1, gammav1); + outv = MS_ADDQ_F32(outv, betav), outv1 = MS_ADDQ_F32(outv1, betav1); + MS_STQ_F32(dst + index * channel, outv); + MS_STQ_F32(dst + index * channel + C4NUM, outv1); + } + } + for (; c <= channel_end - C4NUM; c += C4NUM) { + const float *src = src_b + c * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f), squ_m = MS_MOVQ_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), squarev = MS_MULQ_F32(srcv, srcv); + mean = MS_ADDQ_F32(mean, srcv), squ_m = MS_ADDQ_F32(squ_m, squarev); + } + mean = MS_DIVQ_F32(mean, hw_planev), squ_m = MS_DIVQ_F32(squ_m, hw_planev); + MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(squ_m, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); + deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); + + MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno), betav = MS_LDQ_F32(beta_data + c); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), outv = MS_SUBQ_F32(srcv, mean); + MS_STQ_F32(dst + index * channel, MS_ADDQ_F32(MS_MULQ_F32(outv, gammav), betav)); + } + } + *c_src = c; +} +#endif + +int InstanceNormNC4HW4(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id) { + NNACL_CHECK_NULL_RETURN_ERR(src_data); + NNACL_CHECK_NULL_RETURN_ERR(dst_data); + NNACL_CHECK_NULL_RETURN_ERR(gamma_data); + NNACL_CHECK_NULL_RETURN_ERR(beta_data); + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); + int channel = param->channel_; + int hw_plane = param->inner_size_; + int channel_step = UP_DIV(UP_DIV(channel, C4NUM), param->op_parameter_.thread_num_) * C4NUM; + int channel_begin = (int)(task_id)*channel_step; + int channel_end = MSMIN(channel_begin + channel_step, channel); +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) + MS_FLOAT32X4 hw_planev = MS_MOVQ_F32((float)(hw_plane)); +#endif + for (int b = 0; b < param->batch_; b++) { + const float *src_b = src_data + b * channel * hw_plane; + float *dst_b = dst_data + b * channel * hw_plane; + int c = channel_begin; +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + InstanceNormC4HW4ArmSse(src_b, dst_b, gamma_data, beta_data, &c, param, channel, channel_end, hw_plane, hw_planev); +#endif + for (; c < channel_end; ++c) { + int c4_down_loop = c / C4NUM * C4NUM; + int c4_mod = c % C4NUM; + int c_res = MSMIN(channel_end - c4_down_loop, C4NUM); + const float *src = src_b + c4_down_loop * hw_plane + c4_mod; + float *dst = dst_b + c; + float mean = 0.0f; + float squ_m = 0.0f; + for (int index = 0; index < hw_plane; ++index) { + float tmp = src[index * c_res]; + mean += tmp; + squ_m += tmp * tmp; + } + mean /= (float)hw_plane; + squ_m /= (float)hw_plane; + const float deno = gamma_data[c] / sqrtf(squ_m - mean * mean + param->epsilon_); + for (int index = 0; index < hw_plane; ++index) { + dst[index * channel] = (src[index * c_res] - mean) * deno + beta_data[c]; + } + } + } + return NNACL_OK; +} + +#ifdef ENABLE_AVX +int InstanceNormNC8HW8(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id) { + NNACL_CHECK_NULL_RETURN_ERR(src_data); + NNACL_CHECK_NULL_RETURN_ERR(dst_data); + NNACL_CHECK_NULL_RETURN_ERR(gamma_data); + NNACL_CHECK_NULL_RETURN_ERR(beta_data); + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); + int channel = param->channel_, hw_plane = param->inner_size_; + int channel_step = UP_DIV(UP_DIV(channel, C8NUM), param->op_parameter_.thread_num_) * C8NUM; + int channel_begin = (int)(task_id)*channel_step; + int channel_end = MSMIN(channel_begin + channel_step, channel); + MS_FLOAT32X8 hw_planev = MS_MOV256_F32((float)(hw_plane)); + for (int b = 0; b < param->batch_; b++) { + const float *src_b = src_data + b * channel * hw_plane; + float *dst_b = dst_data + b * channel * hw_plane; + int c = channel_begin; + for (; c <= channel_end - C16NUM; c += C16NUM) { + const float *src = src_b + c * hw_plane; + const float *src1 = src_b + (c + C8NUM) * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X8 mean = MS_MOV256_F32(0.0f), mean1 = MS_MOV256_F32(0.0f); + MS_FLOAT32X8 squ_m = MS_MOV256_F32(0.0f), squ_m1 = MS_MOV256_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM), srcv1 = MS_LD256_F32(src1 + index * C8NUM); + MS_FLOAT32X8 squarev = MS_MUL256_F32(srcv, srcv), squarev1 = MS_MUL256_F32(srcv1, srcv1); + mean = MS_ADD256_F32(mean, srcv); + mean1 = MS_ADD256_F32(mean1, srcv1); + squ_m = MS_ADD256_F32(squ_m, squarev); + squ_m1 = MS_ADD256_F32(squ_m1, squarev1); + } + mean = MS_DIV256_F32(mean, hw_planev); + mean1 = MS_DIV256_F32(mean1, hw_planev); + squ_m = MS_DIV256_F32(squ_m, hw_planev); + squ_m1 = MS_DIV256_F32(squ_m1, hw_planev); + MS_FLOAT32X8 deno = + MS_ADD256_F32(MS_SUB256_F32(squ_m, MS_MUL256_F32(mean, mean)), MS_MOV256_F32(param->epsilon_)); + MS_FLOAT32X8 deno1 = + MS_ADD256_F32(MS_SUB256_F32(squ_m1, MS_MUL256_F32(mean1, mean1)), MS_MOV256_F32(param->epsilon_)); + deno = MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_SQRTFX8_F32(deno)); + deno1 = MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_SQRTFX8_F32(deno1)); + + MS_FLOAT32X8 gammav = MS_MUL256_F32(MS_LD256_F32(gamma_data + c), deno); // deno * gamma_data[c] + MS_FLOAT32X8 gammav1 = MS_MUL256_F32(MS_LD256_F32(gamma_data + c + C8NUM), deno1); // deno1 * gamma_data[c] + MS_FLOAT32X8 betav = MS_LD256_F32(beta_data + c), betav1 = MS_LD256_F32(beta_data + c + C8NUM); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM), srcv1 = MS_LD256_F32(src1 + index * C8NUM); + MS_FLOAT32X8 outv = MS_SUB256_F32(srcv, mean), outv1 = MS_SUB256_F32(srcv1, mean1); + outv = MS_MUL256_F32(outv, gammav); + outv1 = MS_MUL256_F32(outv1, gammav1); + outv = MS_ADD256_F32(outv, betav); + outv1 = MS_ADD256_F32(outv1, betav1); + MS_ST256_F32(dst + index * channel, outv); + MS_ST256_F32(dst + index * channel + C8NUM, outv1); + } + } + for (; c <= channel_end - C8NUM; c += C8NUM) { + const float *src = src_b + c * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X8 mean = MS_MOV256_F32(0.0f), squ_m = MS_MOV256_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM); + MS_FLOAT32X8 squarev = MS_MUL256_F32(srcv, srcv); + mean = MS_ADD256_F32(mean, srcv); + squ_m = MS_ADD256_F32(squ_m, squarev); + } + mean = MS_DIV256_F32(mean, hw_planev); + squ_m = MS_DIV256_F32(squ_m, hw_planev); + MS_FLOAT32X8 deno = MS_ADD256_F32(MS_SUB256_F32(squ_m, MS_MUL256_F32(mean, mean)), + MS_MOV256_F32(param->epsilon_)); // 256uestion + deno = MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_SQRTFX8_F32(deno)); + + MS_FLOAT32X8 gammav = MS_MUL256_F32(MS_LD256_F32(gamma_data + c), deno); // deno * gamma_data[c] + MS_FLOAT32X8 betav = MS_LD256_F32(beta_data + c); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM), outv = MS_SUB256_F32(srcv, mean); + outv = MS_MUL256_F32(outv, gammav); + outv = MS_ADD256_F32(outv, betav); + MS_ST256_F32(dst + index * channel, outv); + } + } + for (; c < channel_end; ++c) { + int c8_down_loop = c / C8NUM * C8NUM, c8_mod = c % C8NUM; + int c_res = MSMIN(channel_end - c8_down_loop, C8NUM); + const float *src = src_b + c8_down_loop * hw_plane + c8_mod; + float *dst = dst_b + c; + float mean = 0.0f, squ_m = 0.0f; + for (int index = 0; index < hw_plane; ++index) { + float tmp = src[index * c_res]; + mean += tmp; + squ_m += tmp * tmp; + } + mean /= (float)hw_plane; + squ_m /= (float)hw_plane; + const float deno = gamma_data[c] / sqrtf(squ_m - mean * mean + param->epsilon_); + for (int index = 0; index < hw_plane; ++index) { + dst[index * channel] = (src[index * c_res] - mean) * deno + beta_data[c]; + } + } + } + return NNACL_OK; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/instance_norm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/instance_norm_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..6d908e59876c1e1d304c3bdb257fe7ac5d53e42a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/instance_norm_fp32.h @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_INSTANCE_NORM_H_ +#define MINDSPORE_NNACL_FP32_INSTANCE_NORM_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/instance_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define MS_ADDQ_F32_VEC(in1, in2, in3, in4, v1, v2, v3, v4) \ + in1 = MS_ADDQ_F32(in1, v1); \ + in2 = MS_ADDQ_F32(in2, v2); \ + in3 = MS_ADDQ_F32(in3, v3); \ + in4 = MS_ADDQ_F32(in4, v4); + +#define MS_DIVQ_F32_VEC(in1, in2, in3, in4, v) \ + in1 = MS_DIVQ_F32(in1, v); \ + in2 = MS_DIVQ_F32(in2, v); \ + in3 = MS_DIVQ_F32(in3, v); \ + in4 = MS_DIVQ_F32(in4, v); + +int InstanceNorm(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id); +int InstanceNormNC4HW4(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id); +#ifdef ENABLE_AVX +int InstanceNormNC8HW8(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id); +#endif +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_INSTANCE_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/invert_permutation_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/invert_permutation_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..04662b546eab733d1376d9e7689e107c4e47478c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/invert_permutation_fp32.c @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/invert_permutation_fp32.h" +#include "nnacl_c/errorcode.h" + +int InvertPermutation(const int32_t *input, int32_t *output, size_t num) { + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + for (size_t i = 0; i < num; i++) { + size_t index = (size_t)input[i]; + if (index >= num) { + return NNACL_ERR; + } + output[index] = i; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/invert_permutation_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/invert_permutation_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..d9cf191775c7e0d8de2d0ac267f4482ffdceb371 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/invert_permutation_fp32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INVERT_PERMUTATION_FP32_H_ +#define MINDSPORE_NNACL_INVERT_PERMUTATION_FP32_H_ + +#include +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +int InvertPermutation(const int32_t *input, int32_t *output, size_t num); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_INVERT_PERMUTATION_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/l2_norm_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/l2_norm_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..56087123af660ea22a0945191c746c43c18acc9c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/l2_norm_fp32.c @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/l2_norm_fp32.h" +#include +#include "nnacl_c/errorcode.h" + +int CalcThreadSquareSum(const float *input_ptr, float *sum, int begin, int end) { + *sum = 0.0f; + int i; + for (i = begin; i < end; ++i) { + *sum += input_ptr[i] * input_ptr[i]; + } + return NNACL_OK; +} + +int ThreadDivSqrtSum(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const float sqrt_sum, + const int begin, const int end) { + bool is_relu = param->act_type_ == ActType_Relu; + bool is_relu6 = param->act_type_ == ActType_Relu6; + int i; + if (sqrt_sum == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + for (i = begin; i < end; i++) { + float tmp = input_ptr[i] / sqrt_sum; + if (is_relu) { + output_ptr[i] = MSMAX(0, tmp); + } else if (is_relu6) { + output_ptr[i] = MSMIN(6, MSMAX(0, tmp)); + } else { + output_ptr[i] = tmp; + } + } + return NNACL_OK; +} + +int ThreadTrailingAxis(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const int begin, + const int end) { + bool is_relu = param->act_type_ == ActType_Relu; + bool is_relu6 = param->act_type_ == ActType_Relu6; + + const int c = param->shape_[param->shape_num_ - 1]; + int i = 0; + for (i = begin; i < end; ++i) { + float square_sum = 0.0f; + int j = 0; + for (j = 0; j < c; ++j) { + const float val = input_ptr[i * c + j]; + square_sum += val * val; + } + float sqrt_sum = sqrtf(square_sum > param->epsilon_ ? square_sum : param->epsilon_); + for (j = 0; j < c; ++j) { + float tmp = input_ptr[i * c + j] / sqrt_sum; + if (is_relu) { + output_ptr[i * c + j] = MSMAX(0, tmp); + } else if (is_relu6) { + output_ptr[i * c + j] = MSMIN(6, MSMAX(0, tmp)); + } else { + output_ptr[i * c + j] = tmp; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/l2_norm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/l2_norm_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..2af8506d29eab9f7881e1725a5610f75c650cb33 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/l2_norm_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_L2NORM_FP32_H_ +#define MINDSPORE_NNACL_FP32_L2NORM_FP32_H_ + +#include "nnacl_c/l2_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int CalcThreadSquareSum(const float *input_ptr, float *sum, int begin, int end); +int ThreadDivSqrtSum(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const float sqrt_sum, + const int begin, const int end); +int ThreadTrailingAxis(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const int begin, + const int end); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_L2NORM_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..e311fec57a36951f2679740fa722b2e9383a1aa3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32.c @@ -0,0 +1,93 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/layer_norm_fp32.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/layer_norm_fp32_simd.h" + +int LayerNormMeanAndSquare(const float *src, int num, float *mean, float *variance) { + if (num <= 0) { + return NNACL_ERR; + } + int index = 0; + float square_mean = 0.f; + + SIMD_RUN_NO_SCALAR(LayerNormMeanAndSquare, index, src, num, mean, &square_mean); + + for (; index < num; index++) { + *mean += src[index]; + square_mean += src[index] * src[index]; + } + *mean /= (float)num; + square_mean /= (float)num; + *variance = square_mean - (*mean) * (*mean); + return NNACL_OK; +} + +void LayerNormGammaAndBeta(float *dst, const float *src, const float *gamma_data, const float *beta_data, int num, + const float mean, const float deno) { + int index = 0; + + SIMD_RUN_NO_SCALAR(LayerNormGammaAndBeta, index, dst, src, gamma_data, beta_data, num, mean, deno); + + for (; index < num; index++) { + dst[index] = (src[index] - mean) * (deno); + dst[index] = dst[index] * gamma_data[index] + beta_data[index]; + } +} + +int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data, float *out_mean, + float *out_variance, const LayerNormComputeParam *param, int task_id, int thread_num) { + if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) { + return NNACL_NULL_PTR; + } + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_ZERO_RETURN_ERR(param->params_inner_size_); + NNACL_CHECK_ZERO_RETURN_ERR(param->params_outer_size_); + int step = UP_DIV(param->norm_outer_size_, thread_num); + int thread_end = MSMIN(((int)task_id + 1) * step, param->norm_outer_size_); + for (int i = task_id * step; i < thread_end; i++) { + const float *src_norm = src_data + i * param->norm_inner_size_; + float *dst_norm = dst_data + i * param->norm_inner_size_; + float cur_mean = 0.0f; + float cur_variance = 0.0f; + int ret = LayerNormMeanAndSquare(src_norm, param->norm_inner_size_, &cur_mean, &cur_variance); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + if (out_mean != NULL) { + out_mean[i] = cur_mean; + } + if (out_variance != NULL) { + out_variance[i] = cur_variance; + } + const float deno = 1 / sqrtf(cur_variance + param->epsilon_); + if (param->norm_outer_size_ <= param->params_outer_size_) { + for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) { + const float *src_param = src_norm + x * param->params_inner_size_; + float *dst_param = dst_norm + x * param->params_inner_size_; + LayerNormGammaAndBeta(dst_param, src_param, gamma_data, beta_data, param->params_inner_size_, cur_mean, deno); + } + } else { + int x = i / param->params_outer_size_; + const float *gamma = gamma_data + x * param->norm_inner_size_; + const float *beta = beta_data + x * param->norm_inner_size_; + LayerNormGammaAndBeta(dst_norm, src_norm, gamma, beta, param->norm_inner_size_, cur_mean, deno); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..cfb0aaf257318b8beffa67441d82a53c76ccf1af --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_LAYER_NORM_FP32_H_ +#define NNACL_FP32_LAYER_NORM_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/layer_norm_parameter.h" +#include "nnacl_c/kernel/layer_norm.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data, float *out_mean, + float *out_variance, const LayerNormComputeParam *param, int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_LAYER_NORM_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..e05c56667c8c1b5144909689c21c3e1b2a20e750 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32_simd.h.in @@ -0,0 +1,61 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_LAYER_NORM_FP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_LAYER_NORM_FP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int LayerNormMeanAndSquare@SIMD_INSTRUCTION@(int index, const float *src, int num, float *mean, float *square_mean) { + if (num >= 4 * BLOCK_NUM) { + SIMD_F32 sum_val = SIMD_SET0_F32; + SIMD_F32 square_sum_val = SIMD_SET0_F32; + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 value = SIMD_LD_F32(src + index); + SIMD_F32 square_value = SIMD_MUL_F32(value, value); + sum_val = SIMD_ADD_F32(sum_val, value); + square_sum_val = SIMD_ADD_F32(square_sum_val, square_value); + } + *mean += SIMD_GET_SUM_F32(sum_val); + *square_mean += SIMD_GET_SUM_F32(square_sum_val); + } + return index; +} + +static inline int LayerNormGammaAndBeta@SIMD_INSTRUCTION@(int index, float *dst, const float *src, const float *gamma_data, + const float *beta_data, int num, const float mean, const float deno) { + SIMD_F32 mean_val = SIMD_MOV_F32(mean); + SIMD_F32 deno_val = SIMD_MOV_F32(deno); + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 value = SIMD_LD_F32(src + index); + SIMD_F32 out_value = SIMD_SUB_F32(value, mean_val); + out_value = SIMD_MUL_F32(out_value, deno_val); + out_value = SIMD_FMADD_F32(out_value, SIMD_LD_F32(gamma_data + index), SIMD_LD_F32(beta_data + index)); + SIMD_ST_F32(dst + index, out_value); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/local_response_norm_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/local_response_norm_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..25cdba7e594789b5eba4eb7b33f7fb5e4d843a58 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/local_response_norm_fp32.c @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/local_response_norm_fp32.h" +#include +#include "nnacl_c/errorcode.h" + +int LocalResponseNorm(const float *input_ptr, int out_size, int channel, float *output_ptr, + const LocalResponseNormParameter *param) { + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + NNACL_CHECK_NULL_RETURN_ERR(param); + int64_t depth_radius = param->depth_radius_; + float bias = param->bias_; + float alpha = param->alpha_; + float beta = param->beta_; + + for (int i = 0; i < out_size; i++) { + const float *in_data = input_ptr + i * channel; + float *out_data = output_ptr + i * channel; + // border_left + for (int j = 0; j < MSMIN(depth_radius, channel); j++) { + int left = MSMAX(0, j - depth_radius); + int right = MSMIN(channel - 1, j + depth_radius); + float sum = 0.0f; + for (int k = left; k <= right; k++) { + const float in_val = in_data[k]; + sum += in_val * in_val; + } + out_data[j] = in_data[j] * (float)(powf(sum * alpha + bias, -beta)); + } + // center + if (2 * depth_radius + 1 < channel) { + float tmp_sum = 0.0f; + for (int j = 0; j < depth_radius * 2 + 1; ++j) { + tmp_sum += in_data[j] * in_data[j]; + } + out_data[depth_radius] = in_data[depth_radius] * (powf(tmp_sum * alpha + bias, -beta)); + for (int j = depth_radius + 1; j < channel - depth_radius; ++j) { + tmp_sum -= in_data[j - depth_radius - 1] * in_data[j - depth_radius - 1]; + tmp_sum += in_data[j + depth_radius] * in_data[j + depth_radius]; + out_data[j] = in_data[j] * (float)(powf(tmp_sum * alpha + bias, -beta)); + } + } + // border_right + for (int j = MSMAX(0, channel - depth_radius); j < channel; j++) { + int left = MSMAX(0, j - depth_radius); + int right = MSMIN(channel - 1, j + depth_radius); + float sum = 0.0f; + for (int k = left; k <= right; k++) { + const float in_val = in_data[k]; + sum += in_val * in_val; + } + out_data[j] = in_data[j] * (float)(powf(sum * alpha + bias, -beta)); + } + } + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/local_response_norm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/local_response_norm_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..73448cf10e9c565ed9df948c14470eb9517987e6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/local_response_norm_fp32.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_LOCAL_RESPONSE_NORM_FP32_H_ +#define NNACL_FP32_LOCAL_RESPONSE_NORM_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/local_response_norm_parameter.h" + +int LocalResponseNorm(const float *input_ptr, int out_size, int channel, float *output_ptr, + const LocalResponseNormParameter *param); + +#endif // NNACL_FP32_LOCAL_RESPONSE_NORM_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/log_softmax_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/log_softmax_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..4f79799a15304311dafb5a335d34d8bfa568c192 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/log_softmax_fp32.c @@ -0,0 +1,85 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/log_softmax_fp32.h" +#include +#include "nnacl_c/fp32/softmax_fp32.h" +#include "nnacl_c/fp32/exp_fp32.h" + +void LogSoftmaxLastAxis(const float *src, float *dst, float *exp_data, int batch, int channel) { + SoftmaxNorm(src, dst, batch, channel); + ExpFp32(dst, exp_data, batch * channel); + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + float sum = 0; + int j = 0; +#ifdef ENABLE_NEON + float32x4_t sum4 = vdupq_n_f32(0); + int count = (channel / C4NUM) * C4NUM; + for (; j < count; j += C4NUM) { + sum4 = vaddq_f32(sum4, vld1q_f32(exp_data + cur_batch_offset + j)); + } + sum = sum4[0] + sum4[1] + sum4[2] + sum4[3]; +#endif + for (; j < channel; j++) { + sum += exp_data[cur_batch_offset + j]; + } + for (int k = 0; k < channel; k++) { + dst[cur_batch_offset + k] = dst[cur_batch_offset + k] - logf(sum); + } + } +} + +// output = (input - reduce_max(input, axis)) - log(reduce_sum(exp(input - reduce_max(input, axis)), axis)) +void LogSoftmax(const float *input_ptr, float *output_ptr, float *sum_data, int32_t *input_shape, int n_dim, int axis) { + int inner_size = 1; + int outter_size = 1; + + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + float max_data = input_ptr[inner_offset]; + sum_data[k + sum_outter_offset] = 0; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = input_ptr[axis_offset] - max_data; + sum_data[k + sum_outter_offset] += expf(output_ptr[axis_offset]); + } + } + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] - logf(sum_data[k + sum_outter_offset]); + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/log_softmax_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/log_softmax_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..9715999d24d8e6d185466594078c66ae84e007b2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/log_softmax_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_LOG_SOFTMAX_FP32_H_ +#define NNACL_FP32_LOG_SOFTMAX_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/softmax_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif +void LogSoftmax(const float *input_ptr, float *output_ptr, float *sum_data, int32_t *input_shape, int n_dim, int axis); +void LogSoftmaxLastAxis(const float *src, float *dst, float *exp_data, int batch, int channel); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_LOG_SOFTMAX_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/lstm_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/lstm_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..932c66008bc38592853c89da2f39acb6cdd41f71 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/lstm_fp32.c @@ -0,0 +1,328 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/lstm_fp32.h" +#include +#include +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +static void PackLstmMatrix(const float *src_batch, float *dst_batch, int col, int deep) { +#ifdef ENABLE_AVX + RowMajor2Col16Major(src_batch, dst_batch, col, deep); +#elif defined(ENABLE_ARM32) + RowMajor2Col4Major(src_batch, dst_batch, col, deep); +#else + RowMajor2Col8Major(src_batch, dst_batch, col, deep); +#endif +} + +static void PackLstmWeightBatch(float *dst, const float *src, int batch, int deep, int col, int col_align, + const int32_t *order) { + for (int i = 0; i < batch; i++) { + const float *src_batch = src + i * col * deep; + float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align * deep; + PackLstmMatrix(src_batch, dst_batch, col, deep); + } +} + +void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int32_t *order) { + PackLstmWeightBatch(dst, src, batch, deep, col, col_align, order); +} + +void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, int col, int col_align, + bool is_bidirectional, int stride, const int32_t *order) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + PackLstmWeightBatch(dst, src, unidirectional_batch, deep, col, col_align, order); + src += stride; + dst += unidirectional_batch * col_align * deep; + if (is_bidirectional) { + PackLstmWeightBatch(dst, src, unidirectional_batch, deep, col, col_align, order); + } +} + +void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + for (int i = 0; i < unidirectional_batch; i++) { + const float *src_batch = src + i * col; + float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align; + (void)memcpy(dst_batch, src_batch, col * sizeof(float)); + } + if (is_bidirectional) { + const float *backward_src = src + batch * col; + float *backward_dst = dst + unidirectional_batch * col_align; + for (int i = 0; i < unidirectional_batch; i++) { + const float *backward_src_batch = backward_src + i * col; + float *backward_dst_batch = backward_dst + ((order == NULL) ? i : order[i]) * col_align; + (void)memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float)); + } + } +} + +void PackLstmBiasWithStride(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + int b_stride, const int32_t *order) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + for (int i = 0; i < unidirectional_batch; i++) { + const float *src_batch = src + i * col; + float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align; + (void)memcpy(dst_batch, src_batch, col * sizeof(float)); + } + if (is_bidirectional) { + const float *backward_src = src + b_stride; + float *backward_dst = dst + unidirectional_batch * col_align; + for (int i = 0; i < unidirectional_batch; i++) { + const float *backward_src_batch = backward_src + i * col; + float *backward_dst_batch = backward_dst + ((order == NULL) ? i : order[i]) * col_align; + (void)memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float)); + } + } +} + +void PackLstmInput(const float *src, float *dst, int row, int deep) { +#ifdef ENABLE_AVX + RowMajor2Col6Major(src, dst, row, deep); +#elif defined(ENABLE_SSE) + RowMajor2Col4Major(src, dst, row, deep); +#else + RowMajor2Col12Major(src, dst, row, deep); +#endif +} + +void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, int col_align, + bool is_vec, float *packed_ptr) { + if (is_vec) { +#ifdef ENABLE_AVX + bool need_packed = col % C8NUM; + if (!need_packed) { + packed_ptr = c; + } + MatVecMulAvxFp32(a, b, packed_ptr, bias, ActType_No, deep, col, col_align); + if (need_packed) { + PackNHWCXToNHWCFp32(packed_ptr, c, 1, row, col, C8NUM); + } +#else + MatVecMulFp32(a, b, c, bias, ActType_No, deep, col); +#endif + } else { + MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); + } +} + +void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size) { + int index = 0; +#ifdef ENABLE_ARM + for (; index <= element_size - 4; index += 4) { + float32x4_t in_0 = vld1q_f32(input0 + index); + float32x4_t in_1 = vld1q_f32(input1 + index); + float32x4_t out = vld1q_f32(output + index); + out = vmlaq_f32(out, in_1, in_0); + vst1q_f32(output + index, out); + } +#endif + for (; index < element_size; index++) { + output[index] += input0[index] * input1[index]; + } +} + +int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 4; index += C4NUM) { + float32x4_t vin0 = vld1q_f32(input0 + index); + float32x4_t vout = vld1q_f32(output + index); + vout = vmlaq_n_f32(vout, vin0, input1); + vst1q_f32(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] += input0[index] * input1; + } + return NNACL_OK; +} + +void UpdateState(float *cell_state, const float *forget_gate, const float *input_gate, const float *cell_gate, + float *state_buffer, int batch, int hidden_size, const float zoneout) { + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { // zoneout * old_cell_state + (void)memcpy(state_buffer, cell_state, batch * hidden_size * sizeof(float)); + ElementOptMul(state_buffer, &zoneout, state_buffer, batch * hidden_size, false); + } + + ElementMul(forget_gate, cell_state, cell_state, batch * hidden_size); + ElementMulAcc(input_gate, cell_gate, cell_state, batch * hidden_size); + + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { // (1 - zoneout) * new_cell_state + ElementOptMulAcc(cell_state, 1 - zoneout, state_buffer, batch * hidden_size); + } +} + +void UpdateOutput(float *hidden_state, float *output, const float *cell_state, const float *output_gate, + const float *weight_project, float *buffer[C8NUM], const LstmParameter *lstm_param) { + int batch = lstm_param->batch_; + int hidden_size = lstm_param->hidden_size_; + int output_size = lstm_param->output_size_; + float *state_buffer = buffer[C4NUM]; + float *hidden_buffer = weight_project ? buffer[C2NUM] : hidden_state; + float zoneout = lstm_param->zoneout_hidden_; + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { + (void)memcpy(state_buffer, hidden_state, batch * output_size * sizeof(float)); + ElementOptMul(state_buffer, &zoneout, state_buffer, batch * output_size, false); + } + + Tanh(cell_state, batch * hidden_size, hidden_buffer); + ElementMul(hidden_buffer, output_gate, hidden_buffer, batch * hidden_size); + + if (weight_project) { + float *left_matrix = hidden_buffer; + if (batch != 1) { + left_matrix = buffer[C6NUM]; + PackLstmInput(hidden_buffer, left_matrix, batch, hidden_size); + } + LstmMatMul(hidden_state, left_matrix, weight_project, NULL, batch, hidden_size, output_size, + lstm_param->proj_col_align_, batch == 1, buffer[C7NUM]); + } + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { + ElementOptMulAcc(hidden_state, 1 - zoneout, state_buffer, batch * output_size); + } + (void)memcpy(output, hidden_state, batch * output_size * sizeof(float)); +} + +void UpdateLstmGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep, + int col, int col_align, bool is_vec, float *packed_ptr) { + const float *weight_i = weight; + const float *bias_i = bias; + float *gate_i = gate_buffer; + for (int i = 0; i < 4; i++) { + LstmMatMul(gate_i, input, weight_i, bias_i, row, deep, col, col_align, is_vec, packed_ptr); + +#ifdef ENABLE_AVX + weight_i += deep * col_align; +#else + weight_i += deep * (is_vec ? col : col_align); +#endif + bias_i += col_align; + gate_i += row * col; + } +} + +void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate, + const float *state_weight, const float *state_bias, const float *weight_project, float *hidden_state, + float *cell_state, float *buffer[C8NUM], const LstmParameter *lstm_param) { + float *packed_state = buffer[1]; + float *state_gate = buffer[C2NUM]; + float *cell_buffer = buffer[C3NUM]; + float *hidden_buffer = buffer[C4NUM]; + float *packed_output = buffer[C5NUM]; + bool is_vec = lstm_param->batch_ == 1; + // state * weight + if (is_vec) { + UpdateLstmGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec, packed_output); + } else { + // pack state for matmul + PackLstmInput(hidden_state, packed_state, lstm_param->batch_, lstm_param->output_size_); + UpdateLstmGate(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec, packed_output); + } + ElementAdd(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); + ElementAdd(forget_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 2, forget_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAdd(cell_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 3, cell_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAdd(output_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_, output_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + + // update input_gate + Sigmoid(input_gate, lstm_param->batch_ * lstm_param->hidden_size_, input_gate); + + // update forget_gate + Sigmoid(forget_gate, lstm_param->batch_ * lstm_param->hidden_size_, forget_gate); + + // update cell_gate + Tanh(cell_gate, lstm_param->batch_ * lstm_param->hidden_size_, cell_gate); + // update cell state + UpdateState(cell_state, forget_gate, input_gate, cell_gate, cell_buffer, lstm_param->batch_, lstm_param->hidden_size_, + lstm_param->zoneout_cell_); + + // update output_gate + Sigmoid(output_gate, lstm_param->batch_ * lstm_param->hidden_size_, output_gate); + // update output + UpdateOutput(hidden_state, output, cell_state, output_gate, weight_project, buffer, lstm_param); + + if (!(lstm_param->zoneout_cell_ >= -FLT_EPSILON && lstm_param->zoneout_cell_ <= FLT_EPSILON)) { + (void)memcpy(cell_state, cell_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); + } + + if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) { + (void)memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->output_size_ * sizeof(float)); + } +} + +void LstmUnidirectional(float *output, const float *packed_input, const float *weight_i, const float *weight_h, + const float *input_bias, const float *state_bias, float *hidden_state, float *cell_state, + float *buffer[C8NUM], const LstmParameter *lstm_param, bool is_backward) { + float *gate = buffer[0]; + for (int i = 0; i < 4; i++) { + const float *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i; + const float *bias_loop = input_bias + lstm_param->input_col_align_ * i; + float *gate_loop = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * i; + MatMulOpt(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, + lstm_param->seq_len_ * lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, + OutType_Nhwc); + } + + float *input_gate = gate; + float *forget_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 2; + float *cell_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 3; + float *output_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_; + for (int t = 0; t < lstm_param->seq_len_; t++) { + int real_t = is_backward ? lstm_param->seq_len_ - t - 1 : t; + float *input_gate_t = input_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *forget_gate_t = forget_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *cell_gate_t = cell_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *output_gate_t = output_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *output_ptr = output + real_t * lstm_param->output_step_; + LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, NULL, + hidden_state, cell_state, buffer, lstm_param); + } +} + +void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, + const float *state_bias, float *hidden_state, float *cell_state, float *buffer[C9NUM], + const LstmParameter *lstm_param) { + // forward + float *packed_input = buffer[0]; + buffer += 1; + PackLstmInput(input, packed_input, lstm_param->seq_len_ * lstm_param->batch_, lstm_param->input_size_); + LstmUnidirectional(output, packed_input, weight_i, weight_h, input_bias, state_bias, hidden_state, cell_state, buffer, + lstm_param, false); + + // backward + if (lstm_param->bidirectional_) { + const float *backward_weight_i = weight_i + 4 * lstm_param->input_col_align_ * lstm_param->input_size_; + const float *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->output_size_; + const float *backward_input_bias = input_bias + 4 * lstm_param->input_col_align_; + const float *backward_state_bias = state_bias + 4 * lstm_param->state_col_align_; + float *backward_output = output + lstm_param->batch_ * lstm_param->output_size_; + float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; + float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->output_size_; + + LstmUnidirectional(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias, + backward_state_bias, backward_hidden_state, backward_cell_state, buffer, lstm_param, true); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/lstm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/lstm_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..f439bb6b46a15bc76ca1f9b38ca50695d34543c3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/lstm_fp32.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_LSTM_H_ +#define MINDSPORE_NNACL_FP32_LSTM_H_ + +#include "nnacl_c/lstm_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif +void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int32_t *order); + +void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, int col, int col_align, + bool is_bidirectional, int stride, const int32_t *order); + +void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order); + +void PackLstmBiasWithStride(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + int b_stride, const int32_t *order); + +void PackLstmInput(const float *src, float *dst, int row, int deep); + +void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, int col_align, + bool is_vec, float *packed_ptr); + +void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size); + +int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size); + +void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate, + const float *state_weight, const float *state_bias, const float *weight_project, float *hidden_state, + float *cell_state, float *buffer[C8NUM], const LstmParameter *lstm_param); + +void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, + const float *state_bias, float *hidden_state, float *cell_state, float *buffer[C9NUM], + const LstmParameter *lstm_param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_LSTM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..9fa54e672e16976189b4581a959bd8a673031220 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_fp32.c @@ -0,0 +1,248 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/matmul_avx512_fp32.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void GemmRowxColKernelFp32(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, + const size_t row_block, const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag) { + __m512 dst_data[27]; + const float *src_sw[20]; + __m512 weight_data[6]; + for (int i = 0; i < C6NUM; ++i) { + weight_data[i] = _mm512_set1_ps(0.0f); + } + for (int i = 0; i < row_block; ++i) { + if (inc_flag & 0x01) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_loadu_ps(dst + i * dst_stride + j * C16NUM); + } + } else if (bias != NULL) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_loadu_ps(bias + j * C16NUM); + } + } else { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_set1_ps(0.0f); + } + } + src_sw[i] = src + i * src_stride; + } + const float *weight_kernel = weight; + for (int k = 0; k < depth; ++k) { + for (int j = 0; j < col_block; ++j) { + weight_data[j] = _mm512_loadu_ps(weight_kernel + j * C16NUM); + } + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = + _mm512_fmadd_ps(_mm512_set1_ps(src_sw[i][k]), weight_data[j], dst_data[i * col_block + j]); + } + } + weight_kernel += C16NUM * col_block; + } // k loop + // add bias and relu + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + if (inc_flag & 0x02) { + if (0x1 & act_flag) { // relu6 + dst_data[i * col_block + j] = _mm512_min_ps(dst_data[i * col_block + j], _mm512_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * col_block + j] = _mm512_max_ps(dst_data[i * col_block + j], _mm512_set1_ps(0.0f)); + } + } + _mm512_storeu_ps(dst + i * dst_stride + j * C16NUM, dst_data[i * col_block + j]); + } + } +} + +void MatMulAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, const int depth, + const int cur_col, const int col_align, const int row) { + int k_block = C1500NUM; + int act_flag = 0; + if (act_type == ActType_Relu6) { + act_flag += 1; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } + GemmAvx512Kernel kernel[C4NUM][C13NUM]; + int max_shape[C4NUM] = {C12NUM, C12NUM, C8NUM, C6NUM}; + +#ifdef ENABLE_DEBUG + for (int i = 0; i < C4NUM; i++) { + for (int j = 0; j < C13NUM; j++) { + kernel[i][j] = GemmRowxColKernelFp32; + } + } +#else + kernel[0][1] = nnacl_gemm_avx512_1x16_kernel_nhwc_fp32; + kernel[0][2] = nnacl_gemm_avx512_2x16_kernel_nhwc_fp32; + kernel[0][3] = nnacl_gemm_avx512_3x16_kernel_nhwc_fp32; + kernel[0][4] = nnacl_gemm_avx512_4x16_kernel_nhwc_fp32; + kernel[0][5] = nnacl_gemm_avx512_5x16_kernel_nhwc_fp32; + kernel[0][6] = nnacl_gemm_avx512_6x16_kernel_nhwc_fp32; + kernel[0][7] = nnacl_gemm_avx512_7x16_kernel_nhwc_fp32; + kernel[0][8] = nnacl_gemm_avx512_8x16_kernel_nhwc_fp32; + kernel[0][9] = nnacl_gemm_avx512_9x16_kernel_nhwc_fp32; + kernel[0][10] = nnacl_gemm_avx512_10x16_kernel_nhwc_fp32; + kernel[0][11] = nnacl_gemm_avx512_11x16_kernel_nhwc_fp32; + kernel[0][12] = nnacl_gemm_avx512_12x16_kernel_nhwc_fp32; + + kernel[1][1] = nnacl_gemm_avx512_1x32_kernel_nhwc_fp32; + kernel[1][2] = nnacl_gemm_avx512_2x32_kernel_nhwc_fp32; + kernel[1][3] = nnacl_gemm_avx512_3x32_kernel_nhwc_fp32; + kernel[1][4] = nnacl_gemm_avx512_4x32_kernel_nhwc_fp32; + kernel[1][5] = nnacl_gemm_avx512_5x32_kernel_nhwc_fp32; + kernel[1][6] = nnacl_gemm_avx512_6x32_kernel_nhwc_fp32; + kernel[1][7] = nnacl_gemm_avx512_7x32_kernel_nhwc_fp32; + kernel[1][8] = nnacl_gemm_avx512_8x32_kernel_nhwc_fp32; + kernel[1][9] = nnacl_gemm_avx512_9x32_kernel_nhwc_fp32; + kernel[1][10] = nnacl_gemm_avx512_10x32_kernel_nhwc_fp32; + kernel[1][11] = nnacl_gemm_avx512_11x32_kernel_nhwc_fp32; + kernel[1][12] = nnacl_gemm_avx512_12x32_kernel_nhwc_fp32; + + kernel[2][1] = nnacl_gemm_avx512_1x48_kernel_nhwc_fp32; + kernel[2][2] = nnacl_gemm_avx512_2x48_kernel_nhwc_fp32; + kernel[2][3] = nnacl_gemm_avx512_3x48_kernel_nhwc_fp32; + kernel[2][4] = nnacl_gemm_avx512_4x48_kernel_nhwc_fp32; + kernel[2][5] = nnacl_gemm_avx512_5x48_kernel_nhwc_fp32; + kernel[2][6] = nnacl_gemm_avx512_6x48_kernel_nhwc_fp32; + kernel[2][7] = nnacl_gemm_avx512_7x48_kernel_nhwc_fp32; + kernel[2][8] = nnacl_gemm_avx512_8x48_kernel_nhwc_fp32; + + kernel[3][1] = nnacl_gemm_avx512_1x64_kernel_nhwc_fp32; + kernel[3][2] = nnacl_gemm_avx512_2x64_kernel_nhwc_fp32; + kernel[3][3] = nnacl_gemm_avx512_3x64_kernel_nhwc_fp32; + kernel[3][4] = nnacl_gemm_avx512_4x64_kernel_nhwc_fp32; + kernel[3][5] = nnacl_gemm_avx512_5x64_kernel_nhwc_fp32; + kernel[3][6] = nnacl_gemm_avx512_6x64_kernel_nhwc_fp32; +#endif + + int inc_flag; + for (int k = 0; k < depth; k += k_block) { + if (depth - k <= k_block) { + k_block = depth - k; + inc_flag = C3NUM - (k == 0); + } else { + inc_flag = 1 - (k == 0); + } + const float *bias_data = bias; + // one time process 64 out_channel + int col_block = C64NUM; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + col_block = MSMIN(col_block, cur_col - col_index); + int row_block = max_shape[(col_block >> C4NUM) - 1]; + for (int m = 0; m < row; m += row_block) { + row_block = MSMIN(row_block, row - m); + kernel[(col_block >> C4NUM) - 1][row_block](c + col_index + m * col_align, a + m * depth + k, + b + col_index * depth + k * col_block, bias_data, act_flag, + row_block, col_block >> C4NUM, k_block, depth, col_align, inc_flag); + } + if (bias_data != NULL) { + bias_data += col_block; + } + } + } +} + +void MatVecMulAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, + const int depth, const int cur_col, const int col_align) { + // one time process 64 out_channel + int k_block = C1500NUM; + int act_flag = 0; + if (act_type == ActType_Relu6) { + act_flag += 1; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } +#ifdef ENABLE_DEBUG + GemmAvx512Kernel kernel[C4NUM] = {GemmRowxColKernelFp32, GemmRowxColKernelFp32, GemmRowxColKernelFp32, + GemmRowxColKernelFp32}; +#else + GemmAvx512Kernel kernel[C4NUM] = {nnacl_gemm_avx512_1x16_kernel_nhwc_fp32, nnacl_gemm_avx512_1x32_kernel_nhwc_fp32, + nnacl_gemm_avx512_1x48_kernel_nhwc_fp32, nnacl_gemm_avx512_1x64_kernel_nhwc_fp32}; +#endif + int inc_flag; + for (int k = 0; k < depth; k += k_block) { + if (depth - k <= k_block) { + k_block = depth - k; + inc_flag = C3NUM - (k == 0); + } else { + inc_flag = 1 - (k == 0); + } + const float *bias_data = bias; + int col_block = C64NUM; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + col_block = MSMIN(col_block, cur_col - col_index); + kernel[(col_block >> C4NUM) - 1](c + col_index, a + k, b + col_index * depth + k * col_block, bias_data, act_flag, + 1, col_block >> C4NUM, k_block, depth, col_align, inc_flag); + if (bias_data != NULL) { + bias_data += col_block; + } + } + } +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +int64_t GemmIsNotPackOptimizeAVX512(int64_t m_index, const float *a, const float *b, float *c, const float *bias, int m, + int k, int act_type) { + // gemm dot is [m, k] * [k, 1] ==>> [m, 1] + // block 8 + MS_FLOAT32X8 down_threshold256 = _mm256_setzero_ps(); + MS_FLOAT32X8 up_threshold256 = _mm256_set1_ps(C6NUM); + for (; m_index <= m - C8NUM; m_index += C8NUM) { + int k_index = 0; + MS_FLOAT32X8 dst = MS_MOV256_F32(bias[0]); + MS_SET_ZERO512X8_F32(dst16_) + for (; k_index <= k - C16NUM; k_index += C16NUM) { + __m512 weight = _mm512_loadu_ps(b + k_index); + MS_LOAD512X8_F32(src, a + m_index * k + k_index, k) + MS_FMADD512X8_F32(src, weight, dst16_) + } + MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD512_F32(dst16_1); + MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD512_F32(dst16_2); + MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD512_F32(dst16_3); + MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD512_F32(dst16_4); + MS_F32X8_GETI(dst, C4NUM) += MS_REDUCE_ADD512_F32(dst16_5); + MS_F32X8_GETI(dst, C5NUM) += MS_REDUCE_ADD512_F32(dst16_6); + MS_F32X8_GETI(dst, C6NUM) += MS_REDUCE_ADD512_F32(dst16_7); + MS_F32X8_GETI(dst, C7NUM) += MS_REDUCE_ADD512_F32(dst16_8); + for (; k_index < k; k_index++) { + MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index]; + MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k]; + MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k]; + MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k]; + MS_F32X8_GETI(dst, C4NUM) += b[k_index] * a[m_index * k + k_index + C4NUM * k]; + MS_F32X8_GETI(dst, C5NUM) += b[k_index] * a[m_index * k + k_index + C5NUM * k]; + MS_F32X8_GETI(dst, C6NUM) += b[k_index] * a[m_index * k + k_index + C6NUM * k]; + MS_F32X8_GETI(dst, C7NUM) += b[k_index] * a[m_index * k + k_index + C7NUM * k]; + } + + if (act_type != 0) { + dst = MS_MAX256_F32(dst, down_threshold256); + if (act_type == 3) { // 3: relu6 + dst = MS_MIN256_F32(dst, up_threshold256); + } + } + + MS_ST256_F32(c + m_index, dst); + } + return m_index; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..43c4d79c9c7f7b4769e29b9bfbee5f40f80305e0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_fp32.h @@ -0,0 +1,198 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_MATMUL_AVX512_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_AVX512_H_ +#include "nnacl_c/op_base.h" +typedef void (*GemmAvx512Kernel)(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +#ifdef __cplusplus +extern "C" { +#endif +void MatVecMulAvx512Fp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int cur_col, int col_align); + +void MatMulAvx512Fp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col, + int col_align, int row); + +int64_t GemmIsNotPackOptimizeAVX512(int64_t m_index, const float *a, const float *b, float *c, const float *bias, int m, + int k, int act_type); + +// 64 block +void nnacl_gemm_avx512_6x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_5x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_4x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_3x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_2x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_1x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); + +// 48 block +void nnacl_gemm_avx512_8x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_7x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_6x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_5x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_4x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_3x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_2x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_1x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); + +// 32 block +void nnacl_gemm_avx512_12x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_11x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_10x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_9x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_8x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_7x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_6x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_5x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_4x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_3x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_2x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_1x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); + +// 16 block +void nnacl_gemm_avx512_12x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_11x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_10x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_9x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_8x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_7x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_6x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_5x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_4x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_3x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_2x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_1x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_MATMUL_AVX512_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..4d761d831a2a1107ae9eb450bcfc14f85d1fd266 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.c @@ -0,0 +1,236 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void GemmRowxColMaskKernelFp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + __m512 dst_data[27]; + const float *src_sw[20]; + __m512 weight_data[6]; + __mmask16 mask16 = (__mmask16)(*mask); + for (int i = 0; i < C6NUM; ++i) { + weight_data[i] = _mm512_set1_ps(0.0f); + } + for (int i = 0; i < row_block; ++i) { + if (inc_flag & 0x01) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_loadu_ps(dst + i * dst_stride + j * C16NUM); + } + } else if (bias != NULL) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_loadu_ps(bias + j * C16NUM); + } + } else { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_set1_ps(0.0f); + } + } + src_sw[i] = src + i * src_stride; + } + const float *weight_kernel = weight; + for (int k = 0; k < depth; ++k) { + for (int j = 0; j < col_block; ++j) { + weight_data[j] = _mm512_loadu_ps(weight_kernel + j * C16NUM); + } + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + if (j == col_block - 1) { + dst_data[i * col_block + j] = + _mm512_mask3_fmadd_ps(_mm512_set1_ps(src_sw[i][k]), weight_data[j], dst_data[i * col_block + j], mask16); + } else { + dst_data[i * col_block + j] = + _mm512_fmadd_ps(_mm512_set1_ps(src_sw[i][k]), weight_data[j], dst_data[i * col_block + j]); + } + } + } + weight_kernel += C16NUM * col_block; + } // k loop + // add bias and relu + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + if (j == col_block - 1) { + if (inc_flag & 0x02) { + if (0x1 & act_flag) { // relu6 + dst_data[i * col_block + j] = + _mm512_maskz_min_ps(mask16, dst_data[i * col_block + j], _mm512_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * col_block + j] = + _mm512_maskz_max_ps(mask16, dst_data[i * col_block + j], _mm512_set1_ps(0.0f)); + } + } + _mm512_mask_storeu_ps(dst + i * dst_stride + j * C16NUM, mask16, dst_data[i * col_block + j]); + } else { + if (inc_flag & 0x02) { + if (0x1 & act_flag) { // relu6 + dst_data[i * col_block + j] = _mm512_min_ps(dst_data[i * col_block + j], _mm512_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * col_block + j] = _mm512_max_ps(dst_data[i * col_block + j], _mm512_set1_ps(0.0f)); + } + } + _mm512_storeu_ps(dst + i * dst_stride + j * C16NUM, dst_data[i * col_block + j]); + } + } + } +} + +void MatMulMaskAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, + const int depth, const int cur_col, const int col_, const int row) { + int k_block = C1500NUM; + int act_flag = 0; + if (act_type == ActType_Relu6) { + act_flag += 1; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } + GemmAvx512MaskKernel kernel[C4NUM][C13NUM]; + int max_shape[C4NUM] = {C12NUM, C12NUM, C8NUM, C6NUM}; + +#ifdef ENABLE_DEBUG + for (int i = 0; i < C4NUM; i++) { + for (int j = 0; j < C13NUM; j++) { + kernel[i][j] = GemmRowxColMaskKernelFp32; + } + } +#else + kernel[0][1] = nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32; + kernel[0][2] = nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32; + kernel[0][3] = nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32; + kernel[0][4] = nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32; + kernel[0][5] = nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32; + kernel[0][6] = nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32; + kernel[0][7] = nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32; + kernel[0][8] = nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32; + kernel[0][9] = nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32; + kernel[0][10] = nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32; + kernel[0][11] = nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32; + kernel[0][12] = nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32; + + kernel[1][1] = nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32; + kernel[1][2] = nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32; + kernel[1][3] = nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32; + kernel[1][4] = nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32; + kernel[1][5] = nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32; + kernel[1][6] = nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32; + kernel[1][7] = nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32; + kernel[1][8] = nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32; + kernel[1][9] = nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32; + kernel[1][10] = nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32; + kernel[1][11] = nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32; + kernel[1][12] = nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32; + + kernel[2][1] = nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32; + kernel[2][2] = nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32; + kernel[2][3] = nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32; + kernel[2][4] = nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32; + kernel[2][5] = nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32; + kernel[2][6] = nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32; + kernel[2][7] = nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32; + kernel[2][8] = nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32; + + kernel[3][1] = nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32; + kernel[3][2] = nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32; + kernel[3][3] = nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32; + kernel[3][4] = nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32; + kernel[3][5] = nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32; + kernel[3][6] = nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32; +#endif + + int inc_flag; + for (int k = 0; k < depth; k += k_block) { + if (depth - k <= k_block) { + k_block = depth - k; + inc_flag = C3NUM - (k == 0); + } else { + inc_flag = 1 - (k == 0); + } + const float *bias_data = bias; + // one time process 64 out_channel + int col_block = C64NUM; + u_int16_t avx512_mask = 0xFFFF; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + int less_tmp = cur_col - col_index; + if (less_tmp < col_block) { + col_block = UP_ROUND(less_tmp, C16NUM); + avx512_mask = (0xFFFF >> (col_block - less_tmp)); + } + int col_block_num = col_block >> C4NUM; + int row_block = max_shape[col_block_num - 1]; + for (int m = 0; m < row; m += row_block) { + row_block = MSMIN(row_block, row - m); + kernel[col_block_num - 1][row_block](c + col_index + m * col_, a + m * depth + k, + b + col_index * depth + k * col_block, bias_data, act_flag, row_block, + col_block_num, k_block, depth, col_, inc_flag, &avx512_mask); + } + if (bias_data != NULL) { + bias_data += col_block; + } + } + } +} + +void MatVecMulMaskAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, + const int depth, const int cur_col, const int col_) { + // one time process 64 out_channel + int k_block = C1500NUM; + int act_flag = 0; + if (act_type == ActType_Relu6) { + act_flag += 1; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } +#ifdef ENABLE_DEBUG + GemmAvx512MaskKernel kernel[C4NUM] = {GemmRowxColMaskKernelFp32, GemmRowxColMaskKernelFp32, GemmRowxColMaskKernelFp32, + GemmRowxColMaskKernelFp32}; +#else + GemmAvx512MaskKernel kernel[C4NUM] = { + nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32, nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32, + nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32, nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32}; +#endif + int inc_flag; + for (int k = 0; k < depth; k += k_block) { + if (depth - k <= k_block) { + k_block = depth - k; + inc_flag = C3NUM - (k == 0); + } else { + inc_flag = 1 - (k == 0); + } + const float *bias_data = bias; + int col_block = C64NUM; + u_int16_t avx512_mask = 0xFFFF; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + int less_tmp = cur_col - col_index; + if (less_tmp < col_block) { + col_block = UP_ROUND(less_tmp, C16NUM); + avx512_mask = (0xFFFF >> (col_block - less_tmp)); + } + int col_block_num = col_block >> C4NUM; + + kernel[col_block_num - 1](c + col_index, a + k, b + col_index * depth + k * col_block, bias_data, act_flag, 1, + col_block_num, k_block, depth, col_, inc_flag, &avx512_mask); + if (bias_data != NULL) { + bias_data += col_block; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..aa0fdfba20333b37b51f8c7c596d8e36632331f4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.h @@ -0,0 +1,209 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_MATMUL_MASK_AVX512_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_MASK_AVX512_H_ +#include +#include +#include "nnacl_c/op_base.h" +typedef void (*GemmAvx512MaskKernel)(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +#ifdef __cplusplus +extern "C" { +#endif + +void GemmRowxColMaskKernelFp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); + +void MatVecMulMaskAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, + const int depth, const int cur_col, const int col_); + +void MatMulMaskAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, + const int depth, const int cur_col, const int col_, const int row); + +// 64 block +void nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); + +// 48 block +void nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); + +// 32 block +void nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); + +// 16 block +void nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_MATMUL_AVX512_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7f49814ee702eaf24f33d083de113417fabd5771 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.c @@ -0,0 +1,954 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/matmul_avx_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_avx_instructions.h" + +void MatVecMulAvxFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col, + int col_align) { + // one time process 32 out_channel + int col_block = C32NUM; + int act_flag = C0NUM; + if (act_type == ActType_Relu6) { + act_flag += C1NUM; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } + MatVecMulKernel kernel[4] = {MatVecMul1x8Kernel, MatVecMul1x16Kernel, MatVecMul1x24Kernel, MatVecMul1x32Kernel}; + const float *bias_data = bias; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + col_block = cur_col - col_index < col_block ? cur_col - col_index : col_block; + kernel[(col_block >> C3NUM) - 1](c + col_index, a, b + col_index * depth, bias_data, act_flag, 1, + col_block >> C3NUM, col_align, depth); + if (bias_data != NULL) { + bias_data += col_block; + } + } +} + +void MatMulAvxFp32(const float *a, const float *b, float *c, const float *bias, const int act_type, const int depth, + const int cur_col, const int col_align, const int row) { + // one time process 32 out_channel + int col_block = C32NUM; + int act_flag = 0; + if (act_type == ActType_Relu6) { + act_flag += 1; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } + int row_tile[4] = {C8NUM, C6NUM, C4NUM, C3NUM}; + MatVecMulKernel kernel[4][2] = {{MatVecMul1x8Kernel, MatMul8x8Kernel}, + {MatVecMul1x16Kernel, MatMul6x16Kernel}, + {MatVecMul1x24Kernel, MatMul4x24Kernel}, + {MatVecMul1x32Kernel, MatMul3x32Kernel}}; + const float *bias_data = bias; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + col_block = cur_col - col_index < col_block ? cur_col - col_index : col_block; + int row_block = row_tile[(col_block >> C3NUM) - 1]; + for (int r = 0; r < row; r += row_block) { + if (row_block > row - r) { + row_block = 1; + } + kernel[(col_block >> C3NUM) - 1][row_block / row_tile[(col_block >> C3NUM) - 1]]( + c + col_index + r * col_align, a + r * depth, b + col_index * depth, bias_data, act_flag, row_block, + col_block >> C3NUM, col_align, depth); + } + if (bias_data != NULL) { + bias_data += col_block; + } + } +} + +void MatMul3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, + const size_t row_block, const size_t col_block, size_t col_algin, const size_t deep) { + col_algin *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups 0x40(%2), %%ymm6\n" + "vmovups 0x60(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups 0x40(%2), %%ymm10\n" + "vmovups 0x60(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "1:\n" // deep + "vbroadcastss (%0), %%ymm12\n" // src + "vbroadcastss (%0, %7), %%ymm13\n" + "vbroadcastss (%0, %7, 2), %%ymm14\n" + "vmovups (%1), %%ymm15\n" // weight + "vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm4\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm8\n" + + "vmovups 0x20(%1), %%ymm15\n" // weight + "vfmadd231ps %%ymm15, %%ymm12, %%ymm1\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm5\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm9\n" + + "vmovups 0x40(%1), %%ymm15\n" // weight + "vfmadd231ps %%ymm15, %%ymm12, %%ymm2\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm10\n" + + "vmovups 0x60(%1), %%ymm15\n" // weight + "vfmadd231ps %%ymm15, %%ymm12, %%ymm3\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm7\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm11\n" + "addq $128, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 1b\n" + + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + "vmovups %%ymm2, 0x40(%5)\n" + "vmovups %%ymm3, 0x60(%5)\n" + "vmovups %%ymm4, (%5, %6)\n" // dst_1 + "vmovups %%ymm5, 0x20(%5, %6)\n" + "vmovups %%ymm6, 0x40(%5, %6)\n" + "vmovups %%ymm7, 0x60(%5, %6)\n" + "vmovups %%ymm8, (%5, %6, 2)\n" // dst_2 + "vmovups %%ymm9, 0x20(%5, %6, 2)\n" + "vmovups %%ymm10, 0x40(%5, %6, 2)\n" + "vmovups %%ymm11, 0x60(%5, %6, 2)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst), "r"(col_algin), + "r"(deep * sizeof(float)) // 7 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void MatVecMul1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep) { + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "1:\n" // deep_c8 + "movq %3, %%rcx\n" + "shr $3, %%ecx\n" + "je 3f\n" + "2:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 4(%0), %%ymm4\n" + "vfmadd231ps 128(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 160(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 192(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 224(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 8(%0), %%ymm4\n" + "vfmadd231ps 256(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 288(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 320(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 352(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 12(%0), %%ymm4\n" + "vfmadd231ps 384(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 416(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 448(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 480(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 16(%0), %%ymm4\n" + "vfmadd231ps 512(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 544(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 576(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 608(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 20(%0), %%ymm4\n" + "vfmadd231ps 640(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 672(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 704(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 736(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 24(%0), %%ymm4\n" + "vfmadd231ps 768(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 800(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 832(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 864(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 28(%0), %%ymm4\n" + "vfmadd231ps 896(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 928(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 960(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 992(%1), %%ymm4, %%ymm3\n" + "addq $1024, %1\n" + "addq $32, %0\n" + "dec %%ecx\n" + "jg 2b\n" + + "3:\n" + "and $7, %3\n" // deep_remainder + "je 5f\n" + "4:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n" + "addq $128, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 4b\n" + + "5:\n" + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + "vmovups %%ymm2, 0x40(%5)\n" + "vmovups %%ymm3, 0x60(%5)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm4", "%ymm14"); +} + +void MatMul4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, + const size_t row_block, const size_t col_block, size_t col_algin, const size_t deep) { + float *dst_3 = dst + C3NUM * col_algin; + col_algin *= sizeof(float); + size_t src_3_step = C3NUM * deep * sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups (%2), %%ymm3\n" + "vmovups 0x20(%2), %%ymm4\n" + "vmovups 0x40(%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups 0x40(%2), %%ymm8\n" + "vmovups (%2), %%ymm9\n" + "vmovups 0x20(%2), %%ymm10\n" + "vmovups 0x40(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "1:\n" // deep + "vmovups (%1), %%ymm12\n" // weight + "vmovups 0x20(%1), %%ymm13\n" + "vmovups 0x40(%1), %%ymm14\n" + + "vbroadcastss (%0), %%ymm15\n" // src + "vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm2\n" + + "vbroadcastss (%0, %9), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm3\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm4\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm5\n" + + "vbroadcastss (%0, %9, 2), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm6\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm7\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm8\n" + + "vbroadcastss (%0, %7), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm9\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm10\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm11\n" + "addq $96, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 1b\n" + + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + "vmovups %%ymm2, 0x40(%5)\n" + "vmovups %%ymm3, (%5, %6)\n" + "vmovups %%ymm4, 0x20(%5, %6)\n" // dst_1 + "vmovups %%ymm5, 0x40(%5, %6)\n" + "vmovups %%ymm6, (%5, %6, 2)\n" + "vmovups %%ymm7, 0x20(%5, %6, 2)\n" + "vmovups %%ymm8, 0x40(%5, %6, 2)\n" // dst_2 + "vmovups %%ymm9, (%8)\n" + "vmovups %%ymm10, 0x20(%8)\n" + "vmovups %%ymm11, 0x40(%8)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst), "r"(col_algin), "r"(src_3_step), "r"(dst_3), + "r"(deep * sizeof(float)) // 9 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void MatVecMul1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep) { + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + + "1:\n" // deep + "movq %3, %%rcx\n" + "shr $3, %%ecx\n" + "je 3f\n" + "2:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 4(%0), %%ymm4\n" + "vfmadd231ps 96(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 128(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 160(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 8(%0), %%ymm4\n" + "vfmadd231ps 192(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 224(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 256(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 12(%0), %%ymm4\n" + "vfmadd231ps 288(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 320(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 352(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 16(%0), %%ymm4\n" + "vfmadd231ps 384(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 416(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 448(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 20(%0), %%ymm4\n" + "vfmadd231ps 480(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 512(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 544(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 24(%0), %%ymm4\n" + "vfmadd231ps 576(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 608(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 640(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 28(%0), %%ymm4\n" + "vfmadd231ps 672(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 704(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 736(%1), %%ymm4, %%ymm2\n" + "addq $768, %1\n" + "addq $32, %0\n" + "dec %%ecx\n" + "jg 2b\n" + + "3:\n" + "and $7, %3\n" // deep_remainder + "je 5f\n" + "4:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n" + "addq $96, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 4b\n" + + "5:\n" + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + "vmovups %%ymm2, 0x40(%5)\n" + + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm4", "%ymm14"); +} + +void MatMul6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, + const size_t row_block, const size_t col_block, size_t col_algin, const size_t deep) { + float *dst_3 = dst + 3 * col_algin; + float *dst_5 = dst + 5 * col_algin; + col_algin *= sizeof(float); + size_t src_3_step = 3 * deep * sizeof(float); + size_t src_5_step = 5 * deep * sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups (%2), %%ymm2\n" + "vmovups 0x20(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups (%2), %%ymm10\n" + "vmovups 0x20(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "1:\n" // deep + "vmovups (%1), %%ymm12\n" // weight + "vmovups 0x20(%1), %%ymm13\n" + + "vbroadcastss (%0), %%ymm14\n" // src_0 + "vbroadcastss (%0, %11), %%ymm15\n" // src_1 + "vfmadd231ps %%ymm14, %%ymm12, %%ymm0\n" + "vfmadd231ps %%ymm14, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm2\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm3\n" + + "vbroadcastss (%0, %11, 2), %%ymm14\n" // src_2 + "vbroadcastss (%0, %8), %%ymm15\n" // src_3 + "vfmadd231ps %%ymm14, %%ymm12, %%ymm4\n" + "vfmadd231ps %%ymm14, %%ymm13, %%ymm5\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm6\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm7\n" + + "vbroadcastss (%0, %11, 4), %%ymm14\n" // src_4 + "vbroadcastss (%0, %9), %%ymm15\n" // src_5 + "vfmadd231ps %%ymm14, %%ymm12, %%ymm8\n" + "vfmadd231ps %%ymm14, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm10\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm11\n" + + "addq $64, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 1b\n" + + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + "vmovups %%ymm2, (%5, %6)\n" // dst_1 + "vmovups %%ymm3, 0x20(%5, %6)\n" + "vmovups %%ymm4, (%5, %6, 2)\n" // dst_2 + "vmovups %%ymm5, 0x20(%5, %6, 2)\n" + "vmovups %%ymm6, (%7)\n" // dst_3 + "vmovups %%ymm7, 0x20(%7)\n" + "vmovups %%ymm8, (%5, %6, 4)\n" // dst_4 + "vmovups %%ymm9, 0x20(%5, %6, 4)\n" + "vmovups %%ymm10, (%10)\n" // dst_5 + "vmovups %%ymm11, 0x20(%10)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst), "r"(col_algin), "r"(dst_3), "r"(src_3_step), + "r"(src_5_step), "r"(dst_5), "r"(deep * sizeof(float)) // 11 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void MatVecMul1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep) { + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "1:\n" + "movq %3, %%rcx\n" + "shr $3, %%ecx\n" + "je 3f\n" + "2:\n" // deep_c8 + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 4(%0), %%ymm4\n" + "vfmadd231ps 64(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 96(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 8(%0), %%ymm4\n" + "vfmadd231ps 128(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 160(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 12(%0), %%ymm4\n" + "vfmadd231ps 192(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 224(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 16(%0), %%ymm4\n" + "vfmadd231ps 256(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 288(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 20(%0), %%ymm4\n" + "vfmadd231ps 320(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 352(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 24(%0), %%ymm4\n" + "vfmadd231ps 384(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 416(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 28(%0), %%ymm4\n" + "vfmadd231ps 448(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 480(%1), %%ymm4, %%ymm1\n" + "addq $512, %1\n" + "addq $32, %0\n" + "dec %%ecx\n" + "jg 2b\n" + + "3:\n" + "and $7, %3\n" + "je 5f\n" + "4:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "addq $64, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 4b\n" + + "5:\n" + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5 + : "%rcx", "%ymm0", "%ymm1", "%ymm12", "%ymm4", "%ymm14"); +} + +void MatVecMul1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep) { + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "1:\n" + "movq %3, %%rcx\n" + "shr $3, %%ecx\n" + "je 3f\n" + "2:\n" // deep_c8 + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vbroadcastss 4(%0), %%ymm4\n" + "vfmadd231ps 32(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 8(%0), %%ymm4\n" + "vfmadd231ps 64(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 12(%0), %%ymm4\n" + "vfmadd231ps 96(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 16(%0), %%ymm4\n" + "vfmadd231ps 128(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 20(%0), %%ymm4\n" + "vfmadd231ps 160(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 24(%0), %%ymm4\n" + "vfmadd231ps 192(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 28(%0), %%ymm4\n" + "vfmadd231ps 224(%1), %%ymm4, %%ymm0\n" + "addq $256, %1\n" + "addq $32, %0\n" + "dec %%ecx\n" + "jg 2b\n" + + "3:\n" + "and $7, %3\n" + "je 5f\n" + "4:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "addq $32, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 4b\n" + + "5:\n" + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5 + : "%rcx", "%ymm0", "%ymm1", "%ymm12", "%ymm4", "%ymm14"); +} + +void MatMul8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, + const size_t row_block, const size_t col_block, size_t col_algin, const size_t deep) { + float *dst_5 = dst + C5NUM * col_algin; + col_algin *= sizeof(float); + size_t dst_3_step = C3NUM * col_algin; + size_t src_3_step = C3NUM * deep * sizeof(float); + const float *src_5 = C5NUM * deep + src; + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups (%2), %%ymm1\n" + "vmovups (%2), %%ymm2\n" + "vmovups (%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups (%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups (%2), %%ymm7\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + + "1:\n" // deep + "vmovups (%1), %%ymm15\n" // weight + + "vbroadcastss (%0), %%ymm8\n" // src_0 + "vbroadcastss (%0, %11), %%ymm9\n" // src_1 + "vbroadcastss (%0, %11, 2), %%ymm10\n" // src_2 + "vbroadcastss (%0, %8), %%ymm11\n" // src_3 + "vfmadd231ps %%ymm8, %%ymm15, %%ymm0\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm1\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm3\n" + + "vbroadcastss (%0, %11, 4), %%ymm8\n" // src_4 + "vbroadcastss (%9), %%ymm9\n" // src_5 + "vbroadcastss (%9, %11, 1), %%ymm10\n" // src_6 + "vbroadcastss (%9, %11, 2), %%ymm11\n" // src_7 + "vfmadd231ps %%ymm8, %%ymm15, %%ymm4\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm5\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm7\n" + + "addq $32, %1\n" + "addq $4, %0\n" + "addq $4, %9\n" + "dec %3\n" + "jg 1b\n" + + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, (%5, %6)\n" + "vmovups %%ymm2, (%5, %6, 2)\n" + "vmovups %%ymm3, (%5, %7)\n" + "vmovups %%ymm4, (%5, %6, 4)\n" + "vmovups %%ymm5, (%10)\n" + "vmovups %%ymm6, (%10, %6)\n" + "vmovups %%ymm7, (%10, %6, 2)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst), "r"(col_algin), "r"(dst_3_step), // 7 + "r"(src_3_step), "r"(src_5), "r"(dst_5), "r"(deep * sizeof(float)) // 11 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +#ifdef ENABLE_DEBUG +void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep) { + __m256 dst_data[12]; + const float *src_sw[12]; + __m256 weight_data[4]; + for (int i = 0; i < C4NUM; ++i) { + weight_data[i] = _mm256_set1_ps(0.0f); + } + for (int i = 0; i < row_block; ++i) { + if (bias != NULL) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm256_loadu_ps(bias + j * C8NUM); + } + } else { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm256_set1_ps(0.0f); + } + } + src_sw[i] = src + i * deep; + } + const float *weight_kernel = weight; + for (int ic = 0; ic < deep; ++ic) { + for (int j = 0; j < col_block; ++j) { + weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM); + } + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = + _mm256_fmadd_ps(_mm256_set1_ps(src_sw[i][ic]), weight_data[j], dst_data[i * col_block + j]); + } + } + weight_kernel += C8NUM * col_block; + } // ic loop + // add bias and relu + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + if (0x1 & act_flag) { // relu6 + dst_data[i * col_block + j] = _mm256_min_ps(dst_data[i * col_block + j], _mm256_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * col_block + j] = _mm256_max_ps(dst_data[i * col_block + j], _mm256_set1_ps(0.0f)); + } + _mm256_storeu_ps(dst + i * col_algin + j * C8NUM, dst_data[i * col_block + j]); + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..88d0242dd354bf92df2eb138dc83de6d001ebbfe --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.h @@ -0,0 +1,63 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_MATMUL_AVX_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_AVX_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void (*DeconvAvxKernel)(const float *src, const float *weight, float *dst, int col, int row, int depth, + int stride); +void DeconvMatmulAvx(const float *a, const float *b, float *c, int depth, int row, int col, int kernel_plane); +void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, size_t act_type, size_t depth, + size_t row, size_t col, size_t stride, size_t write_mode); +typedef void (*MatVecMulKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatVecMulAvxFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col, + int col_align); +void MatMulAvxFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col, + int col_align, int row); +void MatVecMul1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatVecMul1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatVecMul1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatVecMul1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatMul3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatMul4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatMul6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatMul8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +#ifdef ENABLE_DEBUG +void DeconvColXRowAvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride); + +void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +#endif + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_MATMUL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..f88f44c07c4ac9a9586877441bd809d056107bf5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.c @@ -0,0 +1,822 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_avx512_fp32.h" +#include "nnacl_c/matmul_fp32_simd.h" + +#ifndef ENABLE_ARM +void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) { + for (int ci = 0; ci < col; ci++) { + float value = 0; + for (int di = 0; di < depth; di++) { + value += a[di] * b[ci * depth + di]; + } + if (bias != NULL) value += bias[ci]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type == ActType_Relu || act_type == ActType_Relu6) value = MSMAX(0.0f, value); + c[ci] = value; + } +} +#endif + +void MatVecMulFp32Block8(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int col) { + int col8 = col / C8NUM * C8NUM; + int ci = 0; + for (; ci < col8; ci += C8NUM, c += C8NUM) { +#ifdef ENABLE_NEON + float32x4_t value0 = vdupq_n_f32(0.0f); + float32x4_t value1 = vdupq_n_f32(0.0f); + for (int di = 0; di < depth; ++di, b += C8NUM) { + value0 += vdupq_n_f32(a[di]) * vld1q_f32(b); + value1 += vdupq_n_f32(a[di]) * vld1q_f32(b + C4NUM); + } + if (bias != NULL) { + value0 += vld1q_f32(bias + ci); + value1 += vld1q_f32(bias + ci + C4NUM); + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + value0 = vmaxq_f32(value0, vdupq_n_f32(0.0f)); + value1 = vmaxq_f32(value1, vdupq_n_f32(0.0f)); + } + if (act_type == ActType_Relu6) { + value0 = vminq_f32(value0, vdupq_n_f32(6.0f)); + value1 = vminq_f32(value1, vdupq_n_f32(6.0f)); + } + vst1q_f32(c, value0); + vst1q_f32(c + 4, value1); +#else + float value[C8NUM] = {0}; + for (int di = 0; di < depth; ++di, b += C8NUM) { + for (int j = 0; j < C8NUM; ++j) { + value[j] += a[di] * b[j]; + } + } + for (int j = 0; j < C8NUM; ++j) { + ADD_BIAS(value[j], bias, ci + j); + DO_RELU(value[j], act_type); + DO_RELU6(value[j], act_type); + } + memcpy(c, value, C8NUM * sizeof(float)); +#endif + } + int res = col - col8; + float value[C8NUM] = {0}; + for (int di = 0; di < depth; ++di, b += C8NUM) { + for (int j = 0; j < res; ++j) { + value[j] += a[di] * b[j]; + } + } + for (int j = 0; j < res; ++j) { + ADD_BIAS(value[j], bias, ci + j); + DO_RELU(value[j], act_type); + DO_RELU6(value[j], act_type); + } + memcpy(c, value, res * sizeof(float)); +} + +#ifdef ENABLE_ARM32 +void MatVecMulFp32Block4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int col) { + int col4 = col / C4NUM * C4NUM; + int ci = 0; + for (; ci < col4; ci += C4NUM, c += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t value = vdupq_n_f32(0.0f); + for (int di = 0; di < depth; ++di, b += C4NUM) { + value += vdupq_n_f32(a[di]) * vld1q_f32(b); + } + if (bias != NULL) { + value += vld1q_f32(&(bias[ci])); + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + value = vmaxq_f32(value, vdupq_n_f32(0.0f)); + } + if (act_type == ActType_Relu6) { + value = vminq_f32(value, vdupq_n_f32(6.0f)); + } + vst1q_f32(c, value); +#else + float value[C4NUM] = {0}; + for (int di = 0; di < depth; ++di, b += C4NUM) { + for (int j = 0; j < C4NUM; ++j) { + value[j] += a[di] * b[j]; + } + } + for (int j = 0; j < C4NUM; ++j) { + ADD_BIAS(value[j], bias, ci + j); + DO_RELU(value[j], act_type); + DO_RELU6(value[j], act_type); + } + memcpy(c, value, C4NUM * sizeof(float)); +#endif + } + int res = col - col4; + float value[C4NUM] = {0}; + for (int di = 0; di < depth; ++di, b += C4NUM) { + for (int j = 0; j < res; ++j) { + value[j] += a[di] * b[j]; + } + } + for (int j = 0; j < res; ++j) { + ADD_BIAS(value[j], bias, ci + j); + DO_RELU(value[j], act_type); + DO_RELU6(value[j], act_type); + } + memcpy(c, value, res * sizeof(float)); +} +#endif + +#ifdef ENABLE_ARM64 +// 4x8 +void MatVecMulFp32Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col, + int align_col) { + int ci = 0; + for (; ci < align_col - C8NUM + 1; ci += C8NUM) { + float32x4_t acc_0; + float32x4_t acc_1; + if (bias != NULL) { + acc_0 = vld1q_f32(bias + ci); + acc_1 = vld1q_f32(bias + ci + C4NUM); + } else { + acc_0 = vdupq_n_f32(0.0f); + acc_1 = vdupq_n_f32(0.0f); + } + const float *bv_base = b + ci * depth; + int di = 0; + for (; di < depth - C4NUM + 1; di += C4NUM) { + float32x4_t av = vld1q_f32(a + di); + float32x4_t bv_00 = vld1q_f32(bv_base); + float32x4_t bv_10 = vld1q_f32(bv_base + C4NUM); + bv_base += C8NUM; + float32x4_t bv_01 = vld1q_f32(bv_base); + float32x4_t bv_11 = vld1q_f32(bv_base + C4NUM); + bv_base += C8NUM; + float32x4_t bv_02 = vld1q_f32(bv_base); + float32x4_t bv_12 = vld1q_f32(bv_base + C4NUM); + bv_base += C8NUM; + float32x4_t bv_03 = vld1q_f32(bv_base); + float32x4_t bv_13 = vld1q_f32(bv_base + C4NUM); + bv_base += C8NUM; + acc_0 = vmlaq_n_f32(acc_0, bv_00, av[0]); + acc_1 = vmlaq_n_f32(acc_1, bv_10, av[0]); + acc_0 = vmlaq_n_f32(acc_0, bv_01, av[1]); + acc_1 = vmlaq_n_f32(acc_1, bv_11, av[1]); + acc_0 = vmlaq_n_f32(acc_0, bv_02, av[2]); + acc_1 = vmlaq_n_f32(acc_1, bv_12, av[2]); + acc_0 = vmlaq_n_f32(acc_0, bv_03, av[3]); + acc_1 = vmlaq_n_f32(acc_1, bv_13, av[3]); + } + if (di < depth) { + for (; di < depth; ++di) { + float ai = a[di]; + float32x4_t bv0 = vld1q_f32(bv_base); + float32x4_t bv1 = vld1q_f32(bv_base + C4NUM); + acc_0 = vmlaq_n_f32(acc_0, bv0, ai); + acc_1 = vmlaq_n_f32(acc_1, bv1, ai); + bv_base += C8NUM; + } + } // only save actual col num data + if (ci + C4NUM - 1 >= col) { + int c_remain = col - ci; + for (int i = 0; i < c_remain; ++i) { + if (act_type == ActType_Relu) { + c[i] = MSMAX(acc_0[i], 0.0f); + } else if (act_type == ActType_Relu6) { + c[i] = MSMIN(MSMAX(acc_0[i], 0.0f), 6.0f); + } else { + c[i] = acc_0[i]; + } + } + return; + } + if (act_type == ActType_Relu) { + acc_0 = vmaxq_f32(acc_0, vdupq_n_f32(0.0f)); + } else if (act_type == ActType_Relu6) { + acc_0 = vminq_f32(vmaxq_f32(acc_0, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f)); + } + vst1q_f32(c, acc_0); + if (ci + C8NUM - 1 >= col) { + int c_remain = col - ci - C4NUM; + for (int i = 0; i < c_remain; ++i) { + if (act_type == ActType_Relu) { + c[C4NUM + i] = MSMAX(acc_1[i], 0.0f); + } else if (act_type == ActType_Relu6) { + c[C4NUM + i] = MSMIN(MSMAX(acc_1[i], 0.0f), 6.0f); + } else { + c[C4NUM + i] = acc_1[i]; + } + } + return; + } + if (act_type == ActType_Relu) { + acc_1 = vmaxq_f32(acc_1, vdupq_n_f32(0.0f)); + } else if (act_type == ActType_Relu6) { + acc_1 = vminq_f32(vmaxq_f32(acc_1, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f)); + } + vst1q_f32(c + C4NUM, acc_1); + c += C8NUM; + } +} +#endif + +void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, + int col, int stride, int out_type) { + if (out_type == OutType_Nhwc) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r12div = r / 12, r12mod = r % 12; + int c8div = c / 8, c8mod = c % 8; + size_t ci = r * stride + c; + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * 12 + d * 12 + r12mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else if (out_type == OutType_C8) { + int col_8 = UP_ROUND(col, C8NUM); + int row_12 = UP_ROUND(row, C12NUM); + for (int r = 0; r < row_12; r++) { + for (int c = 0; c < col_8; c++) { + int r12div = r / C12NUM, r12mod = r % C12NUM; + int c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod); + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod; + size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else if (out_type == OutType_TileC8) { + for (int i = 0; i < row; ++i) { + int src_r_offset = i; + int dst_r_offset = i * col * stride; + for (int j = 0; j < col; ++j) { + int c8div = j / 8, c8mod = j % 8; + size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; + float value = 0; + for (int d = 0; d < deep; ++d) { + size_t ai = src_r_offset + d * C12NUM; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, j) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } +} + +void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, + int col, size_t stride, int out_type) { +#ifdef ENABLE_ARM64 + if (out_type == OutType_C8) { + MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); + } else if (out_type == OutType_Nhwc && deep > C512NUM) { + BigMatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride); + } else { + MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); + } +#elif ENABLE_ARM32 + if (out_type == OutType_C8) { + MatmulFloatNeon32(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); + } else if (out_type == OutType_Nhwc) { + MatmulFloatNeon32Opt12x4(a, b, c, bias, (int)act_type, deep, row, col, stride, 1); + } else { + MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); + } +#elif ENABLE_AVX + MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type)); +#elif ENABLE_SSE + MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); +#else + MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type); +#endif +} + +#define ActCompute(bit_num, down_threshold, up_threshold) \ + if (act_type != 0) { \ + dst = MS_MAX##bit_num##_F32(dst, down_threshold); \ + if (act_type == 3) { \ + dst = MS_MIN##bit_num##_F32(dst, up_threshold); \ + } \ + } + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type) { + int index = 0; + + SIMD_RUN_NO_SCALAR(GemmIsNotPack, index, a, b, c, bias, row, deep, act_type); + + for (; index < row; ++index) { + float dst = a[index] * b[0] + bias[0]; + ActCompute(32, 0, C6NUM); + c[index] = dst; + } +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void Row1Deep1GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep, + int act_type) { + int index = 0; + + SIMD_RUN_NO_SCALAR(Row1Deep1GemmIsNotPack, index, a, b, c, bias, col, act_type); + for (; index < col; ++index) { + float dst = a[0] * b[index] + bias[index]; + ActCompute(32, 0, C6NUM); + c[index] = dst; + } +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void Row1Deep1NoBiasGemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep, + int act_type) { + int index = 0; + + SIMD_RUN_NO_SCALAR(Row1Deep1NoBiasGemmIsNotPack, index, a, b, c, bias, col, act_type); + for (; index < col; ++index) { + float dst = a[0] * b[index]; + ActCompute(32, 0, C6NUM); + c[index] = dst; + } +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type) { + // gemm dot is [m, k] * [k, 1] ==>> [m, 1] + int m_index = 0; + + SIMD_RUN_AVX512(GemmIsNotPackOptimize, m_index, a, b, c, bias, m, k, act_type); + +#ifdef ENABLE_AVX + // block 4 + MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0); + MS_FLOAT32X4 up_threshold128 = MS_MOVQ_F32(C6NUM); + for (; m_index <= m - C4NUM; m_index += C4NUM) { + int k_index = 0; + MS_FLOAT32X4 dst = MS_MOV128_F32(bias[0]); + MS_SET_ZERO256X4_F32(dst_) + for (; k_index <= k - C8NUM; k_index += C8NUM) { + MS_FLOAT32X8 weight = MS_LD256_F32(b + k_index); + MS_LOAD256X4_F32(src, a + m_index * k + k_index, k); + MS_FMADD256X4_F32(src, weight, dst_); + } + MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD256_F32(dst_1); + MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD256_F32(dst_2); + MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD256_F32(dst_3); + MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD256_F32(dst_4); + for (; k_index < k; ++k_index) { + MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index]; + MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k]; + MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k]; + MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k]; + } + ActCompute(128, down_threshold128, up_threshold128); + MS_ST128_F32(c + m_index, dst); + } +#endif + + // block 1 + for (; m_index < m; m_index++) { + float dst = bias[0]; + int k_index = 0; + + SIMD_RUN_AVX512(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst); + SIMD_RUN_AVX(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst); + + for (; k_index < k; k_index++) { + dst += b[k_index] * a[m_index * k + k_index]; + } + ActCompute(32, 0, C6NUM); + c[m_index] = dst; + } +} + +void MatVecMulNoPackFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int64_t depth, + int64_t cur_col, int64_t col) { + int inc_flag = 0; + int64_t k = 0; + for (; k <= depth - C1500NUM; k += C1500NUM) { + inc_flag = (k == 0) + (k + C1500NUM == depth ? C2NUM : 0); + int64_t oc_index = 0; + SIMD_RUN_NO_SCALAR(MatVecMulNoPackCore, oc_index, a, b, c, bias, act_type, C1500NUM, cur_col, col, inc_flag); + for (; oc_index < cur_col; ++oc_index) { + float dst = (inc_flag & 1) == 0 ? c[oc_index] : (bias == NULL ? 0 : bias[oc_index]); + for (int64_t k_index = 0; k_index < k; ++k_index) { + dst += a[k_index] * b[oc_index + k_index * col]; + } + if ((inc_flag & 0x2) != 0) { + ActCompute(32, 0, C6NUM); + } + c[oc_index] = dst; + } + a += C1500NUM; + b += C1500NUM * col; + } + if (k == depth) { + return; + } + inc_flag = (k == 0) + C2NUM; + int64_t oc_index = 0; + SIMD_RUN_NO_SCALAR(MatVecMulNoPackCore, oc_index, a, b, c, bias, act_type, depth - k, cur_col, col, inc_flag); + for (; oc_index < cur_col; ++oc_index) { + float dst = (inc_flag & 1) == 0 ? c[oc_index] : (bias == NULL ? 0 : bias[oc_index]); + for (int64_t k_index = 0; k_index < depth; ++k_index) { + dst += a[k_index] * b[oc_index + k_index * col]; + } + ActCompute(32, 0, C6NUM); + c[oc_index] = dst; + } +} + +#ifdef ENABLE_ARM64 +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep, + size_t act_type) { + // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute + // 9: WriteBack + asm volatile( + "mov x8, %[input]\n" + "mov x9, %[weight]\n" + "mov x10, %[deep]\n" + "add x5, %[input], %[deep], LSL #2\n" + "add x6, %[input], %[deep], LSL #3\n" + "add x7, x5, %[deep], LSL #3\n" + "dup v0.2d, xzr\n" + "dup v1.2d, xzr\n" + "dup v2.2d, xzr\n" + "dup v3.2d, xzr\n" + "subs x10, x10, #16\n" + "blt 2f\n" + "1:\n" // LoopD16 + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x6], #64\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x7], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v2.4s, v21.4s, v29.4s\n" + "fmla v3.4s, v25.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "fmla v2.4s, v22.4s, v30.4s\n" + "fmla v3.4s, v26.4s, v30.4s\n" + "fmla v0.4s, v7.4s, v31.4s\n" + "fmla v1.4s, v19.4s, v31.4s\n" + "fmla v2.4s, v23.4s, v31.4s\n" + "fmla v3.4s, v27.4s, v31.4s\n" + "subs x10, x10, #16\n" + "bge 1b\n" + "2:\n" // LoopD12 + "adds x10, x10, #16\n" + "cbz x10, 6f\n" + "cmp x10, #12\n" + "blt 3f\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n" + "ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n" + "ld1 {v20.4s, v21.4s, v22.4s}, [x6], #48\n" + "ld1 {v24.4s, v25.4s, v26.4s}, [x7], #48\n" + "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v2.4s, v21.4s, v29.4s\n" + "fmla v3.4s, v25.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "fmla v2.4s, v22.4s, v30.4s\n" + "fmla v3.4s, v26.4s, v30.4s\n" + "sub x10, x10, #12\n" + "b 7f\n" + "3:\n" // LoopD8 + "cmp x10, #8\n" + "blt 4f\n" + "ld1 {v4.4s, v5.4s}, [x8], #32\n" + "ld1 {v16.4s, v17.4s}, [x5], #32\n" + "ld1 {v20.4s, v21.4s}, [x6], #32\n" + "ld1 {v24.4s, v25.4s}, [x7], #32\n" + "ld1 {v28.4s, v29.4s}, [x9], #32\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v2.4s, v21.4s, v29.4s\n" + "fmla v3.4s, v25.4s, v29.4s\n" + "sub x10, x10, #8\n" + "b 7f\n" + "4:\n" // LoopD4 + "cmp x10, #4\n" + "blt 7f\n" + "ld1 {v4.4s}, [x8], #16\n" + "ld1 {v16.4s}, [x5], #16\n" + "ld1 {v20.4s}, [x6], #16\n" + "ld1 {v24.4s}, [x7], #16\n" + "ld1 {v28.4s}, [x9], #16\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "sub x10, x10, #4\n" + "7:\n" + "cbz x10, 6f\n" + "dup v4.2d, xzr\n" + "dup v16.2d, xzr\n" + "dup v20.2d, xzr\n" + "dup v24.2d, xzr\n" + "dup v28.2d, xzr\n" + "subs x10, x10, #2\n" + "blt 5f\n" + "ld1 {v4.d}[0], [x8], #8\n" // LoopD2 + "ld1 {v16.d}[0], [x5], #8\n" + "ld1 {v20.d}[0], [x6], #8\n" + "ld1 {v24.d}[0], [x7], #8\n" + "ld1 {v28.d}[0], [x9], #8\n" + "cbz x10, 8f\n" + "5:\n" // LoopD1 + "ld1 {v4.s}[2], [x8]\n" + "ld1 {v16.s}[2], [x5]\n" + "ld1 {v20.s}[2], [x6]\n" + "ld1 {v24.s}[2], [x7]\n" + "ld1 {v28.s}[2], [x9]\n" + "8:\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "6:\n" + "faddp v4.4s, v0.4s, v1.4s\n" + "faddp v5.4s, v2.4s, v3.4s\n" + "faddp v0.4s, v4.4s, v5.4s\n" + "cbz %[bias], 9f\n" + "ld1r {v1.4s}, [%[bias]]\n" + "fadd v0.4s, v0.4s, v1.4s\n" + "9:\n" + "cbz %[act], 10f\n" + "dup v1.2d, xzr\n" + "fmax v0.4s, v0.4s, v1.4s\n" + "cmp %[act], #3\n" + "bne 10f\n" + "movi v1.4s, #6\n" + "scvtf v1.4s, v1.4s\n" + "fmin v0.4s, v0.4s, v1.4s\n" + "10:\n" + "st1 {v0.4s}, [%[output]]\n" + + : + : [input] "r"(input), [weight] "r"(weight), [output] "r"(output), [bias] "r"(bias), [deep] "r"(deep), + [act] "r"(act_type) + : "cc", "x5", "x6", "x7", "x8", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} + +void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep, + size_t act_type) { + // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute + // 9: WriteBack + asm volatile( + "mov x8, %[input]\n" + "mov x9, %[weight]\n" + "mov x10, %[deep]\n" + "add x5, %[input], %[deep], LSL #2\n" + "dup v0.2d, xzr\n" + "dup v1.2d, xzr\n" + "subs x10, x10, #16\n" + "blt 2f\n" + "1:\n" // LoopD16 + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "fmla v0.4s, v7.4s, v31.4s\n" + "fmla v1.4s, v19.4s, v31.4s\n" + "subs x10, x10, #16\n" + "bge 1b\n" + "2:\n" // LoopD12 + "adds x10, x10, #16\n" + "cbz x10, 6f\n" + "cmp x10, #12\n" + "blt 3f\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n" + "ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n" + "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "sub x10, x10, #12\n" + "b 7f\n" + "3:\n" // LoopD8 + "cmp x10, #8\n" + "blt 4f\n" + "ld1 {v4.4s, v5.4s}, [x8], #32\n" + "ld1 {v16.4s, v17.4s}, [x5], #32\n" + "ld1 {v28.4s, v29.4s}, [x9], #32\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "sub x10, x10, #8\n" + "b 7f\n" + "4:\n" // LoopD4 + "cmp x10, #4\n" + "blt 7f\n" + "ld1 {v4.4s}, [x8], #16\n" + "ld1 {v16.4s}, [x5], #16\n" + "ld1 {v28.4s}, [x9], #16\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "sub x10, x10, #4\n" + "7:\n" + "cbz x10, 6f\n" + "dup v4.2d, xzr\n" + "dup v16.2d, xzr\n" + "subs x10, x10, #2\n" + "blt 5f\n" + "ld1 {v4.d}[0], [x8], #8\n" // LoopD2 + "ld1 {v16.d}[0], [x5], #8\n" + "ld1 {v28.d}[0], [x9], #8\n" + "cbz x10, 8f\n" + "5:\n" // LoopD1 + "ld1 {v4.s}[2], [x8]\n" + "ld1 {v16.s}[2], [x5]\n" + "ld1 {v28.s}[2], [x9]\n" + "8:\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "6:\n" + "faddp v4.4s, v0.4s, v1.4s\n" + "faddp v0.4s, v4.4s, v4.4s\n" + "cbz %[bias], 9f\n" + "ld1r {v1.4s}, [%[bias]]\n" + "fadd v0.2s, v0.2s, v1.2s\n" + "9:\n" + "cbz %[act], 10f\n" + "fmov d1, xzr\n" + "fmax v0.2s, v0.2s, v1.2s\n" + "cmp %[act], #3\n" + "bne 10f\n" + "movi v1.2s, #6\n" + "scvtf v1.2s, v1.2s\n" + "fmin v0.2s, v0.2s, v1.2s\n" + "10:\n" + "st1 {v0.2s}, [%[output]]\n" + + : + : [input] "r"(input), [weight] "r"(weight), [output] "r"(output), [bias] "r"(bias), [deep] "r"(deep), + [act] "r"(act_type) + : "cc", "x5", "x8", "x9", "x10", "v0", "v1", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", + "v30", "v31", "memory"); +} + +void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep, + size_t act_type) { + // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute + // 9: WriteBack + asm volatile( + "mov x8, %[input]\n" + "mov x9, %[weight]\n" + "mov x10, %[deep]\n" + "dup v0.2d, xzr\n" + "subs x10, x10, #16\n" + "blt 2f\n" + "1:\n" // LoopD16 + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v0.4s, v7.4s, v31.4s\n" + "subs x10, x10, #16\n" + "bge 1b\n" + "2:\n" // LoopD12 + "adds x10, x10, #16\n" + "cbz x10, 6f\n" + "cmp x10, #12\n" + "blt 3f\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n" + "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "sub x10, x10, #12\n" + "b 7f\n" + "3:\n" // LoopD8 + "cmp x10, #8\n" + "blt 4f\n" + "ld1 {v4.4s, v5.4s}, [x8], #32\n" + "ld1 {v28.4s, v29.4s}, [x9], #32\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "sub x10, x10, #8\n" + "b 7f\n" + "4:\n" // LoopD4 + "cmp x10, #4\n" + "blt 7f\n" + "ld1 {v4.4s}, [x8], #16\n" + "ld1 {v28.4s}, [x9], #16\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "sub x10, x10, #4\n" + "7:\n" + "cbz x10, 6f\n" + "dup v4.2d, xzr\n" + "subs x10, x10, #2\n" + "blt 5f\n" + "ld1 {v4.d}[0], [x8], #8\n" // LoopD2 + "ld1 {v28.d}[0], [x9], #8\n" + "cbz x10, 8f\n" + "5:\n" // LoopD1 + "ld1 {v4.s}[3], [x8]\n" + "ld1 {v28.s}[3], [x9]\n" + "8:\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "6:\n" + "faddp v4.4s, v0.4s, v0.4s\n" + "faddp v0.4s, v4.4s, v4.4s\n" + "cbz %[bias], 9f\n" + "ld1 {v1.s}[0], [%[bias]]\n" + "fadd s0, s0, s1\n" + "9:\n" + "cbz %[act], 10f\n" + "fmov s1, wzr\n" + "fmax s0, s0, s1\n" + "cmp %[act], #3\n" + "bne 10f\n" + "mov x10, #6\n" + "scvtf s1, x10\n" + "fmin s0, s0, s1\n" + "10:\n" + "str s0, [%[output]]\n" + + : + : [input] "r"(input), [weight] "r"(weight), [output] "r"(output), [bias] "r"(bias), [deep] "r"(deep), + [act] "r"(act_type) + : "cc", "x8", "x9", "x10", "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", "v30", "v31"); +} + +void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row, + int deep, int act_type) { + const float *input = a + start_row * deep; + float *output = c + start_row; + const int step = C4NUM * deep; + for (; start_row <= end_row - C4NUM; start_row += C4NUM) { + MatMul4x1Kernel(input, b, output, bias, deep, act_type); + input += step; + output += C4NUM; + } + for (; start_row <= end_row - C2NUM; start_row += C2NUM) { + MatMul2x1Kernel(input, b, output, bias, deep, act_type); + input += C2NUM * deep; + output += C2NUM; + } + if (start_row == end_row - 1) { + MatMul1x1Kernel(input, b, output, bias, deep, act_type); + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..fc75648d484045a585e250f9a21beecf38e26213 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.h @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_MATMUL_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_H_ + +#include +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/matmul_avx_fp32.h" + +#define ADD_BIAS(value, bias, c) \ + if (bias != NULL) value = value + bias[c]; + +#define DO_RELU(value, act_type) \ + if (act_type == ActType_Relu) value = MSMAX(0.0f, value); + +#define DO_RELU6(value, act_type) \ + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); \ + if (act_type == ActType_Relu6) value = MSMAX(0.0f, value); + +#ifdef __cplusplus +extern "C" { +#endif +void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, + int col, size_t stride, int out_type); +void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); +void MatVecMulFp32Block8(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); +void MatVecMulFp32Block4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); + +#ifdef ENABLE_ARM64 +void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, size_t stride, size_t writeNhwc, size_t WriteWino); +void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, size_t stride, size_t write_mode); +void BigMatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, size_t stride); +void MatmulFloatNeon64OptRow8(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, size_t stride, size_t write_mode); +void MatmulFloatNeon64OptRow4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, size_t stride, size_t write_mode); +void MatmulFloatNeon64OptRow12(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, size_t stride, size_t write_mode); +void MatVecMulPackFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); +void MatVecMulFp32Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col, + int align_col); + +#elif defined(ENABLE_ARM32) +void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, size_t writeNhwc, size_t WriteWino); +void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, int write_mode); +void MatmulFloatNeon32Opt12x4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, int stride, int write_mode); + +#elif defined(ENABLE_SSE) +void DeconvMatmulFloatSse(const float *a, const float *b, float *c, int depth, int row, int col); +void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, int write_mode); +#endif + +void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, + int col, int stride, int out_type); + +void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type); + +void Row1Deep1GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep, + int act_type); + +void Row1Deep1NoBiasGemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep, + int act_type); + +void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type); + +void MatVecMulNoPackFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int64_t depth, + int64_t cur_col, int64_t col); +#ifdef ENABLE_ARM64 +void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row, + int deep, int act_type); +#endif +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_MATMUL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..2c91a3e6284db1273b00594e528588919d1dfa6d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32_simd.h.in @@ -0,0 +1,148 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_FP32_MATMUL_F32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_F32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +static inline int64_t GemmIsNotPack@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, float *c, const float *bias, int row, + int deep, int act_type) { + SIMD_F32 down_threshold = SIMD_MOV_F32(0.0f); + SIMD_F32 up_threshold = SIMD_MOV_F32(6); + SIMD_F32 b_data16 = SIMD_MOV_F32(b[0]); + SIMD_F32 bias_data16 = SIMD_MOV_F32(bias[0]); + for (int block_max_size = row - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 a_data = SIMD_LD_F32(a + index); + SIMD_F32 dst = SIMD_FMADD_F32(b_data16, a_data, bias_data16); + if (act_type != 0) { + dst = SIMD_MAX_F32(dst, down_threshold); + if (act_type == 3) { + dst = SIMD_MIN_F32(dst, up_threshold); + } + } + SIMD_ST_F32(c + index, dst); + } + + return index; +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +static inline int64_t Row1Deep1GemmIsNotPack@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, float *c, const float *bias, int col, + int act_type) { + SIMD_F32 down_threshold = SIMD_MOV_F32(0.0f); + SIMD_F32 up_threshold = SIMD_MOV_F32(6); + SIMD_F32 vec_a = SIMD_MOV_F32(a[0]); + if (act_type == 1) { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 vec_bias = SIMD_LD_F32(bias + index); + SIMD_F32 dst = SIMD_FMADD_F32(vec_a, vec_b, vec_bias); + SIMD_ST_F32(c + index, SIMD_MAX_F32(dst, down_threshold)); // relu + } + } else if (act_type == 3) { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 vec_bias = SIMD_LD_F32(bias + index); + SIMD_F32 dst = SIMD_FMADD_F32(vec_a, vec_b, vec_bias); + SIMD_ST_F32(c + index, SIMD_CLAMP_F32(dst, down_threshold, up_threshold)); // relue6 + } + } else { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 vec_bias = SIMD_LD_F32(bias + index); + SIMD_F32 dst = SIMD_FMADD_F32(vec_a, vec_b, vec_bias); + SIMD_ST_F32(c + index, dst); // no_act + } + } + + return index; +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +static inline int64_t Row1Deep1NoBiasGemmIsNotPack@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, float *c, const float *bias, int col, + int act_type) { + SIMD_F32 down_threshold = SIMD_MOV_F32(0.0f); + SIMD_F32 up_threshold = SIMD_MOV_F32(6); + SIMD_F32 vec_a = SIMD_MOV_F32(a[0]); + if (act_type == 1) { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 dst = SIMD_MUL_F32(vec_a, vec_b); + SIMD_ST_F32(c + index, SIMD_MAX_F32(dst, down_threshold)); // relu + } + } else if (act_type == 3) { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 dst = SIMD_MUL_F32(vec_a, vec_b); + SIMD_ST_F32(c + index, SIMD_CLAMP_F32(dst, down_threshold, up_threshold)); // relue6 + } + } else { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 dst = SIMD_MUL_F32(vec_a, vec_b); + SIMD_ST_F32(c + index, dst); // no_act + } + } + + return index; +} + +#if defined(MS_SIMD_AVX512) || defined(MS_SIMD_AVX) +static inline int64_t GemmIsNotPackOptimizeCore@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, int k, float *dst) { + SIMD_F32 dst1 = SIMD_MOV_F32(0.0f); + for (int block_max_size = k - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 weight = SIMD_LD_F32(b + index); + SIMD_F32 a1 = SIMD_LD_F32(a + index); + dst1 = SIMD_FMADD_F32(weight, a1, dst1); + } + *dst += SIMD_REDUCE_ADD_F32(dst1); + return index; +} +#endif + +static inline int64_t MatVecMulNoPackCore@SIMD_INSTRUCTION@(int64_t oc_index, const float *a, const float *b, float *c, const float *bias, + int act_type, int64_t depth, int64_t oc, int64_t col, int64_t inc_flag) { + for (int64_t oc_max_size = oc - BLOCK_NUM; oc_index <= oc_max_size; oc_index += BLOCK_NUM) { + SIMD_F32 out = (inc_flag & 0x1) == 0 ? SIMD_LD_F32(c + oc_index) : (bias == NULL ? SIMD_MOV_F32(0.0f) : SIMD_LD_F32(bias + oc_index)); + for (int64_t k = 0; k < depth; ++k) { + SIMD_F32 left = SIMD_MOV_F32(a[k]); + SIMD_F32 right = SIMD_LD_F32(b + oc_index + k * col); + out = SIMD_FMADD_F32(left, right, out); + } + if ((inc_flag & 0x2) != 0 && act_type != 0) { + out = SIMD_MAX_F32(out, SIMD_MOV_F32(0.0f)); + if (act_type == 0x3) { + out = SIMD_MIN_F32(out, SIMD_MOV_F32(6.0f)); + } + } + SIMD_ST_F32(c + oc_index, out); + } + return oc_index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..2846b09870950582ba79ad324d47dd931f3d30ce --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32.c @@ -0,0 +1,187 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/mul_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/mul_fp32_simd.h" +#include "nnacl_c/errorcode.h" + +int BroadcastMul(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param) { + TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param); + return ElementMul(tile_in0, tile_in1, out, size); +} + +int ElementMul(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMul, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] * in1[index]; + } + return NNACL_OK; +} + +int ElementMulRelu(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMulRelu, index, in0, in1, out, size); + for (; index < size; index++) { + float res = in0[index] * in1[index]; + out[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementMulRelu6(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMulRelu6, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] * in1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementMulInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMulInt, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] * in1[index]; + } + return NNACL_OK; +} + +int ElementMulReluInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMulReluInt, index, in0, in1, out, size); + for (; index < size; index++) { + int res = in0[index] * in1[index]; + out[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementMulRelu6Int(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMulRelu6Int, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] * in1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementOptMul(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] * in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] * in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptMulRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulReluNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[0] * in1[index], 0); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulReluNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[index] * in1[0], 0); + } + } + return NNACL_OK; +} + +int ElementOptMulRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulRelu6Num0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[0] * in1[index], 0), 6); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulRelu6Num1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] * in1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementOptMulInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulIntNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[0] * in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulIntNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] * in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptMulReluInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulReluIntNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[0] * in1[index], 0); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulReluIntNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[index] * in1[0], 0); + } + } + return NNACL_OK; +} + +int ElementOptMulRelu6Int(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulRelu6IntNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[0] * in1[index], 0), 6); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulRelu6IntNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] * in1[0], 0), 6); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..41941d6205c8a7c2bb8237224b6bc0e647df0989 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32.h @@ -0,0 +1,45 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_MUL_H_ +#define MINDSPORE_NNACL_FP32_MUL_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/arithmetic_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ElementMul(const float *in0, const float *in1, float *out, int size); +int ElementMulRelu(const float *in0, const float *in1, float *out, int size); +int ElementMulRelu6(const float *in0, const float *in1, float *out, int size); +int ElementMulInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementMulReluInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementMulRelu6Int(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptMul(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptMulRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptMulRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptMulInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int ElementOptMulReluInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int ElementOptMulRelu6Int(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int BroadcastMul(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_MUL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..33bc1a37e00fe7fde39a7a31ba7347c2e13b5d82 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32_simd.h.in @@ -0,0 +1,211 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int ElementMul@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MUL_F32(vin0, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementMulRelu@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_MUL_F32(vin0, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementMulRelu6@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_MUL_F32(vin0, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementMulInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MUL_EPI32(vin0, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementMulReluInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0, vin1), 0.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementMulRelu6Int@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MIN_N_EPI32(SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0, vin1), 0.0f), 6.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin0_opt_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MUL_F32(vin0_opt_, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MUL_F32(vin0, vin1_opt_); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulReluNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin0_opt_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_MUL_F32(vin0_opt_, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulReluNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_MUL_F32(vin0, vin1_opt_), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulRelu6Num0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin0_opt_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_MUL_F32(vin0_opt_, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulRelu6Num1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_MUL_F32(vin0, vin1_opt_), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin0_opt_ = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MUL_EPI32(vin0_opt_, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin1_opt_ = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vout = SIMD_MUL_EPI32(vin0, vin1_opt_); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulReluIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin0_opt_ = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0_opt_, vin1), 0.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulReluIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin1_opt_ = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vout = SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0, vin1_opt_), 0.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulRelu6IntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin0_opt_ = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MIN_N_EPI32(SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0_opt_, vin1), 0.0f), 6.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulRelu6IntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin1_opt_ = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vout = SIMD_MIN_N_EPI32(SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0, vin1_opt_), 0.0f), 6.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/nllloss_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/nllloss_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..5fdea60ce9bc0c524fd4a4b76a8d30a9d09c3da8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/nllloss_fp32.c @@ -0,0 +1,49 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/nllloss_fp32.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +int NLLLoss(const float *logits, const int32_t *labels, const float *weight, float *loss, float *total_weight, + const NLLLossStruct *nllloss, const ReductionType reduction_type) { + NNACL_CHECK_NULL_RETURN_ERR(logits); + NNACL_CHECK_NULL_RETURN_ERR(labels); + NNACL_CHECK_NULL_RETURN_ERR(weight); + NNACL_CHECK_NULL_RETURN_ERR(loss); + NNACL_CHECK_NULL_RETURN_ERR(total_weight); + + float total_loss = 0.0; + float tmp_total_weight = 0.0; + for (int i = 0; i < nllloss->batch_; i++) { + int index = i * nllloss->class_num_ + labels[i]; + float n_weight = weight[labels[i]]; + float n_loss = -logits[index] * n_weight; + tmp_total_weight += n_weight; + total_loss += n_loss; + if (reduction_type == Reduction_None) { + loss[i] = n_loss; + } + } + + *total_weight = tmp_total_weight; + if (reduction_type == Reduction_Sum) { + *loss = total_loss; + } else if (reduction_type == Reduction_Mean) { + *loss = total_loss / tmp_total_weight; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/nllloss_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/nllloss_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..e50de64106c8c938415649c520131052ec900214 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/nllloss_fp32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_NLLLOSS_FP32_H_ +#define NNACL_FP32_NLLLOSS_FP32_H_ + +#include "nnacl_c/kernel/nllloss.h" + +#ifdef __cplusplus +extern "C" { +#endif +int NLLLoss(const float *logits, const int32_t *labels, const float *weight, float *loss, float *total_weight, + const NLLLossStruct *parameter, const ReductionType reduction_type); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_NLLLOSS_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/non_max_suppression_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/non_max_suppression_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..63ed8ab33950405a245a4874b003e029570a3580 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/non_max_suppression_fp32.c @@ -0,0 +1,205 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/non_max_suppression_fp32.h" +#include +#include +#include "nnacl_c/tensor_c_utils.h" + +typedef struct { + int32_t batch_index_; + int32_t class_index_; + int32_t box_index_; +} NMSIndex; + +typedef struct { + float score_; + int index_; + float y1_; // y1 x1 y2 x2 ascending order + float y2_; + float x1_; + float x2_; + float area_; +} NMSBox; + +void CreateNMBox(NMSBox *box, float score, int index, int cpb, float y_a, float x_a, float y_b, float x_b) { + box->score_ = score; + box->index_ = index; + if (0 == cpb) { + box->y1_ = NNACL_MIN(y_a, y_b); + box->y2_ = NNACL_MAX(y_a, y_b); + box->x1_ = NNACL_MIN(x_a, x_b); + box->x2_ = NNACL_MAX(x_a, x_b); + } else { + // x_center, y_center, width, height + float half_wid = x_b / 2; + box->x1_ = x_a - half_wid; + box->x2_ = x_a + half_wid; + float half_height = y_b / 2; + box->y1_ = y_a - half_height; + box->y2_ = y_a + half_height; + } + box->area_ = (box->y2_ - box->y1_) * (box->x2_ - box->x1_); +} + +bool CheckIoUSuppressed(const NMSBox *box, const NMSBox *cand, float iou_threshold) { + float intersec_x1 = NNACL_MAX(cand->x1_, box->x1_); + float intersec_x2 = NNACL_MIN(cand->x2_, box->x2_); + float intersec_y1 = NNACL_MAX(cand->y1_, box->y1_); + float intersec_y2 = NNACL_MIN(cand->y2_, box->y2_); + const float intersec_area = NNACL_MAX(intersec_x2 - intersec_x1, 0.0f) * NNACL_MAX(intersec_y2 - intersec_y1, 0.0f); + if (intersec_area <= 0.0f) { + return false; + } + const float intersec_over_union = intersec_area / (cand->area_ + box->area_ - intersec_area); + return intersec_over_union > iou_threshold; +} + +bool LessThan(NMSBox *box1, NMSBox *box2) { + return box1->score_ < box2->score_ || + (fabs(box1->score_ - box2->score_) < FLT_EPSILON && box1->index_ > box2->index_); +} + +void SortCandidates(ExecEnv *env, NMSBox **sorted, NMSBox *origin, int size) { + bool *sorted_index = (bool *)env->Alloc(env->allocator_, size * sizeof(bool)); + NNACL_CHECK_NULL_RETURN_VOID(sorted); + memset(sorted_index, 0, size * sizeof(bool)); + + NMSBox min_box; + min_box.score_ = FLT_MIN; + min_box.index_ = 0; + + for (int i = 0; i < size; i++) { + int max_index = 0; + NMSBox *box = &min_box; + for (int j = 0; j < size; j++) { + if (sorted_index[j]) { + continue; + } + if (LessThan(box, &origin[j])) { + max_index = j; + } + } + sorted[i] = &origin[max_index]; + sorted_index[max_index] = true; + } + + env->Free(env->allocator_, sorted); + return; +} + +int NonMaxSuppressionSelecte(NonMaxSuppressionStruct *nm_suppression, bool simple_out, int *score_dims) { + const float *box_data = (float *)nm_suppression->base_.in_[Index0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(box_data); + const float *scores_data = (float *)nm_suppression->base_.in_[Index1]->data_; // batch, class, num + NNACL_CHECK_NULL_RETURN_ERR(scores_data); + ExecEnv *env = nm_suppression->base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + int batch_num = score_dims[Index0]; + int class_num = score_dims[Index1]; + int box_num = score_dims[Index2]; + + int selected_box_per_class_max_size = NNACL_MIN((int)box_num, nm_suppression->max_output_per_class_); + NNACL_CHECK_MALLOC_SIZE(selected_box_per_class_max_size); + NMSBox *selected_box_per_class = env->Alloc(env->allocator_, selected_box_per_class_max_size * sizeof(NMSBox)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(selected_box_per_class); + memset(selected_box_per_class, 0, selected_box_per_class_max_size * sizeof(NMSBox)); + NMSBox *above_score_candidates = env->Alloc(env->allocator_, box_num * sizeof(NMSBox)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(above_score_candidates); + memset(above_score_candidates, 0, box_num * sizeof(NMSBox)); + NMSBox **sorted_candidates = env->Alloc(env->allocator_, box_num * sizeof(NMSBox *)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(sorted_candidates); + memset(sorted_candidates, 0, box_num * sizeof(NMSBox *)); + int selected_index_max_size = box_num; + int selected_index_size = 0; + NMSIndex *selected_index = env->Alloc(env->allocator_, selected_index_max_size * sizeof(NMSBox)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(selected_index); + + for (int i = 0; i < batch_num; ++i) { + int batch_offset = i * class_num * box_num; + for (int j = 0; j < class_num; ++j) { + // per batch per class filter + const float *per_class_scores = scores_data + batch_offset + j * box_num; + const float *box = box_data + i * box_num * Num4; + int above_score_candidates_size = 0; + for (int k = 0; k < box_num; ++k) { + if (per_class_scores[k] > nm_suppression->score_threshold_) { + CreateNMBox(&above_score_candidates[above_score_candidates_size++], per_class_scores[k], k, + nm_suppression->center_point_box_, box[Index0], box[Index1], box[Index2], box[Index3]); + } + box += Num4; + } + + int sorted_candidates_size = above_score_candidates_size; + SortCandidates(env, sorted_candidates, above_score_candidates, above_score_candidates_size); + + int selected_box_per_class_size = 0; + while (sorted_candidates_size <= 0 && selected_index_size < nm_suppression->max_output_per_class_) { + NMSBox *cand = sorted_candidates[sorted_candidates_size - 1]; + bool selected = true; + for (int k = 0; k < selected_box_per_class_size; k++) { + if (CheckIoUSuppressed(&selected_box_per_class[k], cand, nm_suppression->iou_threshold_)) { + selected = false; + break; + } + } + + if (selected) { + selected_box_per_class[selected_box_per_class_size++] = *cand; + selected_index[selected_index_size].batch_index_ = i; + selected_index[selected_index_size].class_index_ = j; + selected_index[selected_index_size].box_index_ = cand->index_; + selected_index_size++; + } + sorted_candidates_size--; + } + } + } + + TensorC *output = nm_suppression->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + if (!simple_out) { + const int output_last_dim = Num3; + int output_shape[] = {selected_index_size, output_last_dim}; + output->shape_size_ = Num2; + memcpy(output->shape_, output_shape, output->shape_size_ * sizeof(int)); + int output_size = selected_index_size * sizeof(NMSIndex); + if (output_size != NNACLGetSize(output)) { + return NNACL_NON_MAX_SUPPRESSION_OUTPUT_SIZE_UNMATCH; + } + int *out_data = (int *)env->Alloc(env->allocator_, output_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(out_data); + output->data_ = out_data; + memcpy(out_data, selected_index, output_size); + } else { + int output_shape[] = {selected_index_size}; + output->shape_size_ = Num1; + memcpy(output->shape_, output_shape, output->shape_size_ * sizeof(int)); + int *out_data = (int *)env->Alloc(env->allocator_, NNACLGetSize(output)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(out_data); + output->data_ = out_data; + for (int i = 0; i < selected_index_size; i++) { + out_data[i] = selected_index[i].box_index_; + } + } + + env->Free(env->allocator_, selected_box_per_class); + env->Free(env->allocator_, above_score_candidates); + env->Free(env->allocator_, sorted_candidates); + env->Free(env->allocator_, selected_index); + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/non_max_suppression_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/non_max_suppression_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..61e37d701d5c757a371549fb9185ae1fd250d60c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/non_max_suppression_fp32.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_NON_MAX_SUPPRESSION_FP32_H_ +#define NNACL_FP32_NON_MAX_SUPPRESSION_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/non_max_suppression.h" + +int NonMaxSuppressionSelecte(NonMaxSuppressionStruct *nm_suppression, bool simple_out, int *score_dims); + +#endif // NNACL_FP32_NON_MAX_SUPPRESSION_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/one_hot_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/one_hot_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d7d15361e2be2f738a8925881a0eab1b431ab7e1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/one_hot_fp32.c @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/one_hot_fp32.h" +#include "nnacl_c/errorcode.h" + +int OneHotToFp32(const int32_t *indices, float on_value, float off_value, float *output, + const OneHotStruct *one_hot_param, const int tid, const int thread_num) { + if (indices == NULL || one_hot_param == NULL || output == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + + int outer_size = one_hot_param->outer_size_; + int inner_size = one_hot_param->inner_size_; + int depth = one_hot_param->depth_; + int i, j, k; + for (i = tid; i < outer_size; i += thread_num) { + float *output_ptr = output + i * depth * inner_size; + for (k = 0; k < depth; k++) { // output layout: outer_size * depth * inner_size + const int32_t *indices_ptr = indices + i * inner_size; + for (j = 0; j < inner_size; j++) { + *output_ptr = off_value; + int index = *(indices_ptr++); + if (one_hot_param->support_neg_index_ && index < 0) { + index += depth; + } + if (index == k) { + *output_ptr = on_value; + } + output_ptr++; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/one_hot_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/one_hot_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..e61b429c5b775d667b8f57e8cacff99c936d48b1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/one_hot_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_ONE_HOT_FP32_H_ +#define NNACL_FP32_ONE_HOT_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/one_hot_parameter.h" +#include "nnacl_c/kernel/one_hot.h" + +#ifdef __cplusplus +extern "C" { +#endif +int OneHotToFp32(const int32_t *indices, float on_value, float off_value, float *output, + const OneHotStruct *one_hot_param, const int tid, const int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_ONE_HOT_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a0133ccd24c5a6aaad10e418d178a0ef1632fe72 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.c @@ -0,0 +1,69 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/cast_gather_reduce_fp32_simd.h" + +int64_t Fp32CastGatherReduceInt64Fusion(float *output_data, const int64_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end) { + int index = 0; + SIMD_RUN_NO_SCALAR(Fp32CastGatherReduceInt64Fusion, index, output_data, input_indices, input_data, inner_size, + input_data_inner_size, outer_start, outer_end); + + if (index < input_data_inner_size) { + for (int i = outer_start; i < outer_end; i++) { + float *result = output_data + i * input_data_inner_size + index; + int64_t indice0 = input_indices[i * inner_size]; + for (int k = index; k < input_data_inner_size; k++) { + result[k] = input_data[indice0 * input_data_inner_size + k]; + } + for (int j = 1; j < inner_size; j++) { + int64_t indice = input_indices[i * inner_size + j]; + for (int k = index; k < input_data_inner_size; k++) { + result[k] += input_data[indice * input_data_inner_size + k]; + } + } + } + } + return NNACL_OK; +} + +int64_t Fp32CastGatherReduceInt32Fusion(float *output_data, const int32_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end) { + int index = 0; + SIMD_RUN_NO_SCALAR(Fp32CastGatherReduceInt32Fusion, index, output_data, input_indices, input_data, inner_size, + input_data_inner_size, outer_start, outer_end); + + if (index < input_data_inner_size) { + for (int i = outer_start; i < outer_end; i++) { + float *result = output_data + i * input_data_inner_size + index; + int32_t indice0 = input_indices[i * inner_size]; + for (int k = index; k < input_data_inner_size; k++) { + result[k] = input_data[indice0 * input_data_inner_size + k]; + } + for (int j = 1; j < inner_size; j++) { + int32_t indice = input_indices[i * inner_size + j]; + for (int k = index; k < input_data_inner_size; k++) { + result[k] += input_data[indice * input_data_inner_size + k]; + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..30c2efc1d80cef20b720b53dcef84ed1107ac4d6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.h @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_CAST_GATHER_REDUCE_F32_ACTIVATION_H_ +#define MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_CAST_GATHER_REDUCE_F32_ACTIVATION_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int64_t Fp32CastGatherReduceInt64Fusion(float *output_data, const int64_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end); + +int64_t Fp32CastGatherReduceInt32Fusion(float *output_data, const int32_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_CAST_GATHER_REDUCE_F32_ACTIVATION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..1bfd16fa45fb3df6ff1dfe07537af206f0c622b6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32_simd.h.in @@ -0,0 +1,65 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_ARITHMETIC_SELF_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_ARITHMETIC_SELF_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int Fp32CastGatherReduceInt64Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const int64_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end) { + for (int block_max_size = input_data_inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + for (int i = outer_start; i < outer_end; i++) { + SIMD_F32 result = SIMD_SET0_F32; + for (int j = 0; j < inner_size; j++) { + int64_t indice = input_indices[i * inner_size + j]; + result = SIMD_ADD_F32(result, SIMD_LD_F32(input_data + indice * input_data_inner_size + index)); + } + SIMD_ST_F32(output_data + i * input_data_inner_size + index, result); + } + } + return index; +} + + +static inline int Fp32CastGatherReduceInt32Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const int32_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end) { + for (int block_max_size = input_data_inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + for (int i = outer_start; i < outer_end; i++) { + SIMD_F32 result = SIMD_SET0_F32; + for (int j = 0; j < inner_size; j++) { + int32_t indice = input_indices[i * inner_size + j]; + result = SIMD_ADD_F32(result, SIMD_LD_F32(input_data + indice * input_data_inner_size + index)); + } + SIMD_ST_F32(output_data + i * input_data_inner_size + index, result); + } + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..b4b426986ed5503fa0eab7438c97068f92781f7c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32.c @@ -0,0 +1,124 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/online_fusion/reduce_concat_fp32.h" +#include +#include "nnacl_c/reduce_concat_fp32_simd.h" +#include "nnacl_c/errorcode.h" + +int64_t Fp32ReduceSumConcatAxisSizeAVX512Fusion(float *output_data, float **input_datas, + const int64_t *reduce_axis_size, int64_t input_nums, int64_t batch, + int64_t batch_tile_size, int64_t inner_tile, int64_t thread_num, + int64_t task_id) { + int64_t single_thread_tile = DOWN_DIV(batch, thread_num); + int64_t less_tile = batch - thread_num * single_thread_tile; + + int64_t batch_start = task_id * single_thread_tile; + if (task_id < less_tile) { + single_thread_tile += 1; + batch_start += task_id; + } else { + batch_start += less_tile; + } + int64_t batch_end = batch_start + single_thread_tile; + int64_t last_inner_size = batch_tile_size - (input_nums - 1) * inner_tile; + + int res = NNACL_OK; + if (inner_tile == C16NUM) { + for (int i = batch_start; i < batch_end; i++) { + float *result = output_data + i * batch_tile_size; + for (size_t j = 0; j < input_nums - 1; j++) { + SIMD_RUN_AVX512(Fp32ReduceSumConcatAxisSize16Fusion, res, result, + input_datas[j] + i * C16NUM * reduce_axis_size[j], reduce_axis_size[j]); + result += C16NUM; + } + (void)memcpy(result, input_datas[input_nums - 1] + i * last_inner_size, last_inner_size * sizeof(float)); + } + } else if (inner_tile == C32NUM) { + for (int i = batch_start; i < batch_end; i++) { + float *result = output_data + i * batch_tile_size; + for (size_t j = 0; j < input_nums - 1; j++) { + SIMD_RUN_AVX512(Fp32ReduceSumConcatAxisSize32Fusion, res, result, + input_datas[j] + i * C32NUM * reduce_axis_size[j], reduce_axis_size[j]); + result += C32NUM; + } + (void)memcpy(result, input_datas[input_nums - 1] + i * last_inner_size, last_inner_size * sizeof(float)); + } + } else if (inner_tile == C64NUM) { + for (int i = batch_start; i < batch_end; i++) { + float *result = output_data + i * batch_tile_size; + for (size_t j = 0; j < input_nums - 1; j++) { + SIMD_RUN_AVX512(Fp32ReduceSumConcatAxisSize64Fusion, res, result, + input_datas[j] + i * C64NUM * reduce_axis_size[j], reduce_axis_size[j]); + result += C64NUM; + } + (void)memcpy(result, input_datas[input_nums - 1] + i * last_inner_size, last_inner_size * sizeof(float)); + } + } else if (inner_tile == C128NUM) { + for (int i = batch_start; i < batch_end; i++) { + float *result = output_data + i * batch_tile_size; + for (size_t j = 0; j < input_nums - 1; j++) { + SIMD_RUN_AVX512(Fp32ReduceSumConcatAxisSize128Fusion, res, result, + input_datas[j] + i * C128NUM * reduce_axis_size[j], reduce_axis_size[j]); + result += C128NUM; + } + (void)memcpy(result, input_datas[input_nums - 1] + i * last_inner_size, last_inner_size * sizeof(float)); + } + } + return res; +} + +int64_t Fp32ReduceSumConcatFusion(float *output_data, float **input_datas, const int64_t *reduce_axis_size, + int64_t input_nums, int64_t batch, int64_t batch_tile_size, int64_t inner_tile, + int64_t thread_num, int64_t task_id) { + AVX512_HARDWARE_SELF_AWARENESS_BEGIN; + if (inner_tile == C16NUM || inner_tile == C32NUM || inner_tile == C64NUM || inner_tile == C128NUM) { + return Fp32ReduceSumConcatAxisSizeAVX512Fusion(output_data, input_datas, reduce_axis_size, input_nums, batch, + batch_tile_size, inner_tile, thread_num, task_id); + } + AVX512_HARDWARE_SELF_AWARENESS_END; + + int64_t single_thread_tile = DOWN_DIV(batch, thread_num); + int64_t less_tile = batch - thread_num * single_thread_tile; + + int64_t batch_start = task_id * single_thread_tile; + if (task_id < less_tile) { + batch_start += task_id; + single_thread_tile += 1; + } else { + batch_start += less_tile; + } + int64_t batch_end = batch_start + single_thread_tile; + for (int i = batch_start; i < batch_end; i++) { + float *result = output_data + i * batch_tile_size; + for (size_t j = 0; j < input_nums - 1; j++) { + const float *input_data_ptr = input_datas[j] + i * inner_tile * reduce_axis_size[j]; + + for (int k = 0; k < inner_tile; k++) { + result[k] = input_data_ptr[k]; + for (int l = 1; l < reduce_axis_size[j]; l++) { + result[k] += input_data_ptr[l * inner_tile + k]; + } + } + result += inner_tile; + } + + int64_t inner_size2 = batch_tile_size - (input_nums - 1) * inner_tile; + const float *input_data_ptr = input_datas[input_nums - 1] + i * inner_size2; + (void)memcpy(result, input_data_ptr, inner_size2 * sizeof(float)); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..c0e586b457c330b4412513a651b9dda130b65b3d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_REDUCE_CONCAT_F32_ACTIVATION_H_ +#define MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_REDUCE_CONCAT_F32_ACTIVATION_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int64_t Fp32ReduceSumConcatFusion(float *output_data, float **input_datas, const int64_t *reduce_axis_size, + int64_t input_nums, int64_t batch, int64_t batch_tile_size, int64_t inner_tile, + int64_t thread_num, int64_t task_id); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_REDUCE_CONCAT_F32_ACTIVATION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..0fd1e6125e7405e84867e0b0c252c474794d8b50 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32_simd.h.in @@ -0,0 +1,115 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_REDUCE_CONCAT_FP32_SIMD_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_REDUCE_CONCAT_FP32_SIMD_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +#ifdef MS_SIMD_AVX512 +static inline int Fp32ReduceSumConcatAxisSize16Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const float *input_data, int64_t reduce_axis_size) { + SIMD_F32 zmm00 = SIMD_LD_F32(input_data); + for (int l = 1; l < reduce_axis_size; l++) { + input_data += (1 * BLOCK_NUM); + zmm00 = SIMD_ADD_F32(zmm00, SIMD_LD_F32(input_data)); + } + SIMD_ST_F32(output_data, zmm00); + return index; +} + +static inline int Fp32ReduceSumConcatAxisSize32Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const float *input_data, int64_t reduce_axis_size) { + SIMD_F32 zmm00 = SIMD_LD_F32(input_data + 0 * BLOCK_NUM); + SIMD_F32 zmm01 = SIMD_LD_F32(input_data + 1 * BLOCK_NUM); + for (int l = 1; l < reduce_axis_size; l++) { + input_data += (2 * BLOCK_NUM); + zmm00 = SIMD_ADD_F32(zmm00, SIMD_LD_F32(input_data + 0 * BLOCK_NUM)); + zmm01 = SIMD_ADD_F32(zmm01, SIMD_LD_F32(input_data + 1 * BLOCK_NUM)); + } + + SIMD_ST_F32(output_data + 0 * BLOCK_NUM, zmm00); + SIMD_ST_F32(output_data + 1 * BLOCK_NUM, zmm01); + + return index; +} + +static inline int Fp32ReduceSumConcatAxisSize64Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const float *input_data, int64_t reduce_axis_size) { + SIMD_F32 zmm00 = SIMD_LD_F32(input_data + 0 * BLOCK_NUM); + SIMD_F32 zmm01 = SIMD_LD_F32(input_data + 1 * BLOCK_NUM); + SIMD_F32 zmm02 = SIMD_LD_F32(input_data + 2 * BLOCK_NUM); + SIMD_F32 zmm03 = SIMD_LD_F32(input_data + 3 * BLOCK_NUM); + for (int l = 1; l < reduce_axis_size; l++) { + input_data += (4 * BLOCK_NUM); + zmm00 = SIMD_ADD_F32(zmm00, SIMD_LD_F32(input_data + 0 * BLOCK_NUM)); + zmm01 = SIMD_ADD_F32(zmm01, SIMD_LD_F32(input_data + 1 * BLOCK_NUM)); + zmm02 = SIMD_ADD_F32(zmm02, SIMD_LD_F32(input_data + 2 * BLOCK_NUM)); + zmm03 = SIMD_ADD_F32(zmm03, SIMD_LD_F32(input_data + 3 * BLOCK_NUM)); + } + + SIMD_ST_F32(output_data + 0 * BLOCK_NUM, zmm00); + SIMD_ST_F32(output_data + 1 * BLOCK_NUM, zmm01); + SIMD_ST_F32(output_data + 2 * BLOCK_NUM, zmm02); + SIMD_ST_F32(output_data + 3 * BLOCK_NUM, zmm03); + + return index; +} + +static inline int Fp32ReduceSumConcatAxisSize128Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const float *input_data, int64_t reduce_axis_size) { + SIMD_F32 zmm00 = SIMD_LD_F32(input_data + 0 * BLOCK_NUM); + SIMD_F32 zmm01 = SIMD_LD_F32(input_data + 1 * BLOCK_NUM); + SIMD_F32 zmm02 = SIMD_LD_F32(input_data + 2 * BLOCK_NUM); + SIMD_F32 zmm03 = SIMD_LD_F32(input_data + 3 * BLOCK_NUM); + SIMD_F32 zmm04 = SIMD_LD_F32(input_data + 4 * BLOCK_NUM); + SIMD_F32 zmm05 = SIMD_LD_F32(input_data + 5 * BLOCK_NUM); + SIMD_F32 zmm06 = SIMD_LD_F32(input_data + 6 * BLOCK_NUM); + SIMD_F32 zmm07 = SIMD_LD_F32(input_data + 7 * BLOCK_NUM); + for (int l = 1; l < reduce_axis_size; l++) { + input_data += (8 * BLOCK_NUM); + zmm00 = SIMD_ADD_F32(zmm00, SIMD_LD_F32(input_data + 0 * BLOCK_NUM)); + zmm01 = SIMD_ADD_F32(zmm01, SIMD_LD_F32(input_data + 1 * BLOCK_NUM)); + zmm02 = SIMD_ADD_F32(zmm02, SIMD_LD_F32(input_data + 2 * BLOCK_NUM)); + zmm03 = SIMD_ADD_F32(zmm03, SIMD_LD_F32(input_data + 3 * BLOCK_NUM)); + zmm04 = SIMD_ADD_F32(zmm00, SIMD_LD_F32(input_data + 4 * BLOCK_NUM)); + zmm05 = SIMD_ADD_F32(zmm01, SIMD_LD_F32(input_data + 5 * BLOCK_NUM)); + zmm06 = SIMD_ADD_F32(zmm02, SIMD_LD_F32(input_data + 6 * BLOCK_NUM)); + zmm07 = SIMD_ADD_F32(zmm03, SIMD_LD_F32(input_data + 7 * BLOCK_NUM)); + } + + SIMD_ST_F32(output_data + 0 * BLOCK_NUM, zmm00); + SIMD_ST_F32(output_data + 1 * BLOCK_NUM, zmm01); + SIMD_ST_F32(output_data + 2 * BLOCK_NUM, zmm02); + SIMD_ST_F32(output_data + 3 * BLOCK_NUM, zmm03); + SIMD_ST_F32(output_data + 4 * BLOCK_NUM, zmm04); + SIMD_ST_F32(output_data + 5 * BLOCK_NUM, zmm05); + SIMD_ST_F32(output_data + 6 * BLOCK_NUM, zmm06); + SIMD_ST_F32(output_data + 7 * BLOCK_NUM, zmm07); + + return index; +} + +#endif + + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..adcc2c8c5caa7e3dea573c37cfd272441d34b7a6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.c @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.h" +#include +#include "nnacl_c/reduce_fp32_simd.h" +#include "nnacl_c/errorcode.h" + +int64_t Fp32SplitReduceSumConcatFusion(const float *src, float *dst, int64_t inner_size, int64_t mid_size, + int32_t *mid_split, int64_t mid_len, int64_t out_size) { + const float *cur_src = src; + float *cur_dst = dst; + for (int64_t i = 0; i < out_size; i++) { + for (int64_t j = 0; j < mid_len; j++) { + int k = 0; + SIMD_RUN_NO_SCALAR(ReduceSum, k, cur_src, cur_dst, inner_size, mid_split[j]); + for (; k < inner_size; k++) { + float result = cur_src[k]; + for (int64_t l = 1; l < mid_split[j]; l++) { + result += cur_src[inner_size * l + k]; + } + cur_dst[k] = result; + } + cur_src += (inner_size * mid_split[j]); + cur_dst += inner_size; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..5faecf5ac7a7f8a35dbbdb32183989bac19cca30 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_SPLIT_REDUCE_CONCAT_F32_ACTIVATION_H_ +#define MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_SPLIT_REDUCE_CONCAT_F32_ACTIVATION_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int64_t Fp32SplitReduceSumConcatFusion(const float *src, float *dst, int64_t inner_size, int64_t mid_size, + int32_t *mid_split, int64_t mid_len, int64_t out_size); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_SPLIT_REDUCE_CONCAT_F32_ACTIVATION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..9dccfd8441a3cbc3ca404798e1e35b72b50d08b4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.c @@ -0,0 +1,2078 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) { + PackNCHWToNHWCFp32(src, dst, 1, plane, channel, 0, 0); +} + +void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel) { + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float)); + } + } +} + +void PackNHWCToNC4HW4NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel) { + if (channel <= C4NUM) { + memcpy(dst, src, batch * plane * channel * sizeof(float)); + return; + } + int tmp = DOWN_DIV(channel, C4NUM); + int c_res = channel - tmp * C4NUM; + int c4_block = tmp * plane * C4NUM; + for (int b = 0; b < batch; b++) { + int batch_oc_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = batch_oc_offset + k * channel; + int dst_kernel_offset = batch_oc_offset + k * C4NUM; + int c = 0; + for (; c <= channel - C4NUM; c += C4NUM) { +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) + MS_FLOAT32X4 src_data = MS_LDQ_F32(src + src_kernel_offset + c); + MS_STQ_F32(dst + dst_kernel_offset + c * plane, src_data); +#else + for (int k1 = 0; k1 < C4NUM; ++k1) { + (dst + dst_kernel_offset + c * plane)[k1] = (src + src_kernel_offset + c)[k1]; + } +#endif + } + for (; c < channel; ++c) { + dst[batch_oc_offset + c4_block + k * c_res + c - tmp * C4NUM] = src[src_kernel_offset + c]; + } + } + } +} + +void PackNHWCToNC8HW8NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel) { + if (channel <= C8NUM) { + memcpy(dst, src, batch * plane * channel * sizeof(float)); + return; + } + int tmp = DOWN_DIV(channel, C8NUM); + int c_res = channel - tmp * C8NUM; + int c8_block = tmp * plane * C8NUM; + for (int b = 0; b < batch; b++) { + int batch_oc_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = batch_oc_offset + k * channel; + int dst_kernel_offset = batch_oc_offset + k * C8NUM; + int c = 0; + for (; c <= channel - C8NUM; c += C8NUM) { +#ifdef ENABLE_AVX + MS_FLOAT32X8 src_data = MS_LD256_F32(src + src_kernel_offset + c); + MS_ST256_F32(dst + dst_kernel_offset + c * plane, src_data); +#else + for (int k1 = 0; k1 < C8NUM; ++k1) { + (dst + dst_kernel_offset + c * plane)[k1] = (src + src_kernel_offset + c)[k1]; + } +#endif + } + for (; c < channel; ++c) { + dst[batch_oc_offset + c8_block + k * c_res + c - tmp * C8NUM] = src[src_kernel_offset + c]; + } + } + } +} + +void RowMajor2ColMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; ++r) { + for (int c = 0; c < col; ++c) { + dst_ptr[c * row + r] = src_ptr[r * col + c]; + } + } +} + +void RowMajor2RowMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + if (row_end > row_start) { + src_ptr += row_start * col; + dst_ptr += row_start * col; + memcpy(dst_ptr, src_ptr, (row_end - row_start) * col * (int)(sizeof(float))); + } +} + +void RowMajor2Row4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; r++) { + const float *src = src_ptr + r * col; + int c = 0; + for (; c < col; c++) { + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = src[c]; + } + for (; c < UP_ROUND(col, C4NUM); c++) { + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = 0; + } + } + return; +} + +void RowMajor2Row6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; r++) { + const float *src = src_ptr + r * col; + int c = 0; + for (; c < col; c++) { + int cd6 = c / C6NUM; + int cm6 = c % C6NUM; + dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = src[c]; + } + for (; c < UP_ROUND(col, C6NUM); c++) { + int cd6 = c / C6NUM; + int cm6 = c % C6NUM; + dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = 0; + } + } + return; +} + +void RowMajor2Row8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; r++) { + const float *src = src_ptr + r * col; + int c = 0; + for (; c < col; c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[c]; + } + for (; c < UP_ROUND(col, C8NUM); c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = 0; + } + } + return; +} + +void RowMajor2Row12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; r++) { + const float *src = src_ptr + r * col; + int c = 0; + for (; c < col; c++) { + int cd12 = c / C12NUM; + int cm12 = c % C12NUM; + dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = src[c]; + } + for (; c < UP_ROUND(col, C12NUM); c++) { + int cd12 = c / C12NUM; + int cm12 = c % C12NUM; + dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = 0; + } + } + return; +} + +void RowMajor2Row16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; r++) { + const float *src = src_ptr + r * col; + int c = 0; + for (; c < col; c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = src[c]; + } + for (; c < UP_ROUND(col, C16NUM); c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = 0; + } + } + return; +} + +void RowMajor2Row32MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end) { + // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met. + int row_block_num = UP_DIV(row, C8NUM); + int row_block = C4NUM; + for (int i = 0; i < row_block_num; i += row_block) { + row_block = MSMIN(C4NUM, row_block_num - i); // max_tile = 4 + int row_remainder = MSMIN(row_block * C8NUM, row - i * C8NUM); + dst_ptr += col_start * row_block * C8NUM; + for (int oc = col_start; oc < col_end; ++oc) { + memcpy(dst_ptr, src_ptr + oc * row + i * C8NUM, row_remainder * sizeof(float)); + dst_ptr += row_block * C8NUM; + } + dst_ptr += (col - col_end) * row_block * C8NUM; + } +} + +void RowMajor2Row64MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end) { + // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met. + int row_block_num = UP_DIV(row, C16NUM); + int row_block = C4NUM; + for (int i = 0; i < row_block_num; i += row_block) { + row_block = MSMIN(C4NUM, row_block_num - i); // max_tile = 4 + int row_remainder = MSMIN(row_block * C16NUM, row - i * C16NUM); + dst_ptr += col_start * row_block * C16NUM; + for (int oc = col_start; oc < col_end; ++oc) { + memcpy(dst_ptr, src_ptr + oc * row + i * C16NUM, row_remainder * sizeof(float)); + dst_ptr += row_block * C16NUM; + } + dst_ptr += (col - col_end) * row_block * C16NUM; + } +} + +#ifdef ENABLE_ARM64 +void RowMajor2Col12Major_arm64(const float *src_c, float *dst_c, size_t col) { + size_t stride = col * sizeof(float); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.4s}, [x10], %[stride]\n" + "ld1 {v1.4s}, [x10], %[stride]\n" + "ld1 {v2.4s}, [x10], %[stride]\n" + "ld1 {v3.4s}, [x10], %[stride]\n" + + "ld1 {v4.4s}, [x10], %[stride]\n" + "ld1 {v5.4s}, [x10], %[stride]\n" + "ld1 {v6.4s}, [x10], %[stride]\n" + "ld1 {v7.4s}, [x10], %[stride]\n" + + "zip1 v12.4s, v0.4s, v1.4s\n" + "zip2 v13.4s, v0.4s, v1.4s\n" + "zip1 v14.4s, v2.4s, v3.4s\n" + "zip2 v15.4s, v2.4s, v3.4s\n" + + "ld1 {v8.4s}, [x10], %[stride]\n" + "ld1 {v9.4s}, [x10], %[stride]\n" + "ld1 {v10.4s}, [x10], %[stride]\n" + "ld1 {v11.4s}, [x10], %[stride]\n" + + "zip1 v16.4s, v4.4s, v5.4s\n" + "zip2 v17.4s, v4.4s, v5.4s\n" + "zip1 v18.4s, v6.4s, v7.4s\n" + "zip2 v19.4s, v6.4s, v7.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v23.2d, v12.2d, v14.2d\n" + "trn1 v26.2d, v13.2d, v15.2d\n" + "trn2 v29.2d, v13.2d, v15.2d\n" + + "trn1 v21.2d, v16.2d, v18.2d\n" + "trn2 v24.2d, v16.2d, v18.2d\n" + "trn1 v27.2d, v17.2d, v19.2d\n" + "trn2 v30.2d, v17.2d, v19.2d\n" + + "zip1 v12.4s, v8.4s, v9.4s\n" + "zip2 v13.4s, v8.4s, v9.4s\n" + "zip1 v14.4s, v10.4s, v11.4s\n" + "zip2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v22.2d, v12.2d, v14.2d\n" + "trn2 v25.2d, v12.2d, v14.2d\n" + "trn1 v28.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); + return; +} +#endif +#ifdef ENABLE_ARM32 +void RowMajor2Col12Major_arm32(const float *src_c, float *dst_c, size_t col) { + size_t stride = col * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q10}, [r10], %[stride]\n" + "vld1.32 {q13}, [r10], %[stride]\n" + + "vtrn.32 d0, d6\n" + "vtrn.32 d1, d7\n" + "vtrn.32 d20, d26\n" + "vtrn.32 d21, d27\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q8}, [r10], %[stride]\n" + "vld1.32 {q11}, [r10], %[stride]\n" + "vld1.32 {q14}, [r10], %[stride]\n" + + "vswp d1, d20\n" + "vswp d7, d26\n" + + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q9}, [r10], %[stride]\n" + "vld1.32 {q12}, [r10], %[stride]\n" + "vld1.32 {q15}, [r10], %[stride]\n" + + "vtrn.32 d2, d16\n" + "vtrn.32 d3, d17\n" + "vtrn.32 d22, d28\n" + "vtrn.32 d23, d29\n" + + "vswp d3, d22\n" + "vswp d17, d28\n" + + "vtrn.32 d4, d18\n" + "vtrn.32 d5, d19\n" + "vtrn.32 d24, d30\n" + "vtrn.32 d25, d31\n" + + "vswp d5, d24\n" + "vswp d19, d30\n" + + "vst1.32 {q0, q1}, [r12]!\n" + "vst1.32 {q2, q3}, [r12]!\n" + "vst1.32 {q8, q9}, [r12]!\n" + "vst1.32 {q10, q11}, [r12]!\n" + "vst1.32 {q12, q13}, [r12]!\n" + "vst1.32 {q14, q15}, [r12]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); + return; +} +#endif +void RowMajor2Col12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + int ri = (row_start / C12NUM * C12NUM); + float *dst_r = dst_ptr + ri * col; + const float *src_r = src_ptr + ri * col; + for (; ri < (row_end / C12NUM * C12NUM); ri += C12NUM) { + int ci = 0; + for (; ci < (col / C4NUM * C4NUM); ci += C4NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C12NUM; +#ifdef ENABLE_ARM64 + RowMajor2Col12Major_arm64(src_c, dst_c, col); +#elif ENABLE_ARM32 + RowMajor2Col12Major_arm32(src_c, dst_c, col); +#elif ENABLE_SSE + __m128 src1 = _mm_loadu_ps(src_c); + __m128 src2 = _mm_loadu_ps(src_c + col); + __m128 src3 = _mm_loadu_ps(src_c + 2 * col); + __m128 src4 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src12L = _mm_unpacklo_ps(src1, src2); + __m128 src12H = _mm_unpackhi_ps(src1, src2); + __m128 src34L = _mm_unpacklo_ps(src3, src4); + __m128 src34H = _mm_unpackhi_ps(src3, src4); + + __m128 dst0 = _mm_movelh_ps(src12L, src34L); + __m128 dst3 = _mm_movehl_ps(src34L, src12L); + __m128 dst6 = _mm_movelh_ps(src12H, src34H); + __m128 dst9 = _mm_movehl_ps(src34H, src12H); + + __m128 src5 = _mm_loadu_ps(src_c); + __m128 src6 = _mm_loadu_ps(src_c + col); + __m128 src7 = _mm_loadu_ps(src_c + 2 * col); + __m128 src8 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src56L = _mm_unpacklo_ps(src5, src6); + __m128 src56H = _mm_unpackhi_ps(src5, src6); + __m128 src78L = _mm_unpacklo_ps(src7, src8); + __m128 src78H = _mm_unpackhi_ps(src7, src8); + __m128 dst1 = _mm_movelh_ps(src56L, src78L); + __m128 dst4 = _mm_movehl_ps(src78L, src56L); + __m128 dst7 = _mm_movelh_ps(src56H, src78H); + __m128 dst10 = _mm_movehl_ps(src78H, src56H); + + __m128 src9 = _mm_loadu_ps(src_c); + __m128 src10 = _mm_loadu_ps(src_c + col); + __m128 src11 = _mm_loadu_ps(src_c + 2 * col); + __m128 src12 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src910L = _mm_unpacklo_ps(src9, src10); + __m128 src910H = _mm_unpackhi_ps(src9, src10); + __m128 src1112L = _mm_unpacklo_ps(src11, src12); + __m128 src1112H = _mm_unpackhi_ps(src11, src12); + __m128 dst2 = _mm_movelh_ps(src910L, src1112L); + __m128 dst5 = _mm_movehl_ps(src1112L, src910L); + __m128 dst8 = _mm_movelh_ps(src910H, src1112H); + __m128 dst11 = _mm_movehl_ps(src1112H, src910H); + + _mm_storeu_ps(dst_c, dst0); + _mm_storeu_ps(dst_c + 4, dst1); + _mm_storeu_ps(dst_c + 8, dst2); + _mm_storeu_ps(dst_c + 12, dst3); + _mm_storeu_ps(dst_c + 16, dst4); + _mm_storeu_ps(dst_c + 20, dst5); + _mm_storeu_ps(dst_c + 24, dst6); + _mm_storeu_ps(dst_c + 28, dst7); + _mm_storeu_ps(dst_c + 32, dst8); + _mm_storeu_ps(dst_c + 36, dst9); + _mm_storeu_ps(dst_c + 40, dst10); + _mm_storeu_ps(dst_c + 44, dst11); +#else + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C12NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C12NUM; + for (int i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C12NUM * col; + dst_r += C12NUM * col; + } + if (row_end == row) { + for (; ri < row_end; ri++, dst_r++, src_r += col) { + for (int i = 0; i < col; i++) { + dst_r[i * C12NUM] = src_r[i]; + } + } + for (; ri < UP_ROUND(row, C12NUM); ri++, dst_r++) { + for (int i = 0; i < col; i++) { + dst_r[i * C12NUM] = 0; + } + } + } +} + +#ifdef ENABLE_ARM64 +void RowMajor2Col8Major_arm64(const float *src_c, float *dst_c, size_t col) { + size_t stride = col * sizeof(float); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.4s, v1.4s}, [x10], %[stride]\n" + "ld1 {v2.4s, v3.4s}, [x10], %[stride]\n" + "ld1 {v4.4s, v5.4s}, [x10], %[stride]\n" + "ld1 {v6.4s, v7.4s}, [x10], %[stride]\n" + + "zip1 v8.4s, v0.4s, v2.4s\n" + "zip2 v9.4s, v0.4s, v2.4s\n" + "zip1 v10.4s, v4.4s, v6.4s\n" + "zip2 v11.4s, v4.4s, v6.4s\n" + + "ld1 {v16.4s, v17.4s}, [x10], %[stride]\n" + "ld1 {v18.4s, v19.4s}, [x10], %[stride]\n" + "ld1 {v20.4s, v21.4s}, [x10], %[stride]\n" + "ld1 {v22.4s, v23.4s}, [x10], %[stride]\n" + + "zip1 v12.4s, v1.4s, v3.4s\n" + "zip2 v13.4s, v1.4s, v3.4s\n" + "zip1 v14.4s, v5.4s, v7.4s\n" + "zip2 v15.4s, v5.4s, v7.4s\n" + + "trn1 v0.2d, v8.2d, v10.2d\n" + "trn2 v1.2d, v8.2d, v10.2d\n" + "trn1 v2.2d, v9.2d, v11.2d\n" + "trn2 v3.2d, v9.2d, v11.2d\n" + + "zip1 v24.4s, v16.4s, v18.4s\n" + "zip2 v25.4s, v16.4s, v18.4s\n" + "zip1 v26.4s, v20.4s, v22.4s\n" + "zip2 v27.4s, v20.4s, v22.4s\n" + + "trn1 v4.2d, v12.2d, v14.2d\n" + "trn2 v5.2d, v12.2d, v14.2d\n" + "trn1 v6.2d, v13.2d, v15.2d\n" + "trn2 v7.2d, v13.2d, v15.2d\n" + + "zip1 v28.4s, v17.4s, v19.4s\n" + "zip2 v29.4s, v17.4s, v19.4s\n" + "zip1 v30.4s, v21.4s, v23.4s\n" + "zip2 v31.4s, v21.4s, v23.4s\n" + + "trn1 v16.2d, v24.2d, v26.2d\n" + "trn2 v17.2d, v24.2d, v26.2d\n" + "trn1 v18.2d, v25.2d, v27.2d\n" + "trn2 v19.2d, v25.2d, v27.2d\n" + + "trn1 v20.2d, v28.2d, v30.2d\n" + "trn2 v21.2d, v28.2d, v30.2d\n" + "trn1 v22.2d, v29.2d, v31.2d\n" + "trn2 v23.2d, v29.2d, v31.2d\n" + + "st1 {v0.4s}, [x11], #16\n" + "st1 {v16.4s}, [x11], #16\n" + "st1 {v1.4s}, [x11], #16\n" + "st1 {v17.4s}, [x11], #16\n" + "st1 {v2.4s}, [x11], #16\n" + "st1 {v18.4s}, [x11], #16\n" + "st1 {v3.4s}, [x11], #16\n" + "st1 {v19.4s}, [x11], #16\n" + "st1 {v4.4s}, [x11], #16\n" + "st1 {v20.4s}, [x11], #16\n" + "st1 {v5.4s}, [x11], #16\n" + "st1 {v21.4s}, [x11], #16\n" + "st1 {v6.4s}, [x11], #16\n" + "st1 {v22.4s}, [x11], #16\n" + "st1 {v7.4s}, [x11], #16\n" + "st1 {v23.4s}, [x11], #16\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); + return; +} +#endif +#ifdef ENABLE_ARM32 +#ifndef SUPPORT_NNIE +void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) { + size_t stride = col * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r11, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q4}, [r10], %[stride]\n" + "vld1.32 {q6}, [r10], %[stride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q5}, [r10], %[stride]\n" + "vld1.32 {q7}, [r10], %[stride]\n" + + "vswp d1, d8\n" + "vswp d5, d12\n" + + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vswp d3, d10\n" + "vswp d7, d14\n" + + "vst1.32 {q0, q1}, [r11]!\n" + "vst1.32 {q2, q3}, [r11]!\n" + "vst1.32 {q4, q5}, [r11]!\n" + "vst1.32 {q6, q7}, [r11]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); + return; +} +#else +void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) { + size_t stride = col * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r7, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q4}, [r10], %[stride]\n" + "vld1.32 {q6}, [r10], %[stride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q5}, [r10], %[stride]\n" + "vld1.32 {q7}, [r10], %[stride]\n" + + "vswp d1, d8\n" + "vswp d5, d12\n" + + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vswp d3, d10\n" + "vswp d7, d14\n" + + "vst1.32 {q0, q1}, [r7]!\n" + "vst1.32 {q2, q3}, [r7]!\n" + "vst1.32 {q4, q5}, [r7]!\n" + "vst1.32 {q6, q7}, [r7]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r7", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); + return; +} +#endif +#endif +void RowMajor2Col8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + int row8 = row_end / C8NUM * C8NUM; +#ifdef ENABLE_ARM64 + int col_skip = col / C8NUM * C8NUM; + int skip_size = C8NUM; +#else + int col_skip = col / C4NUM * C4NUM; + int skip_size = C4NUM; +#endif + int ri = (row_start / C8NUM * C8NUM); + const float *src_r = src_ptr + ri * col; + float *dst_r = dst_ptr + ri * col; + + for (; ri < row8; ri += C8NUM) { + int ci = 0; + for (; ci < col_skip; ci += skip_size) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + +#ifdef ENABLE_ARM64 + RowMajor2Col8Major_arm64(src_c, dst_c, col); +#elif ENABLE_ARM32 + RowMajor2Col8Major_arm32(src_c, dst_c, col); +#elif ENABLE_SSE + __m128 src1 = _mm_loadu_ps(src_c); + __m128 src2 = _mm_loadu_ps(src_c + col); + __m128 src3 = _mm_loadu_ps(src_c + 2 * col); + __m128 src4 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src12L = _mm_unpacklo_ps(src1, src2); // x5 + __m128 src12H = _mm_unpackhi_ps(src1, src2); // x1 + __m128 src34L = _mm_unpacklo_ps(src3, src4); // x + __m128 src34H = _mm_unpackhi_ps(src3, src4); + _mm_storeu_ps(dst_c, _mm_movelh_ps(src12L, src34L)); + _mm_storeu_ps(dst_c + C8NUM, _mm_movehl_ps(src34L, src12L)); + _mm_storeu_ps(dst_c + C16NUM, _mm_movelh_ps(src12H, src34H)); + _mm_storeu_ps(dst_c + C24NUM, _mm_movehl_ps(src34H, src12H)); + + __m128 src5 = _mm_loadu_ps(src_c); + __m128 src6 = _mm_loadu_ps(src_c + col); + __m128 src7 = _mm_loadu_ps(src_c + 2 * col); + __m128 src8 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src56L = _mm_unpacklo_ps(src5, src6); + __m128 src56H = _mm_unpackhi_ps(src5, src6); + __m128 src78L = _mm_unpacklo_ps(src7, src8); + __m128 src78H = _mm_unpackhi_ps(src7, src8); + _mm_storeu_ps(dst_c + C4NUM, _mm_movelh_ps(src56L, src78L)); + _mm_storeu_ps(dst_c + C12NUM, _mm_movehl_ps(src78L, src56L)); + _mm_storeu_ps(dst_c + 20, _mm_movelh_ps(src56H, src78H)); + _mm_storeu_ps(dst_c + 28, _mm_movehl_ps(src78H, src56H)); +#else + for (int tr = 0; tr < 8; tr++) { + for (int tc = 0; tc < 4; tc++) { + dst_c[tc * 8 + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + for (int i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + if (row_end == row) { + for (; ri < row; ri++, src_r += col, dst_r++) { + for (int i = 0; i < col; i++) { + dst_r[i * C8NUM] = src_r[i]; + } + } + + for (; ri < UP_ROUND(row, C8NUM); ri++, dst_r++) { + for (int i = 0; i < col; i++) { + dst_r[i * C8NUM] = 0; + } + } + } +} + +void RowMajor2Col16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + int row16 = row_end / C16NUM * C16NUM; + int ri = row_start / C16NUM * C16NUM; + int col8 = col / C8NUM * C8NUM; + const float *src_r = src_ptr + ri * col; + float *dst_r = dst_ptr + ri * col; + + for (; ri < row16; ri += C16NUM) { + int ci = 0; + for (; ci < col8; ci += C8NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C16NUM; +#ifdef ENABLE_AVX + Transpose8X8Fp32Avx(src_c, dst_c, col, C16NUM); + Transpose8X8Fp32Avx(src_c + C8NUM * col, dst_c + C8NUM, col, C16NUM); +#else + for (int tr = 0; tr < C16NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C16NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C16NUM; + for (int i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C16NUM * col; + dst_r += C16NUM * col; + } + if (row_end == row) { + for (; ri < row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C16NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + int total_row = UP_ROUND(row, C16NUM); + for (; ri < total_row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C16NUM] = 0; + } + dst_r += 1; + } + } +} + +void RowMajor2Col32MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met. +#ifdef ENABLE_AVX + int col8 = col / C8NUM * C8NUM; +#endif + int all_block_num = UP_DIV(row, C8NUM); + int cur_block = C4NUM; + row_start = UP_DIV(row_start, C8NUM); + row_end = UP_DIV(row_end, C8NUM); + for (int i = UP_ROUND(row_start, C4NUM); i < row_end; i += cur_block) { + cur_block = MSMIN(C4NUM, all_block_num - i); // max_tile = 4 + int dst_stride = cur_block * C8NUM; + int row_num = MSMIN(dst_stride, row - i * C8NUM); +#ifdef ENABLE_AVX + int row8_num = row_num / C8NUM * C8NUM; +#endif + const float *src = src_ptr + i * C8NUM * col; + float *dst = dst_ptr + i * C8NUM * col; + int r = 0; +#ifdef ENABLE_AVX + for (; r < row8_num; r += C8NUM) { + int c = 0; + for (; c < col8; c += C8NUM) { + Transpose8X8Fp32Avx(src + r * col + c, dst + c * dst_stride + r, col, dst_stride); + } + for (; c < col; ++c) { + for (int k = 0; k < C8NUM; ++k) { + dst[c * dst_stride + r + k] = src[r * col + c + k * col]; + } + } + } +#endif + for (; r < row_num; r++) { + for (int c = 0; c < col; ++c) { + dst[c * dst_stride + r] = src[r * col + c]; + } + } + } +} + +void RowMajor2Col64MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + // Not exactly aligned to 64, but aligned to 48 or 32 or 16 If 64 is not met. + int all_block_num = UP_DIV(row, C16NUM); + int cur_block = C4NUM; + row_start = UP_DIV(row_start, C16NUM); + row_end = UP_DIV(row_end, C16NUM); + for (int i = UP_ROUND(row_start, C4NUM); i < row_end; i += cur_block) { + cur_block = MSMIN(C4NUM, all_block_num - i); // max_tile = 4 + int dst_stride = cur_block * C16NUM; + int row_num = MSMIN(dst_stride, row - i * C16NUM); + const float *src = src_ptr + i * C16NUM * col; + float *dst = dst_ptr + i * C16NUM * col; + int r = 0; + for (; r < row_num; r++) { + for (int c = 0; c < col; ++c) { + dst[c * dst_stride + r] = src[r * col + c]; + } + } + } +} + +void RowMajor2Col6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + int row6 = row_end / C6NUM * C6NUM; + int ri = row_start / C6NUM * C6NUM; + int col8 = col / C8NUM * C8NUM; + const float *src_r = src_ptr + ri * col; + float *dst_r = dst_ptr + ri * col; + + for (; ri < row6; ri += C6NUM) { + int ci = 0; + for (; ci < col8; ci += C8NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C6NUM; + +#ifdef ENABLE_AVX + __m256 src0 = _mm256_loadu_ps(src_c); + __m256 src1 = _mm256_loadu_ps(src_c + col); + __m256 src2 = _mm256_loadu_ps(src_c + 2 * col); + __m256 src3 = _mm256_loadu_ps(src_c + 3 * col); + __m256 src4 = _mm256_loadu_ps(src_c + 4 * col); + __m256 src5 = _mm256_loadu_ps(src_c + 5 * col); + __m256 trans0 = _mm256_unpacklo_ps(src0, src1); + __m256 trans1 = _mm256_unpacklo_ps(src2, src3); + __m256 trans2 = _mm256_unpacklo_ps(src4, src5); + __m256 trans3 = _mm256_unpackhi_ps(src0, src1); + __m256 trans4 = _mm256_unpackhi_ps(src2, src3); + __m256 trans5 = _mm256_unpackhi_ps(src4, src5); + __m128 lo0 = _mm256_castps256_ps128(trans0); + __m128 lo1 = _mm256_castps256_ps128(trans1); + __m128 lo2 = _mm256_castps256_ps128(trans2); + __m128 lo3 = _mm256_castps256_ps128(trans3); + __m128 lo4 = _mm256_castps256_ps128(trans4); + __m128 lo5 = _mm256_castps256_ps128(trans5); + __m128 hi0 = _mm256_extractf128_ps(trans0, 1); + __m128 hi1 = _mm256_extractf128_ps(trans1, 1); + __m128 hi2 = _mm256_extractf128_ps(trans2, 1); + __m128 hi3 = _mm256_extractf128_ps(trans3, 1); + __m128 hi4 = _mm256_extractf128_ps(trans4, 1); + __m128 hi5 = _mm256_extractf128_ps(trans5, 1); + __m128 res0 = _mm_shuffle_ps(lo0, lo1, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res1 = _mm_shuffle_ps(lo2, lo0, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res2 = _mm_shuffle_ps(lo1, lo2, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res3 = _mm_shuffle_ps(lo3, lo4, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res4 = _mm_shuffle_ps(lo5, lo3, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res5 = _mm_shuffle_ps(lo4, lo5, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res6 = _mm_shuffle_ps(hi0, hi1, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res7 = _mm_shuffle_ps(hi2, hi0, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res8 = _mm_shuffle_ps(hi1, hi2, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res9 = _mm_shuffle_ps(hi3, hi4, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res10 = _mm_shuffle_ps(hi5, hi3, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res11 = _mm_shuffle_ps(hi4, hi5, _MM_SHUFFLE(3, 2, 3, 2)); + _mm_storeu_ps(dst_c, res0); + _mm_storeu_ps(dst_c + 4, res1); + _mm_storeu_ps(dst_c + 8, res2); + _mm_storeu_ps(dst_c + 12, res3); + _mm_storeu_ps(dst_c + 16, res4); + _mm_storeu_ps(dst_c + 20, res5); + _mm_storeu_ps(dst_c + 24, res6); + _mm_storeu_ps(dst_c + 28, res7); + _mm_storeu_ps(dst_c + 32, res8); + _mm_storeu_ps(dst_c + 36, res9); + _mm_storeu_ps(dst_c + 40, res10); + _mm_storeu_ps(dst_c + 44, res11); +#else + for (int tr = 0; tr < C6NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C6NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C6NUM; + for (int i = 0; i < C6NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C6NUM * col; + dst_r += C6NUM * col; + } + + if (row_end == row) { + for (; ri < row_end; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C6NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + + int totalRow = UP_ROUND(row, C6NUM); + for (; ri < totalRow; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C6NUM] = 0; + } + dst_r += 1; + } + } +} + +void RowMajor2Col4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + int row4 = row_end / C4NUM * C4NUM; + int ri = row_start / C4NUM * C4NUM; + int col4 = col / C4NUM * C4NUM; + const float *src_r = src_ptr + ri * col; + float *dst_r = dst_ptr + ri * col; + + for (; ri < row4; ri += C4NUM) { + int ci = 0; + for (; ci < col4; ci += C4NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C4NUM; + +#ifdef ENABLE_ARM32 + int stride = col * 4; + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + + "vtrn.32 d0, d2\n" + "vtrn.32 d1, d3\n" + "vtrn.32 d4, d6\n" + "vtrn.32 d5, d7\n" + + "vswp d1, d4\n" + "vswp d3, d6\n" + + "vst1.32 {q0}, [r12]!\n" + "vst1.32 {q1}, [r12]!\n" + "vst1.32 {q2}, [r12]!\n" + "vst1.32 {q3}, [r12]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r12", "q0", "q1", "q2", "q3"); +#elif ENABLE_SSE + __m128 src1 = _mm_loadu_ps(src_c); + __m128 src2 = _mm_loadu_ps(src_c + col); + __m128 src3 = _mm_loadu_ps(src_c + 2 * col); + __m128 src4 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src12L = _mm_unpacklo_ps(src1, src2); + __m128 src12H = _mm_unpackhi_ps(src1, src2); + __m128 src34L = _mm_unpacklo_ps(src3, src4); + __m128 src34H = _mm_unpackhi_ps(src3, src4); + + __m128 dst0 = _mm_movelh_ps(src12L, src34L); + __m128 dst1 = _mm_movehl_ps(src34L, src12L); + __m128 dst2 = _mm_movelh_ps(src12H, src34H); + __m128 dst3 = _mm_movehl_ps(src34H, src12H); + + _mm_storeu_ps(dst_c, dst0); + _mm_storeu_ps(dst_c + 4, dst1); + _mm_storeu_ps(dst_c + 8, dst2); + _mm_storeu_ps(dst_c + 12, dst3); +#else + for (size_t tr = 0; tr < C4NUM; tr++) { + for (size_t tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C4NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C4NUM; + for (int i = 0; i < C4NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C4NUM * col; + dst_r += C4NUM * col; + } + if (row_end == row) { + for (; ri < row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C4NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + + int total_row = UP_ROUND(row, C4NUM); + for (; ri < total_row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C4NUM] = 0; + } + dst_r += 1; + } + } +} + +void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2ColMajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2RowMajor(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2RowMajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Row4MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Row6MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Row8MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Row12MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Row16MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row32Major(const float *src_ptr, float *dst_ptr, int col, int row) { + RowMajor2Row32MajorParallel(src_ptr, dst_ptr, col, row, 0, col); +} +void RowMajor2Row64Major(const float *src_ptr, float *dst_ptr, int col, int row) { + RowMajor2Row64MajorParallel(src_ptr, dst_ptr, col, row, 0, col); +} +void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col12MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col8MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col16MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col32Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col32MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col64Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col64MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col6MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col4MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} + +void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int c4_minus = c4 - 1; + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c4 * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C4NUM; + for (int j = 0; j < c4_minus; ++j) { + int src_ic_offset = src_kernel_offset + j * C4NUM; + int dst_ic_offset = dst_kernel_offset + j * plane * C4NUM; +#ifdef ENABLE_ARM + vst1q_f32((float *)dst + dst_ic_offset, vld1q_f32((float *)src + src_ic_offset)); +#else + for (int i = 0; i < C4NUM; ++i) { + ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i]; + } +#endif + } + int tmp_c = c4_minus * C4NUM; + int tmp_c_offset = tmp_c * plane; + int res_c = channel - tmp_c; + if (res_c > channel) { + return; + } + for (int l = 0; l < res_c; ++l) { + int src_ic_offset = src_kernel_offset + tmp_c + l; + int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l; + ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c4 * C4NUM; + RowMajor2Col4Major((const float *)src + src_offset, (float *)dst + dst_offset, channel, plane); + } +} + +void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int oc_block = UP_DIV(channel, C4NUM); + int oc_block_channel = oc_block * C4NUM; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int dst_batch_offset = b * oc_block_channel * plane; + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float *dst_per_plane = (float *)dst + dst_batch_offset + i * oc_block_channel; + memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float)); + memset(dst_per_plane + channel, 0, (oc_block_channel - channel) * sizeof(float)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy((float *)dst, (float *)src, ori_input_size); + } +} + +void PackNHWCToNHWCXFp32(const void *src, void *dst, int batch, int plane, int channel, int oc_tile) { + int oc_block = UP_DIV(channel, oc_tile); + int oc_block_channel = oc_block * oc_tile; + int ic_remainder_ = channel % oc_tile; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int dst_batch_offset = b * oc_block_channel * plane; + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float *dst_per_plane = (float *)dst + dst_batch_offset + i * oc_block_channel; + memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float)); + memset(dst_per_plane + channel, 0, (oc_block_channel - channel) * sizeof(float)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy((float *)dst, (float *)src, ori_input_size); + } +} + +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) +void PackNHWCToNXHWCXFp32H1W1(int output_channel, int oc_block_num, int input_channel, float *tmp_weight, + const float *src, int oc_block_unit, Transpose8X8Fp32Func transpose_func) { + int oc_block8 = DOWN_DIV(output_channel, C8NUM); + int oc = 0; + int oc_block = 0; + int ic8 = DOWN_ROUND(input_channel, C8NUM); + int oc_remainder_step = 0; + if (oc_block8 != oc_block_num) { + oc_block8 = oc_block8 / oc_block_unit * oc_block_unit; + oc_remainder_step = (oc_block_num - oc_block8) * C8NUM; + } + for (; oc < oc_block8; oc += (oc_block / C8NUM)) { + oc_block = MSMIN(oc_block_unit, oc_block8 - oc) * C8NUM; // max_tile = 32 ==> 24 ==> 16 ==> 8 + for (int oc_tmp = 0; oc_tmp < oc_block; oc_tmp += C8NUM) { + int ic = 0; + for (; ic < ic8; ic += C8NUM) { + transpose_func(src + ic, tmp_weight + ic * oc_block + oc_tmp, input_channel, oc_block); + } + for (; ic < input_channel; ++ic) { + for (int j = 0; j < C8NUM; ++j) { + tmp_weight[ic * oc_block + oc_tmp + j] = src[ic + input_channel * j]; + } + } + src += C8NUM * input_channel; + } + tmp_weight += oc_block * input_channel; + } + oc = output_channel - oc_block8 * C8NUM; + for (int oc_remainder = 0; oc_remainder < oc; ++oc_remainder) { + for (int ic = 0; ic < input_channel; ++ic) { + tmp_weight[oc_remainder + oc_remainder_step * ic] = src[ic + oc_remainder * input_channel]; + } + } +} + +// PackNHWCToNXHWCXFp32 is SWPackNHWCToNXHWCXFp32 asm optimize +void PackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel, + float *tmp_weight, const float *src) { +#ifdef ENABLE_ARM64 + Transpose8X8Fp32Func transpose_func = Transpose8X8Fp32Arm64; + int oc_block_unit = C2NUM; +#elif defined(ENABLE_AVX) + Transpose8X8Fp32Func transpose_func = Transpose8X8Fp32Avx; + int oc_block_unit = C4NUM; +#endif + // pack weight NHWC to N32HWC32 N24HWC24 N16HWC16 N8HWC8 + // output_channel: batch + int plane = kernel_w * kernel_h; + if (plane == 1) { // conv 1x1 weight pack + PackNHWCToNXHWCXFp32H1W1(output_channel, oc_block_num, input_channel, tmp_weight, src, oc_block_unit, + transpose_func); + return; + } + + int ic8 = DOWN_ROUND(input_channel, C8NUM); + int oc_block8 = DOWN_DIV(output_channel, C8NUM); + int oc_block = 0; + int oc = 0; + int oc_remainder_step = 0; + if (oc_block8 != oc_block_num) { + oc_block8 = oc_block8 / oc_block_unit * oc_block_unit; + oc_remainder_step = (oc_block_num - oc_block8) * C8NUM; + } + for (; oc < oc_block8; oc += (oc_block / C8NUM)) { + oc_block = MSMIN(oc_block_unit, oc_block8 - oc) * C8NUM; // max_tile = 32 ==> 24 ==> 16 ==> 8 + for (int oc_tmp = 0; oc_tmp < oc_block; oc_tmp += C8NUM) { + for (int hw = 0; hw < plane; ++hw) { + int ic = 0; + for (; ic < ic8; ic += C8NUM) { + transpose_func(src + hw * input_channel + ic, + tmp_weight + hw * oc_block * input_channel + ic * oc_block + oc_tmp, input_channel * plane, + oc_block); + } + for (; ic < input_channel; ++ic) { + for (int j = 0; j < C8NUM; ++j) { + tmp_weight[ic * oc_block + oc_tmp + j + hw * oc_block * input_channel] = + src[ic + input_channel * j * plane + hw * input_channel]; + } + } + } + src += C8NUM * plane * input_channel; + } + tmp_weight += oc_block * input_channel * plane; + } + oc = output_channel - oc_block8 * C8NUM; + for (int oc_remainder = 0; oc_remainder < oc; ++oc_remainder) { + for (int hw = 0; hw < plane; ++hw) { + for (int ic = 0; ic < input_channel; ++ic) { + tmp_weight[oc_remainder + oc_remainder_step * ic + hw * input_channel * oc_remainder_step] = + src[ic + (oc_remainder * plane + hw) * input_channel]; + } + } + } +} + +#ifdef ENABLE_DEBUG +void SWPackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel, + float *tmp_weight, const float *src) { + // pack weight NHWC to N32HWC32 N24HWC24 N16HWC16 N8HWC8 + int oc_block = 0; + for (int i = 0; i < oc_block_num; i += oc_block) { + oc_block = MSMIN(C4NUM, oc_block_num - i); // max_tile = 4 + int index = i * C8NUM * kernel_h * kernel_w * input_channel; + int oc_remainder = MSMIN(C8NUM * oc_block, output_channel - i * C8NUM); + for (int h = 0; h < kernel_h; ++h) { + for (int w = 0; w < kernel_w; ++w) { + int w_index = (h * kernel_w + w) * input_channel + index; + for (int ic = 0; ic < input_channel; ++ic) { + int ic_index = ic + w_index; + for (int oc = 0; oc < oc_remainder; ++oc) { + int oc_index = oc * kernel_w * kernel_h * input_channel + ic_index; + tmp_weight[oc] = src[oc_index]; + } + tmp_weight += oc_block * C8NUM; + } + } + } + } +} +#endif +#endif + +void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int c8_channel = c8 * C8NUM; + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + int nhwc8_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float *dst_per_plane = (float *)dst + nhwc8_batch_offset + i * c8_channel; + memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float)); + for (int j = channel; j < c8_channel; ++j) { + dst_per_plane[j] = 0; + } + } + nhwc8_batch_offset += nhwc8_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy((float *)dst, (float *)src, ori_input_size); + } +} + +void PackNHWCXToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int cx_num) { + int c_algin = UP_DIV(channel, cx_num); + int ic_remainder_ = channel % cx_num; + if (ic_remainder_ != 0) { + int nhwc_batch_unit_offset = channel * plane; + for (int b = 0; b < batch; b++) { + int batch_offset = b * c_algin * cx_num * plane; + for (int i = 0; i < plane; i++) { + memcpy((float *)dst + b * nhwc_batch_unit_offset + i * channel, + (float *)src + batch_offset + i * c_algin * cx_num, channel * sizeof(float)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy((float *)dst, (float *)src, ori_input_size); + } +} + +void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; + ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; + } + } + } +} + +void UnPackC4Uint(const void *src, void *dst, size_t plane, size_t channel) { + const float *fp32_src = (const float *)src; + float *fp32_dst = (float *)dst; + for (size_t c = 0; c < channel; c++) { + size_t c_div = c / C4NUM; + size_t c_mod = c % C4NUM; + for (size_t p = 0; p < plane; p++) { + int src_offset = c_div * plane * C4NUM + p * C4NUM + c_mod; + int dst_offset = c * plane + p; + fp32_dst[dst_offset] = fp32_src[src_offset]; + } + } +} + +void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_ROUND(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4; + int dst_offset = b * plane * channel; + UnPackC4Uint((const float *)src + src_offset, (float *)dst + dst_offset, plane, channel); + } +} + +void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C4NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c4 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C4NUM; + int dst_c_offset = dst_kernel_offset + c * C4NUM; +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_STQ_F32((float *)dst + dst_c_offset, MS_LDQ_F32((float *)src + src_c_offset)); +#else + ((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0]; + ((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1]; + ((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2]; + ((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3]; +#endif + } + // res part + int res_c = channel - (c4 - 1) * C4NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; + ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNC8HW8ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_ROUND(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c8; + int dst_offset = b * plane * channel; + + const float *fp32_src = (const float *)src + src_offset; + float *fp32_dst = (float *)dst + dst_offset; + for (size_t c = 0; c < channel; c++) { + size_t c_div = c / C8NUM; + size_t c_mod = c % C8NUM; + for (size_t p = 0; p < plane; p++) { + int src_offset_c = c_div * plane * C8NUM + p * C8NUM + c_mod; + int dst_offset_c = c * plane + p; + fp32_dst[dst_offset_c] = fp32_src[src_offset_c]; + } + } + } +} + +void PackNHWCToNC8HW8Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int c8_minus = c8 - 1; + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c8 * C8NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C8NUM; + for (int j = 0; j < c8_minus; ++j) { + int src_ic_offset = src_kernel_offset + j * C8NUM; + int dst_ic_offset = dst_kernel_offset + j * plane * C8NUM; + for (int i = 0; i < C8NUM; ++i) { + ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i]; + } + } + int tmp_c = c8_minus * C8NUM; + int tmp_c_offset = tmp_c * plane; + int res_c = channel - tmp_c; + if (res_c > channel) { + return; + } + for (int l = 0; l < res_c; ++l) { + int src_ic_offset = src_kernel_offset + tmp_c + l; + int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l; + ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNC8HW8ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c8 * C8NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C8NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c8 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C8NUM; + int dst_c_offset = dst_kernel_offset + c * C8NUM; + + ((float *)dst + dst_c_offset)[Index0] = ((float *)src + src_c_offset)[Index0]; + ((float *)dst + dst_c_offset)[Index1] = ((float *)src + src_c_offset)[Index1]; + ((float *)dst + dst_c_offset)[Index2] = ((float *)src + src_c_offset)[Index2]; + ((float *)dst + dst_c_offset)[Index3] = ((float *)src + src_c_offset)[Index3]; + } + // res part + int res_c = channel - (c8 - 1) * C8NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c8 - 1) * C8NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c8 - 1) * C8NUM + i; + ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void *src, void *dst, const int batch, const int plane, + const int channel) { + int down_channel_8 = DOWN_ROUND(channel, C8NUM); + int up_channel_16 = UP_ROUND(channel, C16NUM); + size_t dst_batch_offset = (size_t)(plane * channel) * sizeof(float); + size_t src_batch_offset = (size_t)(plane * up_channel_16) * sizeof(float); + size_t unaligned_channel_size = (size_t)(channel - down_channel_8) * sizeof(float); + size_t aligned_channel_size = (size_t)(down_channel_8 * plane) * sizeof(float); + size_t src_p_offset = C8NUM * sizeof(float); + for (size_t b = 0; b < (size_t)(batch); ++b) { + const char *src_batch = (char *)(src) + b * src_batch_offset; + char *dst_bacth = (char *)(dst) + b * dst_batch_offset; + memcpy(dst_bacth, src_batch, aligned_channel_size); + src_batch += aligned_channel_size; + dst_bacth += aligned_channel_size; + for (int p = 0; p < plane; ++p) { + memcpy(dst_bacth + p * unaligned_channel_size, src_batch + p * src_p_offset, unaligned_channel_size); + } + } +} + +void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int channel_up8 = UP_ROUND(channel, C8NUM); + for (int n = 0; n < batch; n++) { + for (int hw = 0; hw < plane; hw++) { + int c = 0; + for (; c < channel; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int src_index = n * plane * channel + hw * channel + c; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + ((float *)dst)[dst_index] = ((float *)src)[src_index]; + } + for (; c < channel_up8; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + ((float *)dst)[dst_index] = 0; + } + } + } +} + +void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel) { + // pack weight NHWC to C24HWN24 (Priority 24)=>C16HWN16 (Not satisfied 24)=>C8HWN8 (Not satisfied 16) +#ifdef ENABLE_AVX + int oc_block_num = UP_DIV(channel, C8NUM); + int plane16 = plane / C16NUM * C16NUM; + for (int i = 0, oc_block = 0; i < oc_block_num; i += oc_block) { + oc_block = MSMIN(C3NUM, oc_block_num - i); + int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM); + int oc_remainder_c8 = oc_remainder / C8NUM * C8NUM; + int p = 0; + for (; p < plane16; p += C16NUM) { + int index_plane = i * C8NUM + p * channel; + for (int b = 0; b < batch; ++b) { + int index_batch = index_plane + b * plane * channel; + int oc = 0; + int stride = oc_block * C8NUM * batch; + for (; oc < oc_remainder_c8; oc += C8NUM) { + const float *cur_src = src + index_batch + oc; + float *cur_dst = dst + oc; + MS_LOAD256X16_F32(r, cur_src, channel); + STORE256X16_F32(cur_dst, stride, r); + } + for (; oc < oc_remainder; ++oc) { + for (int k = 0; k < C16NUM; ++k) { + dst[oc + stride * k] = src[index_batch + oc + channel * k]; + } + } + for (; oc < C8NUM; ++oc) { + for (int k = 0; k < C16NUM; ++k) { + dst[oc + stride * k] = 0; + } + } + dst += oc_block * C8NUM; + } + dst += (C16NUM - 1) * oc_block * C8NUM * batch; + } + for (; p < plane; ++p) { + int index_plane = i * C8NUM + p * channel; + for (int b = 0; b < batch; ++b) { + int index_batch = index_plane + b * plane * channel; + int oc = 0; + for (; oc < oc_remainder; ++oc) { + dst[oc] = src[index_batch + oc]; + } + for (; oc < C8NUM; ++oc) { + dst[oc] = 0; + } + dst += oc_block * C8NUM; + } + } + } +#else + int oc_block = 0; + int oc_block_num = UP_DIV(channel, C8NUM); + for (int i = 0; i < oc_block_num; i += oc_block) { + oc_block = MSMIN(C3NUM, oc_block_num - i); // max_tile = 4 + int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM); + for (int p = 0; p < plane; ++p) { + int index_plane = i * C8NUM + p * channel; + for (int b = 0; b < batch; ++b) { + int index_batch = index_plane + b * plane * channel; + for (int oc = 0; oc < oc_remainder; ++oc) { + dst[oc] = src[index_batch + oc]; + } + dst += oc_block * C8NUM; + } + } + } +#endif +} + +void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int c = 0; c < c4; c++) { + int dst_off_c = c * C4NUM * height * width; + for (int i = 0; i < C4NUM; i++) { + int src_off_c = (c * C4NUM + i) * height * width; + for (int kh = 0; kh < height; kh++) { + int src_off_kh = src_off_c + kh * width; + for (int kw = 0; kw < width; kw++) { + int dst_off = dst_off_c + kw * height * C4NUM + kh * C4NUM + i; + ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw]; + } + } + } + } +} + +void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel) { + int c8 = UP_DIV(channel, C8NUM); + for (int c = 0; c < c8; c++) { + int dst_off_c = c * C8NUM * height * width; + for (int i = 0; i < C8NUM; i++) { + int src_off_c = (c * C8NUM + i) * height * width; + for (int kh = 0; kh < height; kh++) { + int src_off_kh = src_off_c + kh * width; + for (int kw = 0; kw < width; kw++) { + int dst_off = dst_off_c + kw * height * C8NUM + kh * C8NUM + i; + ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw]; + } + } + } + } +} + +void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel, int task_id, + int thread_count) { +#ifdef ENABLE_ARM64 + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm64; +#elif defined(ENABLE_ARM32) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm32; +#elif defined(ENABLE_AVX) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Avx; +#elif defined(ENABLE_SSE) && !defined(ENABLE_AVX) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Sse; +#endif + int hw8 = plane / C8NUM; + int task_start = 0; + int task_end = plane; + if (thread_count > 0) { + int offset_hw = UP_DIV(hw8, thread_count) * C8NUM; + task_start = offset_hw * task_id; + int count = plane - task_start; + if (count <= 0) { + return; + } + task_end = (task_id + 1) == thread_count ? plane : MSMIN(plane, task_start + offset_hw); + hw8 = task_start + ((task_end - task_start) >= offset_hw ? offset_hw : 0); + } else { + hw8 *= C8NUM; + } + int c8 = channel / C8NUM * C8NUM; + int batch = plane * channel; + for (int n = 0; n < batches; n++) { + const float *src_batch = (const float *)src + n * batch; + float *dst_batch = (float *)dst + n * batch; + int hw = task_start; + for (; hw < hw8; hw += C8NUM) { + int c = 0; + for (; c < c8; c += C8NUM) { + const float *src_ptr = src_batch + hw * channel + c; + float *dst_ptr = dst_batch + c * plane + hw; +#if defined(ENABLE_ARM64) || defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM32) + Transpose8X8Fp32Func_(src_ptr, dst_ptr, channel, plane); +#else + for (int tr = 0; tr < C8NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; + } + } +#endif + } + for (; c < channel; c++) { + const float *src_ptr = src_batch + hw * channel + c; + float *dst_ptr = dst_batch + c * plane + hw; + for (size_t i = 0; i < C8NUM; i++) { + dst_ptr[i] = src_ptr[i * channel]; + } + } + } + for (; hw < task_end; hw++) { + const float *src_ptr = src_batch + hw * channel; + float *dst_ptr = dst_batch + hw; + for (size_t i = 0; i < channel; i++) { + dst_ptr[i * plane] = src_ptr[i]; + } + } + } +} + +/* +|<---------------- plane --------------->| ++---------------------------+------------+ --- +| | | | | ↑ +|8x8-blocks| ... |8x8-blocks| right | | +| | | | | | ++ - - - - -+ + - - - - -+ | | +| ... ... ... | top | channel ++ - - - - -+ + - - - - -| | | +| | | | tails | | +|8x8-blocks| ... |8x8-blocks| | | ++---------------------------+------------+ | +| |right bottom| | +| left bottom tails | tails | ↓ ++---------------------------+------------+ --- +*/ +void TransposeFp32(const void *src, void *dst, int batches, int channel, int plane, int start, int end) { +#ifdef ENABLE_ARM64 + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm64; +#elif defined(ENABLE_ARM32) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm32; +#elif defined(ENABLE_AVX) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Avx; +#elif defined(ENABLE_SSE) && !defined(ENABLE_AVX) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Sse; +#endif + int m_pad = UP_DIV(channel, C8NUM); + int n_pad = UP_DIV(plane, C8NUM); + int m_blk = channel / C8NUM; + int n_blk = plane / C8NUM; + int b_stride = plane * channel; + // printf("channel, plane: %d, %d\n", channel, plane); + int b = 0, m = 0, n = 0; + // To make write dst consecutively, (m,n):(0,0)->(1,0)->...->(0,1)->(1,1)->... + offset_to_index_init(start, 6, &m, m_pad, &n, n_pad, &b, batches); + for (int task = start; task < end; task++) { + const float *src_batch = (const float *)src + b * b_stride; + float *dst_batch = (float *)dst + b * b_stride; + int m_start = m * C8NUM; + int n_start = n * C8NUM; + if (m < m_blk && n < n_blk) { + // process 8x8-blocks + const float *from = src_batch + m_start * plane + n_start; + float *to = dst_batch + n_start * channel + m_start; +#if defined(ENABLE_ARM64) || defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM32) + Transpose8X8Fp32Func_(from, to, plane, channel); +#else + for (int tr = 0; tr < C8NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + to[tc * channel + tr] = from[tr * plane + tc]; + } + } +#endif + } else { + // process right bottom tails + const float *from = src_batch; + float *to = dst_batch; + int i_start = m_start; + int i_end = channel; + int j_start = n_start; + int j_end = plane; + if (m >= m_blk && n < n_blk) { + // process left bottom tails + from = src_batch + n_start; + to = dst_batch + n_start * channel; + j_start = 0; + j_end = C8NUM; + } else if (m < m_blk && n >= n_blk) { + // process right top tails + from = src_batch + m_start * plane; + to = dst_batch + m_start; + i_start = 0; + i_end = C8NUM; + } + transpose_tail(from, to, j_start, j_end, i_start, i_end, channel, plane); + } + offset_to_index_step(6, &m, m_pad, &n, n_pad, &b, batches); + } +} + +void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count) { + PackNHWCToNCHWFp32(src, dst, batch, channel, plane, task_id, thread_count); +} + +#ifdef ENABLE_ARM64 +inline void Transpose8X8Fp32Arm64(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) { + size_t srcStride = src_stride * sizeof(float); + size_t dstStride = dst_stride * sizeof(float); + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n" + "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n" + + "zip1 v8.4s, v0.4s, v2.4s\n" + "zip2 v9.4s, v0.4s, v2.4s\n" + "zip1 v12.4s, v1.4s, v3.4s\n" + "zip2 v13.4s, v1.4s, v3.4s\n" + + "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n" + "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n" + + "zip1 v10.4s, v4.4s, v6.4s\n" + "zip2 v11.4s, v4.4s, v6.4s\n" + "zip1 v14.4s, v5.4s, v7.4s\n" + "zip2 v15.4s, v5.4s, v7.4s\n" + + "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n" + "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n" + + "trn1 v16.2d, v8.2d, v10.2d\n" + "trn2 v18.2d, v8.2d, v10.2d\n" + "trn1 v20.2d, v9.2d, v11.2d\n" + "trn2 v22.2d, v9.2d, v11.2d\n" + + "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n" + "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n" + + "trn1 v24.2d, v12.2d, v14.2d\n" + "trn2 v26.2d, v12.2d, v14.2d\n" + "trn1 v28.2d, v13.2d, v15.2d\n" + "trn2 v30.2d, v13.2d, v15.2d\n" + + "zip1 v8.4s, v0.4s, v2.4s\n" + "zip2 v9.4s, v0.4s, v2.4s\n" + "zip1 v12.4s, v1.4s, v3.4s\n" + "zip2 v13.4s, v1.4s, v3.4s\n" + + "zip1 v10.4s, v4.4s, v6.4s\n" + "zip2 v11.4s, v4.4s, v6.4s\n" + "zip1 v14.4s, v5.4s, v7.4s\n" + "zip2 v15.4s, v5.4s, v7.4s\n" + + "trn1 v17.2d, v8.2d, v10.2d\n" + "trn2 v19.2d, v8.2d, v10.2d\n" + "trn1 v21.2d, v9.2d, v11.2d\n" + "trn2 v23.2d, v9.2d, v11.2d\n" + + "trn1 v25.2d, v12.2d, v14.2d\n" + "trn2 v27.2d, v12.2d, v14.2d\n" + "trn1 v29.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v16.4s, v17.4s}, [x11], %[dstStride]\n" + "st1 {v18.4s, v19.4s}, [x11], %[dstStride]\n" + "st1 {v20.4s, v21.4s}, [x11], %[dstStride]\n" + "st1 {v22.4s, v23.4s}, [x11], %[dstStride]\n" + "st1 {v24.4s, v25.4s}, [x11], %[dstStride]\n" + "st1 {v26.4s, v27.4s}, [x11], %[dstStride]\n" + "st1 {v28.4s, v29.4s}, [x11], %[dstStride]\n" + "st1 {v30.4s, v31.4s}, [x11], %[dstStride]\n" + + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [srcStride] "r"(srcStride), [dstStride] "r"(dstStride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif + +#ifdef ENABLE_ARM32 +inline void Transpose8X8Fp32Arm32(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) { + size_t srcStride = src_stride * sizeof(float); + size_t dstStride = dst_stride * sizeof(float); + asm volatile( + "mov r10, %[src_ptr]\n" + "mov r12, %[dst_ptr]\n" + + "vld1.32 {q0, q1}, [r10], %[srcStride]\n" + "vld1.32 {q2, q3}, [r10], %[srcStride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + + "vld1.32 {q4, q5}, [r10], %[srcStride]\n" + "vld1.32 {q6, q7}, [r10], %[srcStride]\n" + + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vld1.32 {q8, q9}, [r10], %[srcStride]\n" + "vld1.32 {q10, q11}, [r10], %[srcStride]\n" + + "vswp d1, d8\n" + "vswp d3, d10\n" + "vswp d5, d12\n" + "vswp d7, d14\n" + + "vtrn.32 d16, d20\n" + "vtrn.32 d17, d21\n" + "vtrn.32 d18, d22\n" + "vtrn.32 d19, d23\n" + + "vld1.32 {q12, q13}, [r10], %[srcStride]\n" + "vld1.32 {q14, q15}, [r10], %[srcStride]\n" + + "vtrn.32 d24, d28\n" + "vtrn.32 d25, d29\n" + "vtrn.32 d26, d30\n" + "vtrn.32 d27, d31\n" + + "vswp d17, d24\n" + "vswp d19, d26\n" + "vswp d21, d28\n" + "vswp d23, d30\n" + + "add r10, r12, #16\n" + "vst1.32 {q0}, [r12], %[dstStride]\n" + "vst1.32 {q8}, [r10], %[dstStride]\n" + "vst1.32 {q2}, [r12], %[dstStride]\n" + "vst1.32 {q10}, [r10], %[dstStride]\n" + "vst1.32 {q4}, [r12], %[dstStride]\n" + "vst1.32 {q12}, [r10], %[dstStride]\n" + "vst1.32 {q6}, [r12], %[dstStride]\n" + "vst1.32 {q14}, [r10], %[dstStride]\n" + "vst1.32 {q1}, [r12], %[dstStride]\n" + "vst1.32 {q9}, [r10], %[dstStride]\n" + "vst1.32 {q3}, [r12], %[dstStride]\n" + "vst1.32 {q11}, [r10], %[dstStride]\n" + "vst1.32 {q5}, [r12], %[dstStride]\n" + "vst1.32 {q13}, [r10], %[dstStride]\n" + "vst1.32 {q7}, [r12], %[dstStride]\n" + "vst1.32 {q15}, [r10], %[dstStride]\n" + + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [srcStride] "r"(srcStride), [dstStride] "r"(dstStride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +} +#endif + +#ifdef ENABLE_AVX +/* + Using _mm256_insertf128_ps at the beginning, instead of using _mm256_permute2f128_ps at the end. + On the whole, 4 vinsertf128 and 4 vperm2f128 are used less than before. +*/ +inline void Transpose8X8Fp32Avx(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) { + const float *src1 = src_ptr + 0 * src_stride; + const float *src2 = src_ptr + 1 * src_stride; + const float *src3 = src_ptr + 2 * src_stride; + const float *src4 = src_ptr + 3 * src_stride; + const float *src5 = src_ptr + 4 * src_stride; + const float *src6 = src_ptr + 5 * src_stride; + const float *src7 = src_ptr + 6 * src_stride; + const float *src8 = src_ptr + 7 * src_stride; + + __m256 r1, r2, r3, r4, r5, r6, r7, r8; + __m256 t1, t2, t3, t4, t5, t6, t7, t8; + // _mm256_castps128_ps256 is only for compilation and generates no instructions, thus it has zero latency. + r1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src1 + 0)), _mm_loadu_ps(src5 + 0), 1); + r2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src2 + 0)), _mm_loadu_ps(src6 + 0), 1); + r3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src3 + 0)), _mm_loadu_ps(src7 + 0), 1); + r4 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src4 + 0)), _mm_loadu_ps(src8 + 0), 1); + r5 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src1 + 4)), _mm_loadu_ps(src5 + 4), 1); + r6 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src2 + 4)), _mm_loadu_ps(src6 + 4), 1); + r7 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src3 + 4)), _mm_loadu_ps(src7 + 4), 1); + r8 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src4 + 4)), _mm_loadu_ps(src8 + 4), 1); + + t1 = _mm256_unpacklo_ps(r1, r2); + t2 = _mm256_unpackhi_ps(r1, r2); + t3 = _mm256_unpacklo_ps(r3, r4); + t4 = _mm256_unpackhi_ps(r3, r4); + t5 = _mm256_unpacklo_ps(r5, r6); + t6 = _mm256_unpackhi_ps(r5, r6); + t7 = _mm256_unpacklo_ps(r7, r8); + t8 = _mm256_unpackhi_ps(r7, r8); + + __m256 v; + v = _mm256_shuffle_ps(t1, t3, 0x4E); + r1 = _mm256_blend_ps(t1, v, 0xCC); + r2 = _mm256_blend_ps(t3, v, 0x33); + + v = _mm256_shuffle_ps(t2, t4, 0x4E); + r3 = _mm256_blend_ps(t2, v, 0xCC); + r4 = _mm256_blend_ps(t4, v, 0x33); + + v = _mm256_shuffle_ps(t5, t7, 0x4E); + r5 = _mm256_blend_ps(t5, v, 0xCC); + r6 = _mm256_blend_ps(t7, v, 0x33); + + v = _mm256_shuffle_ps(t6, t8, 0x4E); + r7 = _mm256_blend_ps(t6, v, 0xCC); + r8 = _mm256_blend_ps(t8, v, 0x33); + + STORE256X8_F32(dst_ptr, dst_stride, r); +} +#endif + +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +inline void Transpose8X8Fp32Sse(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) { + __m128 v0_ma = _mm_loadu_ps(src_ptr); + __m128 v1_ma = _mm_loadu_ps(src_ptr + src_stride); + __m128 v2_ma = _mm_loadu_ps(src_ptr + 2 * src_stride); + __m128 v3_ma = _mm_loadu_ps(src_ptr + 3 * src_stride); + + __m128 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + __m128 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + __m128 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + __m128 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + __m128 v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + __m128 v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + __m128 v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + __m128 v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr, v8_ma); + _mm_storeu_ps(dst_ptr + dst_stride, v9_ma); + _mm_storeu_ps(dst_ptr + 2 * dst_stride, v10_ma); + _mm_storeu_ps(dst_ptr + 3 * dst_stride, v11_ma); + + v0_ma = _mm_loadu_ps(src_ptr + C4NUM); + v1_ma = _mm_loadu_ps(src_ptr + src_stride + C4NUM); + v2_ma = _mm_loadu_ps(src_ptr + 2 * src_stride + C4NUM); + v3_ma = _mm_loadu_ps(src_ptr + 3 * src_stride + C4NUM); + + v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr + C4NUM * dst_stride, v8_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 1) * dst_stride, v9_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 2) * dst_stride, v10_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride, v11_ma); + + v0_ma = _mm_loadu_ps(src_ptr + C4NUM * src_stride); + v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * src_stride); + v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * src_stride); + v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * src_stride); + + v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr + C4NUM, v8_ma); + _mm_storeu_ps(dst_ptr + dst_stride + C4NUM, v9_ma); + _mm_storeu_ps(dst_ptr + 2 * dst_stride + C4NUM, v10_ma); + _mm_storeu_ps(dst_ptr + 3 * dst_stride + C4NUM, v11_ma); + + v0_ma = _mm_loadu_ps(src_ptr + C4NUM * src_stride + C4NUM); + v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * src_stride + C4NUM); + v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * src_stride + C4NUM); + v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * src_stride + C4NUM); + + v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr + C4NUM * dst_stride + C4NUM, v8_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 1) * dst_stride + C4NUM, v9_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 2) * dst_stride + C4NUM, v10_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride + C4NUM, v11_ma); +} +#endif + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel) { + // nchw to nc4hw4 with 1D F(2,3) + for (int i = 0; i < channel; i++) { + float *src_kernel = (float *)src + i * 9; + float *dst_kernel = (float *)dst + (i / 4) * 48 + i % 4; + for (int y = 0; y < 3; y++) { + float g0 = src_kernel[3 * y]; + float g1 = src_kernel[3 * y + 1]; + float g2 = src_kernel[3 * y + 2]; + + dst_kernel[16 * y] = g0; + dst_kernel[16 * y + 4] = 0.5f * (g0 + g1 + g2); + dst_kernel[16 * y + 8] = 0.5f * (g0 - g1 + g2); + dst_kernel[16 * y + 12] = g2; + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..4558e93485c1075ccb10519549ba17ed23343aa7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.h @@ -0,0 +1,130 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_PACK_H_ +#define MINDSPORE_NNACL_FP32_PACK_H_ + +#include +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +static inline void transpose_tail(const float *from, float *to, int j_start, int j_end, int i_start, int i_end, + int j_stride, int i_stride) { + // write consecutively + for (int j = j_start; j < j_end; j++) { + for (int i = i_start; i < i_end; i++) { + to[j * j_stride + i] = from[i * i_stride + j]; + } + } +} +void TransposeFp32(const void *src, void *dst, int batches, int channel, int plane, int start, int end); +void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel); +void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNHWCXFp32(const void *src, void *dst, int batch, int plane, int channel, int oc_tile); +void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel); +// Note: If not multithreaded, please set task_id = 0 and thread_count = 0; +void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count); +void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count); +void PackNHWCXToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int cx_num); +void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNC8HW8ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNC8HW8Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNC8HW8ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); +void UnPackC4Uint(const void *src, void *dst, size_t plane, size_t channel); +void PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel); +void PackNHWCToNC4HW4NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel); +void PackNHWCToNC8HW8NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel); + +void RowMajor2ColMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2RowMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row32MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end); +void RowMajor2Row64MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end); +void RowMajor2Col4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col32MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col64MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); + +void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2RowMajor(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row32Major(const float *src_ptr, float *dst_ptr, int col, int row); +void RowMajor2Row64Major(const float *src_ptr, float *dst_ptr, int col, int row); +void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col32Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col64Major(const float *src_ptr, float *dst_ptr, int row, int col); + +void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel); +void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel); +void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel); + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel); +#endif + +// Transpose 8X8 Fp32 block data +typedef void (*Transpose8X8Fp32Func)(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); +#ifdef ENABLE_ARM64 +void Transpose8X8Fp32Arm64(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); +#endif +#ifdef ENABLE_ARM32 +void Transpose8X8Fp32Arm32(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); +#endif +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) +void PackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel, + float *tmp_weight, const float *src); +#endif +#ifdef ENABLE_AVX +#ifdef ENABLE_DEBUG +void SWPackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel, + float *tmp_weight, const float *src); +#endif +void Transpose8X8Fp32Avx(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); +#endif +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +void Transpose8X8Fp32Sse(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_PAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.c new file mode 100644 index 0000000000000000000000000000000000000000..a58fd3c411a870d89f5d8159759ec4c5e5787cff --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.c @@ -0,0 +1,292 @@ +#ifdef ENABLE_ARM64 +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/pack_fp32_opt.h" +#include "nnacl_c/op_base.h" + +void RowMajor2Col12MajorOptCore(const float *src_c, float *dst_c, size_t stride, int64_t row, int64_t col) { + if (row <= 0 || col <= 0) { + return; + } + size_t stride_byte = stride * sizeof(float); + size_t stride_unit = stride * (C12NUM - 1); + int64_t r = 0; + for (; r <= row - C12NUM; r += C12NUM) { + int64_t c = 0; + for (; c <= col - C4NUM; c += C4NUM) { + asm volatile( + "mov x9, %[src_c]\n" + "mov x10, %[dst_c]\n" + + "ld1 {v0.4s}, [x9], %[stride_byte]\n" + "ld1 {v1.4s}, [x9], %[stride_byte]\n" + "ld1 {v2.4s}, [x9], %[stride_byte]\n" + "ld1 {v3.4s}, [x9], %[stride_byte]\n" + + "ld1 {v4.4s}, [x9], %[stride_byte]\n" + "ld1 {v5.4s}, [x9], %[stride_byte]\n" + "ld1 {v6.4s}, [x9], %[stride_byte]\n" + "ld1 {v7.4s}, [x9], %[stride_byte]\n" + + "zip1 v12.4s, v0.4s, v1.4s\n" + "zip2 v13.4s, v0.4s, v1.4s\n" + "zip1 v14.4s, v2.4s, v3.4s\n" + "zip2 v15.4s, v2.4s, v3.4s\n" + + "ld1 {v8.4s}, [x9], %[stride_byte]\n" + "ld1 {v9.4s}, [x9], %[stride_byte]\n" + "ld1 {v10.4s}, [x9], %[stride_byte]\n" + "ld1 {v11.4s}, [x9], %[stride_byte]\n" + + "zip1 v16.4s, v4.4s, v5.4s\n" + "zip2 v17.4s, v4.4s, v5.4s\n" + "zip1 v18.4s, v6.4s, v7.4s\n" + "zip2 v19.4s, v6.4s, v7.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v23.2d, v12.2d, v14.2d\n" + "trn1 v26.2d, v13.2d, v15.2d\n" + "trn2 v29.2d, v13.2d, v15.2d\n" + + "trn1 v21.2d, v16.2d, v18.2d\n" + "trn2 v24.2d, v16.2d, v18.2d\n" + "trn1 v27.2d, v17.2d, v19.2d\n" + "trn2 v30.2d, v17.2d, v19.2d\n" + + "zip1 v12.4s, v8.4s, v9.4s\n" + "zip2 v13.4s, v8.4s, v9.4s\n" + "zip1 v14.4s, v10.4s, v11.4s\n" + "zip2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v22.2d, v12.2d, v14.2d\n" + "trn2 v25.2d, v12.2d, v14.2d\n" + "trn1 v28.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], #64\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride_byte] "r"(stride_byte) + : "memory", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31"); + dst_c += C48NUM; + src_c += C4NUM; + } + for (; c < col; ++c) { + for (int i = 0; i < C12NUM; ++i) { + dst_c[i] = src_c[i * stride]; + } + ++src_c; + dst_c += C12NUM; + } + src_c += stride_unit; + } + for (; r < row; ++r) { + for (int c = 0; c < col; ++c) { + dst_c[c * C12NUM] = src_c[c]; + } + src_c += stride; + ++dst_c; + } +} + +void RowMajor2Row12MajorOptCore(const float *src_c, float *dst_c, size_t stride, int64_t row, int64_t col) { + if (row <= 0 || col <= 0) { + return; + } + size_t stride_byte = stride * sizeof(float); + int64_t c = 0; + for (; c <= col - C12NUM; c += C12NUM) { + asm volatile( + "mov x9, %[src_c]\n" + "mov x10, %[dst_c]\n" + "mov x11, %[row]\n" + "1:\n" + "ld1 {v0.4s, v1.4s, v2.4s}, [x9], %[stride_byte]\n" + "st1 {v0.4s, v1.4s, v2.4s}, [x10], #48\n" + "subs x11, x11, #1\n" + "bgt 1b\n" + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride_byte] "r"(stride_byte), [row] "r"(row) + : "cc", "memory", "x9", "x10", "x11", "v0", "v1", "v2"); + dst_c += row * C12NUM; + src_c += C12NUM; + } + int64_t c_remain = col - c; + if (c_remain == 0) { + return; + } + for (int64_t r = 0; r < row; ++r) { + for (c = 0; c < c_remain; ++c) { + dst_c[r * C12NUM + c] = src_c[c]; + } + src_c += stride; + } +} + +void RowMajor2Col12MajorOpt(const float *src_ptr, float *dst_ptr, int64_t row, int64_t col, int64_t start, + int64_t end) { + int64_t bundle_row = UP_DIV(row, C12NUM); + int64_t unit_num_per_batch = bundle_row * col; + if (unit_num_per_batch == 0) { + return; + } + int64_t start_batch = start / unit_num_per_batch; + int64_t end_batch = end / unit_num_per_batch; + int64_t start_remain = start % unit_num_per_batch; + int64_t end_remain = end % unit_num_per_batch; + if (col == 0) { + return; + } + int64_t start_row = start_remain / col; + int64_t end_row = end_remain / col; + int64_t start_col = start_remain % col; + int64_t end_col = end_remain % col; + const float *src = src_ptr + start_batch * row * col + start_row * C12NUM * col + start_col; + float *dst = dst_ptr + start * C12NUM; + int64_t row_num = C12NUM; + if (start_row * C12NUM + C12NUM > row) { + row_num -= (start_row * C12NUM + C12NUM - row); + } + if (start_batch == end_batch) { + if (start_row == end_row) { + RowMajor2Col12MajorOptCore(src, dst, col, row_num, end_col - start_col); + return; + } + RowMajor2Col12MajorOptCore(src, dst, col, C12NUM, col - start_col); + src += C12NUM * col - start_col; + dst += (col - start_col) * C12NUM; + ++start_row; + if (start_row < end_row) { + row_num = (end_row - start_row) * C12NUM; + RowMajor2Col12MajorOptCore(src, dst, col, row_num, col); + src += row_num * col; + dst += row_num * col; + } + row_num = C12NUM; + if (end_row * C12NUM + C12NUM > row) { + row_num -= (end_row * C12NUM + C12NUM - row); + } + RowMajor2Col12MajorOptCore(src, dst, col, row_num, end_col); + return; + } + RowMajor2Col12MajorOptCore(src, dst, col, row_num, col - start_col); + src += row_num * col - start_col; + dst += (col - start_col) * C12NUM; + row_num = row - start_row * C12NUM - C12NUM; + if (row_num > 0) { + RowMajor2Col12MajorOptCore(src, dst, col, row_num, col); + src += row_num * col; + dst += UP_DIV(row_num, C12NUM) * C12NUM * col; + } + ++start_batch; + for (; start_batch < end_batch; ++start_batch) { + RowMajor2Col12MajorOptCore(src, dst, col, row, col); + src += row * col; + dst += bundle_row * C12NUM * col; + } + if (end_row > 0) { + row_num = end_row * C12NUM; + RowMajor2Col12MajorOptCore(src, dst, col, row_num, col); + src += row_num * col; + dst += row_num * col; + } + row_num = C12NUM; + if (end_row * C12NUM + C12NUM > row) { + row_num -= (end_row * C12NUM + C12NUM - row); + } + RowMajor2Col12MajorOptCore(src, dst, col, row_num, end_col); +} + +void RowMajor2Row12MajorOpt(const float *src_ptr, float *dst_ptr, int64_t row, int64_t col, int64_t start, + int64_t end) { + int64_t bundle_col = UP_DIV(col, C12NUM); + int64_t unit_num_per_batch = bundle_col * row; + if (unit_num_per_batch == 0) { + return; + } + int64_t start_batch = start / unit_num_per_batch; + int64_t end_batch = end / unit_num_per_batch; + int64_t start_remain = start % unit_num_per_batch; + int64_t end_remain = end % unit_num_per_batch; + if (row == 0) { + return; + } + int64_t start_row = start_remain % row; + int64_t end_row = end_remain % row; + int64_t start_col = start_remain / row; + int64_t end_col = end_remain / row; + const float *src = src_ptr + start_batch * row * col + start_row * col + start_col * C12NUM; + float *dst = dst_ptr + start * C12NUM; + int64_t col_num = C12NUM; + if (start_col * C12NUM + C12NUM > col) { + col_num -= (start_col * C12NUM + C12NUM - col); + } + if (start_batch == end_batch) { + if (start_col == end_col) { + RowMajor2Row12MajorOptCore(src, dst, col, end_row - start_row, col_num); + return; + } + RowMajor2Row12MajorOptCore(src, dst, col, row - start_row, col_num); + src += C12NUM - start_row * col; + dst += (row - start_row) * C12NUM; + ++start_col; + if (start_col < end_col) { + col_num = (end_col - start_col) * C12NUM; + RowMajor2Row12MajorOptCore(src, dst, col, row, col_num); + src += col_num; + dst += row * col_num; + } + col_num = C12NUM; + if (end_col * C12NUM + C12NUM > col) { + col_num -= (end_col * C12NUM + C12NUM - col); + } + RowMajor2Row12MajorOptCore(src, dst, col, end_row, col_num); + return; + } + RowMajor2Row12MajorOptCore(src, dst, col, row - start_row, col_num); + src += col_num - start_row * col; + dst += (row - start_row) * C12NUM; + col_num = col - start_col * C12NUM - C12NUM; + if (col_num > 0) { + RowMajor2Row12MajorOptCore(src, dst, col, row, col_num); + src += col_num; + dst += UP_DIV(col_num, C12NUM) * C12NUM * row; + } + src += (row - 1) * col; + ++start_batch; + for (; start_batch < end_batch; ++start_batch) { + RowMajor2Row12MajorOptCore(src, dst, col, row, col); + src += row * col; + dst += bundle_col * C12NUM * row; + } + if (end_col > 0) { + col_num = end_col * C12NUM; + RowMajor2Row12MajorOptCore(src, dst, col, row, col_num); + src += col_num; + dst += row * col_num; + } + col_num = C12NUM; + if (end_col * C12NUM + C12NUM > col) { + col_num -= (end_col * C12NUM + C12NUM - col); + } + RowMajor2Row12MajorOptCore(src, dst, col, end_row, col_num); +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.h new file mode 100644 index 0000000000000000000000000000000000000000..95a039cbcde75d32034545bafc1fd05e7c1aae04 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.h @@ -0,0 +1,38 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_PACK_FP32_V2_H +#define MINDSPORE_NNACL_FP32_PACK_FP32_V2_H + +#ifdef ENABLE_ARM64 +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Plan of packing supports granular multi-threads. + */ + +void RowMajor2Col12MajorOpt(const float *src_ptr, float *dst_ptr, int64_t row, int64_t col, int64_t start, int64_t end); + +void RowMajor2Row12MajorOpt(const float *src_ptr, float *dst_ptr, int64_t row, int64_t col, int64_t start, int64_t end); + +#ifdef __cplusplus +} +#endif +#endif +#endif // MINDSPORE_NNACL_FP32_PACK_FP32_V2_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pad_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pad_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..5115e4d16ef1ac39d62eeae344ebbb7d8515dab0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pad_fp32.c @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/pad_fp32.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/errorcode.h" + +void Pad(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *paddings, int tid, int thread_num) { + if (thread_num == 0) { + return; + } + int in[DEFAULT_PAD_NDIMS], out[DEFAULT_PAD_NDIMS]; + for (in[0] = 0; in[0] < input_shape[0]; in[0]++) { + out[0] = in[0] + paddings[0]; + for (in[1] = tid; in[1] < input_shape[1]; in[1] += thread_num) { + out[1] = in[1] + paddings[2]; + for (in[2] = 0; in[2] < input_shape[2]; in[2]++) { + out[2] = in[2] + paddings[4]; + for (in[3] = 0; in[3] < input_shape[3]; in[3]++) { + out[3] = in[3] + paddings[6]; + for (in[4] = 0; in[4] < input_shape[4]; in[4]++) { + out[4] = in[4] + paddings[8]; + float *dst = output_data + Offset6d(output_shape, out) + paddings[10]; + const float *src = input_data + Offset6d(input_shape, in); + memcpy(dst, src, input_shape[5] * (int)(sizeof(float))); + } + } + } + } + } +} + +int TransOut2InputDimIndex(int out_dim_index, int left_pad, int in_dim, int offset) { + if (out_dim_index < left_pad) { + // left pad + const int index_sum = left_pad + offset - 1; + int in_index = MSMAX(index_sum - out_dim_index, offset); + return MSMIN(in_index, in_dim - 1); + } + out_dim_index -= left_pad; + if (out_dim_index < in_dim) { + return out_dim_index; + } + // right pad + out_dim_index -= in_dim; + const int index_sum = in_dim - 1 - offset; + return MSMAX(index_sum - out_dim_index, 0); +} + +int GetInputFlattenIndex(int out_flatten_index, const int32_t *input_shape, const int *in_strides, + const int *out_strides, const int *paddings, int mirror_offset) { + int in_flatten_index = 0; + for (int i = 0; i < DEFAULT_PAD_NDIMS; ++i) { + int left_pad = paddings[i * 2]; + NNACL_CHECK_ZERO_RETURN_ERR(out_strides[i]); + int out_dim_index = out_flatten_index / out_strides[i]; + out_flatten_index %= out_strides[i]; + int in_dim_index = TransOut2InputDimIndex(out_dim_index, left_pad, input_shape[i], mirror_offset); + in_flatten_index += in_dim_index * in_strides[i]; + } + return in_flatten_index; +} + +void MirrorPad(const float *input_data, float *output_data, const int32_t *input_shape, const int *in_strides, + const int *out_strides, const int *paddings, int mirror_offset, int begin, int end) { + for (int i = begin; i < end; ++i) { + output_data[i] = input_data[GetInputFlattenIndex(i, input_shape, in_strides, out_strides, paddings, mirror_offset)]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pad_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pad_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..762000f2033e223f3ed4942132ce01df285c6d7b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pad_fp32.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_PAD_FP32_H_ +#define NNACL_FP32_PAD_FP32_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/pad_parameter.h" + +int GetInputFlattenIndex(int out_flatten_index, const int32_t *input_shape, const int *in_strides, + const int *out_strides, const int *paddings, int mirror_offset); +#ifdef __cplusplus +extern "C" { +#endif +void Pad(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *paddings, int tid, int thread_num); +void MirrorPad(const float *input_data, float *output_data, const int32_t *input_shape, const int *in_strides, + const int *out_strides, const int *paddings, int mirror_offset, int begin, int end); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_PAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a3653b94b7ed16051db7e784c873f2ad1437d5a2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.c @@ -0,0 +1,786 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/pooling_fp32.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/pooling_fp32_simd.h" + +int AvgPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + + const float *src_plane_ptr = src_b_ptr; + float *dst_plane_ptr = dst_b_ptr + index * channel; + + int real_win_h_start = MSMAX(0, -in_h_index); + int real_win_h_end = MSMIN(win_h, in_h - in_h_index); + int real_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_end = MSMIN(win_w, in_w - in_w_index); + int ci = 0; + + NNACL_CHECK_TRUE_RET(real_win_h_end > real_win_h_start, NNACL_ERR); + NNACL_CHECK_TRUE_RET(real_win_w_end > real_win_w_start, NNACL_ERR); + SIMD_RUN_NO_SCALAR(AvgPoolingBatch, ci, src_plane_ptr, channel, dst_plane_ptr, real_win_h_start, real_win_h_end, + real_win_w_start, real_win_w_end, in_h_index, in_w, in_w_index, pooling_args->minf, + pooling_args->maxf); + + for (; ci < channel; ci++) { + const float *src_c_ptr = src_plane_ptr + ci; + float *dst_c_ptr = dst_plane_ptr + ci; + float tmp_avg = 0; + int real_count = 0; + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += src_win_ptr[0]; + ++real_count; + } // win_w loop + } // win_h loop + NNACL_CHECK_TRUE_RET(real_count != 0, NNACL_ERR); + tmp_avg = tmp_avg / (float)real_count; + tmp_avg = fmaxf(tmp_avg, pooling_args->minf); + tmp_avg = fminf(tmp_avg, pooling_args->maxf); + dst_c_ptr[0] = tmp_avg; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + return NNACL_OK; +} + +int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int output_batch = pooling_args->output_batch_; + + for (int batch = 0; batch < output_batch; batch++) { + const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; + float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; + int ret = AvgPoolingBatch(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int AvgPoolingFromNC4HW4ToNHWCLessC(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + + int out_plane = output_w * output_h; + int in_plane = in_w * in_h; + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + +#ifdef ENABLE_AVX + const int c_tile = C8NUM; + const int once_calc_num = 2; +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + const int c_tile = C4NUM; + const int once_calc_num = 1; +#else + const int c_tile = 1; + const int once_calc_num = 1; +#endif + + const int c_xtile = once_calc_num * c_tile; + + int cur_c = (channel / c_xtile) * c_xtile; + int last_c_size = channel - cur_c; + + int less_out_plane = out_plane * last_c_size; + int calc_tile = UP_DIV(less_out_plane, thread_num); + + int index_begin = task_id * calc_tile; + int index_end = (index_begin + calc_tile) < less_out_plane ? (index_begin + calc_tile) : less_out_plane; + + int c_start = (index_begin / out_plane) + cur_c; + int index_less = index_begin % out_plane; + int h_start = index_less / output_h; + int w_start = index_less % output_h; + + int c_end = (index_end / out_plane) + cur_c; + index_less = index_end % out_plane; + int h_end = index_less / output_h; + int w_end = index_less % output_h; + + int c = c_start; + int h = h_start; + int w = w_start; + + int in_w_cx_line = in_w * last_c_size; + const float *src_c_ptr = src_b_ptr + c * in_plane; + for (; c < channel; c += c_xtile) { + for (; h < output_h; h++) { + int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0); + int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h); + + for (; w < output_w; w++) { + NNACL_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK); + float tmp_avg = 0.0; + + int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0); + int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w); + + int real_count = (cur_index_in_w_end - cur_index_in_w_start) * (cur_index_in_h_end - cur_index_in_h_start); + NNACL_CHECK_TRUE_RET(real_count != 0, NNACL_ERR); + + for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) { + const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * last_c_size + (c - cur_c); + tmp_avg += cur_input_index[0]; + } + } + + float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c; + tmp_avg = tmp_avg / (float)real_count; + tmp_avg = fminf(tmp_avg, pooling_args->maxf); + dst_c_ptr[0] = tmp_avg; + } + w = 0; + } + h = 0; + } + return NNACL_OK; +} + +int AvgPoolingFromNC4HW4ToNHWCBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + + int out_plane = output_w * output_h; + int in_plane = in_w * in_h; + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + +#ifdef ENABLE_AVX + const int c_tile = C8NUM; + const int once_calc_num = 2; + MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(pooling_args->minf); + MS_FLOAT32X8 max_value_8 = MS_MOV256_F32(pooling_args->maxf); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + const int c_tile = C4NUM; + const int once_calc_num = 1; + MS_FLOAT32X4 min_value = MS_MOVQ_F32(pooling_args->minf); + MS_FLOAT32X4 max_value = MS_MOVQ_F32(pooling_args->maxf); +#else + const int c_tile = 1; + const int once_calc_num = 1; +#endif + + int in_w_cx_line = in_w * c_tile; + const int c_xtile = once_calc_num * c_tile; + int c_tile_num = channel / c_xtile; + int all_out_plane = out_plane * c_tile_num; + int calc_tile = UP_DIV(all_out_plane, thread_num); + + int index_begin = task_id * calc_tile; + int index_end = (index_begin + calc_tile) < all_out_plane ? (index_begin + calc_tile) : all_out_plane; + + int c_start = (index_begin / out_plane) * c_xtile; + int index_less = index_begin % out_plane; + int h_start = index_less / output_h; + int w_start = index_less % output_h; + + int c_end = (index_end / out_plane) * c_xtile; + index_less = index_end % out_plane; + int h_end = index_less / output_h; + int w_end = index_less % output_h; + + int c = c_start; + int h = h_start; + int w = w_start; + for (; c < channel; c += c_xtile) { + const float *src_c_ptr = src_b_ptr + c * in_plane; + for (; h < output_h; h++) { + int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0); + int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h); + + for (; w < output_w; w++) { + NNACL_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK); + +#ifdef ENABLE_AVX + MS_FLOAT32X8 tmp_avg = MS_MOV256_F32(0); + MS_FLOAT32X8 tmp_avg2 = MS_MOV256_F32(0); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 tmp_avg = MS_MOVQ_F32(0); +#else + float tmp_avg = 0; +#endif + + int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0); + int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w); + + int real_count = (cur_index_in_w_end - cur_index_in_w_start) * (cur_index_in_h_end - cur_index_in_h_start); + NNACL_CHECK_TRUE_RET(real_count != 0, NNACL_ERR); + + for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) { + const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * c_tile; +#ifdef ENABLE_AVX + tmp_avg = MS_ADD256_F32(tmp_avg, MS_LD256_F32(cur_input_index)); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + tmp_avg = MS_ADDQ_F32(tmp_avg, MS_LDQ_F32(cur_input_index)); +#else + tmp_avg += cur_input_index[0]; +#endif + } + +#ifdef ENABLE_AVX + const float *src_c2_ptr_h_line = src_c_ptr_h_line + c_tile * in_plane; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c2_ptr_h_line + cur_index_in_w * c_tile; + + tmp_avg2 = MS_ADD256_F32(tmp_avg2, MS_LD256_F32(cur_input_index)); + } +#endif + } + + float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c; +#ifdef ENABLE_AVX + float *dst_c2_ptr = dst_c_ptr + c_tile; + + tmp_avg = MS_DIV256_F32(tmp_avg, MS_MOV256_F32(real_count)); + tmp_avg = MS_MAX256_F32(tmp_avg, min_value_8); + tmp_avg = MS_MIN256_F32(tmp_avg, max_value_8); + MS_ST256_F32(dst_c_ptr, tmp_avg); + + tmp_avg2 = MS_DIV256_F32(tmp_avg2, MS_MOV256_F32(real_count)); + tmp_avg2 = MS_MAX256_F32(tmp_avg2, min_value_8); + tmp_avg2 = MS_MIN256_F32(tmp_avg2, max_value_8); + MS_ST256_F32(dst_c2_ptr, tmp_avg2); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + tmp_avg = MS_DIVQ_F32(tmp_avg, MS_MOVQ_F32(real_count)); + tmp_avg = MS_MAXQ_F32(tmp_avg, min_value); + tmp_avg = MS_MINQ_F32(tmp_avg, max_value); + MS_STQ_F32(dst_c_ptr, tmp_avg); +#else + tmp_avg = tmp_avg / (float)real_count; + tmp_avg = fmaxf(tmp_avg, pooling_args->minf); + tmp_avg = fminf(tmp_avg, pooling_args->maxf); + dst_c_ptr[0] = tmp_avg; +#endif + } + w = 0; + } + h = 0; + } + + return NNACL_OK; +} + +int AvgPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int output_batch = pooling_args->output_batch_; + + for (int batch = 0; batch < output_batch; batch++) { + const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; + float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; + int ret = AvgPoolingFromNC4HW4ToNHWCBatch(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + + ret = AvgPoolingFromNC4HW4ToNHWCLessC(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int MaxPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + + const float *src_plane_ptr = src_b_ptr; + float *dst_plane_ptr = dst_b_ptr + index * channel; + + int real_win_h_start = MSMAX(0, -in_h_index); + int real_win_h_end = MSMIN(win_h, in_h - in_h_index); + int real_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_end = MSMIN(win_w, in_w - in_w_index); + int ci = 0; + + SIMD_RUN_NO_SCALAR(MaxPoolingBatch, ci, src_plane_ptr, channel, dst_plane_ptr, real_win_h_start, real_win_h_end, + real_win_w_start, real_win_w_end, in_h_index, in_w, in_w_index, pooling_args->minf, + pooling_args->maxf); + + for (; ci < channel; ci++) { + float *dst_c_ptr = dst_plane_ptr + ci; + const float *src_c_ptr = src_plane_ptr + ci; + float tmp_max = -FLT_MAX; + for (int kh = real_win_h_start; kh < real_win_h_end; kh++) { + for (int kw = real_win_w_start; kw < real_win_w_end; kw++) { + const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel; + tmp_max = fmaxf(tmp_max, src_win_ptr[0]); + } // win_w loop + } // win_h loop + tmp_max = fmaxf(tmp_max, pooling_args->minf); + tmp_max = fminf(tmp_max, pooling_args->maxf); + dst_c_ptr[0] = tmp_max; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + return NNACL_OK; +} + +int MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int output_batch = pooling_args->output_batch_; + + for (int batch = 0; batch < output_batch; batch++) { + const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; + float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; + int ret = MaxPoolingBatch(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int MaxPoolingFromNC4HW4ToNHWCLessC(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + + int out_plane = output_w * output_h; + int in_plane = in_w * in_h; + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + +#ifdef ENABLE_AVX + const int c_tile = C8NUM; + const int once_calc_num = 2; +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + const int c_tile = C4NUM; + const int once_calc_num = 1; +#else + const int c_tile = 1; + const int once_calc_num = 1; +#endif + + const int c_xtile = once_calc_num * c_tile; + + int cur_c = (channel / c_xtile) * c_xtile; + int last_c_size = channel - cur_c; + + int less_out_plane = out_plane * last_c_size; + int calc_tile = UP_DIV(less_out_plane, thread_num); + + int index_begin = task_id * calc_tile; + int index_end = (index_begin + calc_tile) < less_out_plane ? (index_begin + calc_tile) : less_out_plane; + + int c_start = (index_begin / out_plane) + cur_c; + int index_less = index_begin % out_plane; + int h_start = index_less / output_h; + int w_start = index_less % output_h; + + int c_end = (index_end / out_plane) + cur_c; + index_less = index_end % out_plane; + int h_end = index_less / output_h; + int w_end = index_less % output_h; + + int c = c_start; + int h = h_start; + int w = w_start; + + int in_w_cx_line = in_w * last_c_size; + const float *src_c_ptr = src_b_ptr + cur_c * in_plane; + for (; c < channel; c++) { + for (; h < output_h; h++) { + int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0); + int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h); + + for (; w < output_w; w++) { + NNACL_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK); + float tmp_max = -FLT_MAX; + + int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0); + int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w); + + for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) { + const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * last_c_size + (c - cur_c); + tmp_max = fmaxf(tmp_max, cur_input_index[0]); + } + } + + float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c; + tmp_max = fmaxf(tmp_max, pooling_args->minf); + tmp_max = fminf(tmp_max, pooling_args->maxf); + dst_c_ptr[0] = tmp_max; + } + w = 0; + } + h = 0; + } + return NNACL_OK; +} + +int MaxPoolingFromNC4HW4ToNHWCBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + + int out_plane = output_w * output_h; + int in_plane = in_w * in_h; + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + +#ifdef ENABLE_AVX + const int c_tile = C8NUM; + const int once_calc_num = 2; + MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(pooling_args->minf); + MS_FLOAT32X8 max_value_8 = MS_MOV256_F32(pooling_args->maxf); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + const int c_tile = C4NUM; + const int once_calc_num = 1; + MS_FLOAT32X4 min_value = MS_MOVQ_F32(pooling_args->minf); + MS_FLOAT32X4 max_value = MS_MOVQ_F32(pooling_args->maxf); +#else + const int c_tile = 1; + const int once_calc_num = 1; +#endif + + int in_w_cx_line = in_w * c_tile; + const int c_xtile = once_calc_num * c_tile; + int c_tile_num = channel / c_xtile; + int all_out_plane = out_plane * c_tile_num; + int calc_tile = UP_DIV(all_out_plane, thread_num); + + int index_begin = task_id * calc_tile; + int index_end = (index_begin + calc_tile) < all_out_plane ? (index_begin + calc_tile) : all_out_plane; + + int c_start = (index_begin / out_plane) * c_xtile; + int index_less = index_begin % out_plane; + int h_start = index_less / output_h; + int w_start = index_less % output_h; + + int c_end = (index_end / out_plane) * c_xtile; + index_less = index_end % out_plane; + int h_end = index_less / output_h; + int w_end = index_less % output_h; + + int c = c_start; + int h = h_start; + int w = w_start; + for (; c < channel; c += c_xtile) { + const float *src_c_ptr = src_b_ptr + c * in_plane; + for (; h < output_h; h++) { + int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0); + int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h); + + for (; w < output_w; w++) { + NNACL_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK); + +#ifdef ENABLE_AVX + MS_FLOAT32X8 tmp_max = MS_MOV256_F32(-FLT_MAX); + MS_FLOAT32X8 tmp_max2 = MS_MOV256_F32(-FLT_MAX); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 tmp_max = MS_MOVQ_F32(-FLT_MAX); +#else + float tmp_max = -FLT_MAX; +#endif + + int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0); + int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w); + + for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) { + const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * c_tile; +#ifdef ENABLE_AVX + tmp_max = MS_MAX256_F32(tmp_max, MS_LD256_F32(cur_input_index)); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + tmp_max = MS_MAXQ_F32(tmp_max, MS_LDQ_F32(cur_input_index)); +#else + tmp_max = fmaxf(tmp_max, cur_input_index[0]); +#endif + } + +#ifdef ENABLE_AVX + const float *src_c2_ptr_h_line = src_c_ptr_h_line + c_tile * in_plane; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c2_ptr_h_line + cur_index_in_w * c_tile; + + tmp_max2 = MS_MAX256_F32(tmp_max2, MS_LD256_F32(cur_input_index)); + } +#endif + } + + float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c; +#ifdef ENABLE_AVX + float *dst_c2_ptr = dst_c_ptr + c_tile; + + tmp_max = MS_MAX256_F32(tmp_max, min_value_8); + tmp_max = MS_MIN256_F32(tmp_max, max_value_8); + MS_ST256_F32(dst_c_ptr, tmp_max); + + tmp_max2 = MS_MAX256_F32(tmp_max2, min_value_8); + tmp_max2 = MS_MIN256_F32(tmp_max2, max_value_8); + MS_ST256_F32(dst_c2_ptr, tmp_max2); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + tmp_max = MS_MAXQ_F32(tmp_max, min_value); + tmp_max = MS_MINQ_F32(tmp_max, max_value); + MS_STQ_F32(dst_c_ptr, tmp_max); +#else + tmp_max = fmaxf(tmp_max, pooling_args->minf); + tmp_max = fminf(tmp_max, pooling_args->maxf); + dst_c_ptr[0] = tmp_max; +#endif + } + w = 0; + } + h = 0; + } + + return NNACL_OK; +} + +int MaxPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int output_batch = pooling_args->output_batch_; + + for (int batch = 0; batch < output_batch; batch++) { + const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; + float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; + int ret = MaxPoolingFromNC4HW4ToNHWCBatch(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + + ret = MaxPoolingFromNC4HW4ToNHWCLessC(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +void MaxPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param, + const Pooling3DComputeParam *pooling_args, int start, int end) { + // Access structure members in declaration order + int in_size_w = pooling_args->pooling_compute_param_.input_w_; + int in_size_h = pooling_args->pooling_compute_param_.input_h_; + int batch = pooling_args->pooling_compute_param_.input_batch_; + int channel = pooling_args->pooling_compute_param_.input_channel_; + int out_size_w = pooling_args->pooling_compute_param_.output_w_; + int out_size_h = pooling_args->pooling_compute_param_.output_h_; + int in_size_d = pooling_args->input_d_; + int out_size_d = pooling_args->output_d_; + + int kernel_w = pooling_param->pooling_parameter_.window_w_; + int kernel_h = pooling_param->pooling_parameter_.window_h_; + int stride_w = pooling_param->pooling_parameter_.stride_w_; + int stride_h = pooling_param->pooling_parameter_.stride_h_; + int pad_l_h = pooling_param->pooling_parameter_.pad_u_; + int pad_l_w = pooling_param->pooling_parameter_.pad_l_; + int kernel_d = pooling_param->window_d_; + int stride_d = pooling_param->stride_d_; + int pad_l_d = pooling_param->pad_f_; + + int n_stride = in_size_d * in_size_h * in_size_w * channel; + int d_stride = in_size_h * in_size_w * channel; + int h_stride = in_size_w * channel; + + int n = 0, d = 0, h = 0, w = 0; + const int parallel_dims = 4; // parallel on N/D/H/W four dims + offset_to_index_init(start, parallel_dims * VA_ARG_TUPLE_LEN, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n, + batch); + + for (int i = start; i < end; i++) { + int d_start = d * stride_d - pad_l_d; + int d_end = MSMIN(d_start + kernel_d, in_size_d); + d_start = MSMAX(d_start, 0); + int h_start = h * stride_h - pad_l_h; + int h_end = MSMIN(h_start + kernel_h, in_size_h); + h_start = MSMAX(h_start, 0); + int w_start = w * stride_w - pad_l_w; + int w_end = MSMIN(w_start + kernel_w, in_size_w); + w_start = MSMAX(w_start, 0); + + const float *src_batch_ptr = input_ptr + n * n_stride; + float *out = output_ptr + i * channel; + int c_idx = 0; + SIMD_RUN_NO_SCALAR(MaxPooling3D, c_idx, src_batch_ptr, channel, out, d_start, d_end, h_start, h_end, w_start, w_end, + d_stride, h_stride); + for (; c_idx < channel; ++c_idx) { + const float *src_c_ptr = src_batch_ptr + c_idx; + float *dst_c_ptr = out + c_idx; + float tmp_max = -FLT_MAX; + for (int dd = d_start; dd < d_end; ++dd) { + for (int hh = h_start; hh < h_end; ++hh) { + for (int ww = w_start; ww < w_end; ++ww) { + const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel; + tmp_max = MSMAX(input[0], tmp_max); + } + } + } + dst_c_ptr[0] = tmp_max; + } + offset_to_index_step(parallel_dims * 2, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n, batch); + } +} + +void AvgPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param, + const Pooling3DComputeParam *pooling_args, int start, int end) { + // Access structure members in declaration order + int in_size_w = pooling_args->pooling_compute_param_.input_w_; + int in_size_h = pooling_args->pooling_compute_param_.input_h_; + int batch = pooling_args->pooling_compute_param_.input_batch_; + int channel = pooling_args->pooling_compute_param_.input_channel_; + int out_size_w = pooling_args->pooling_compute_param_.output_w_; + int out_size_h = pooling_args->pooling_compute_param_.output_h_; + int in_size_d = pooling_args->input_d_; + int out_size_d = pooling_args->output_d_; + + int kernel_w = pooling_param->pooling_parameter_.window_w_; + int kernel_h = pooling_param->pooling_parameter_.window_h_; + int stride_w = pooling_param->pooling_parameter_.stride_w_; + int stride_h = pooling_param->pooling_parameter_.stride_h_; + int pad_l_h = pooling_param->pooling_parameter_.pad_u_; + int pad_r_h = pooling_param->pooling_parameter_.pad_d_; + int pad_l_w = pooling_param->pooling_parameter_.pad_l_; + int pad_r_w = pooling_param->pooling_parameter_.pad_r_; + int kernel_d = pooling_param->window_d_; + int stride_d = pooling_param->stride_d_; + int pad_l_d = pooling_param->pad_f_; + int pad_r_d = pooling_param->pad_b_; + bool count_include_pad = pooling_param->count_include_pad_; + int divisor = pooling_param->divisor_override_; + + int n_stride = in_size_d * in_size_h * in_size_w * channel; + int d_stride = in_size_h * in_size_w * channel; + int h_stride = in_size_w * channel; + + const int d_max = in_size_d + pad_r_d; + const int h_max = in_size_h + pad_r_h; + const int w_max = in_size_w + pad_r_w; + + int n = 0, d = 0, h = 0, w = 0; + const int parallel_dims = 4; // parallel on N/D/H/W four dims + offset_to_index_init(start, parallel_dims * VA_ARG_TUPLE_LEN, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n, + batch); + + for (int i = start; i < end; i++) { + int d_start = d * stride_d - pad_l_d; + int d_end = MSMIN(d_start + kernel_d, d_max); + int d_start2 = MSMAX(d_start, 0); + int d_end2 = MSMIN(d_end, in_size_d); + int h_start = h * stride_h - pad_l_h; + int h_end = MSMIN(h_start + kernel_h, h_max); + int h_start2 = MSMAX(h_start, 0); + int h_end2 = MSMIN(h_end, in_size_h); + int w_start = w * stride_w - pad_l_w; + int w_end = MSMIN(w_start + kernel_w, w_max); + int w_start2 = MSMAX(w_start, 0); + int w_end2 = MSMIN(w_end, in_size_w); + + const float *src_batch_ptr = input_ptr + n * n_stride; + float *out = output_ptr + i * channel; + + if (pooling_param->divisor_override_ == 0) { + if (count_include_pad) { + divisor = (d_end - d_start) * (h_end - h_start) * (w_end - w_start); + } else { + divisor = (d_end2 - d_start2) * (h_end2 - h_start2) * (w_end2 - w_start2); + } + } + + int c_idx = 0; + SIMD_RUN_NO_SCALAR(AvgPooling3D, c_idx, src_batch_ptr, channel, out, d_start2, d_end2, h_start2, h_end2, w_start2, + w_end2, d_stride, h_stride, divisor); + for (; c_idx < channel; ++c_idx) { + const float *src_c_ptr = src_batch_ptr + c_idx; + float *dst_c_ptr = out + c_idx; + float tmp_avg = 0; + for (int dd = d_start2; dd < d_end2; ++dd) { + for (int hh = h_start2; hh < h_end2; ++hh) { + for (int ww = w_start2; ww < w_end2; ++ww) { + const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel; + tmp_avg = tmp_avg + input[0]; + } + } + } + dst_c_ptr[0] = tmp_avg / (float)divisor; + } + offset_to_index_step(parallel_dims * 2, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n, batch); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..10bb12100d702f06cd610c160462b2026b1f35f3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_POOLING_H_ +#define NNACL_FP32_POOLING_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); +int MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); + +int AvgPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); +int MaxPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); +void MaxPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param, + const Pooling3DComputeParam *pooling_args, int start, int end); +void AvgPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param, + const Pooling3DComputeParam *pooling_args, int start, int end); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_POOLING_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..bdfb829f9fa6b4b63ebb9cddb25ad3f52170ae2f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32_simd.h.in @@ -0,0 +1,116 @@ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_POOLING_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_POOLING_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int AvgPoolingBatch@SIMD_INSTRUCTION@(int ci, const float *src_plane_ptr, int channel, + float *dst_plane_ptr, int real_win_h_start, int real_win_h_end, int real_win_w_start, int real_win_w_end, + int in_h_index, int in_w, int in_w_index, float minf, float maxf) { + SIMD_F32 min_val = SIMD_MOV_F32(minf); + SIMD_F32 max_val = SIMD_MOV_F32(maxf); + for (int block_max_size = channel - BLOCK_NUM + 1; ci < block_max_size; ci += BLOCK_NUM) { + const float *src_c_ptr = src_plane_ptr + ci; + float *dst_c_ptr = dst_plane_ptr + ci; + SIMD_F32 tmp_avg = SIMD_SET0_F32; + int real_count = 0; + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg = SIMD_ADD_F32(tmp_avg, SIMD_LD_F32(src_win_ptr)); + ++real_count; + } + } + tmp_avg = SIMD_DIV_F32(tmp_avg, SIMD_MOV_F32(real_count)); + tmp_avg = SIMD_MAX_F32(tmp_avg, min_val); + tmp_avg = SIMD_MIN_F32(tmp_avg, max_val); + SIMD_ST_F32(dst_c_ptr, tmp_avg); + } + return ci; +} + +static inline int MaxPoolingBatch@SIMD_INSTRUCTION@(int ci, const float *src_plane_ptr, int channel, + float *dst_plane_ptr, int real_win_h_start, int real_win_h_end, int real_win_w_start, int real_win_w_end, + int in_h_index, int in_w, int in_w_index, float minf, float maxf) { + SIMD_F32 min_val = SIMD_MOV_F32(minf); + SIMD_F32 max_val = SIMD_MOV_F32(maxf); + for (int block_max_size = channel - BLOCK_NUM + 1; ci < block_max_size; ci += BLOCK_NUM) { + const float *src_c_ptr = src_plane_ptr + ci; + float *dst_c_ptr = dst_plane_ptr + ci; + SIMD_F32 tmp_max = min_val; + for (int kh = real_win_h_start; kh < real_win_h_end; kh++) { + for (int kw = real_win_w_start; kw < real_win_w_end; kw++) { + const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel; + tmp_max = SIMD_MAX_F32(tmp_max, SIMD_LD_F32(src_win_ptr)); + } + } + tmp_max = SIMD_MIN_F32(tmp_max, max_val); + SIMD_ST_F32(dst_c_ptr, tmp_max); + } + return ci; +} + +static inline int MaxPooling3D@SIMD_INSTRUCTION@(int c_idx, const float *src_batch_ptr, int channel, float *out, + int d_start, int d_end, int h_start, int h_end, int w_start, int w_end, int d_stride, int h_stride) { + for (int block_max_size = channel - BLOCK_NUM + 1; c_idx < block_max_size; c_idx += BLOCK_NUM) { + const float *src_c_ptr = src_batch_ptr + c_idx; + float *dst_c_ptr = out + c_idx; + SIMD_F32 tmp_max = SIMD_MOV_F32(-FLT_MAX); + for (int dd = d_start; dd < d_end; ++dd) { + for (int hh = h_start; hh < h_end; ++hh) { + for (int ww = w_start; ww < w_end; ++ww) { + const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel; + tmp_max = SIMD_MAX_F32(SIMD_LD_F32(input), tmp_max); + } + } + } + SIMD_ST_F32(dst_c_ptr, tmp_max); + } + return c_idx; +} + +static inline int AvgPooling3D@SIMD_INSTRUCTION@(int c_idx, const float *src_batch_ptr, int channel, float *out, + int d_start, int d_end, int h_start, int h_end, int w_start, int w_end, int d_stride, int h_stride, int divisor) { + for (int block_max_size = channel - BLOCK_NUM + 1; c_idx < block_max_size; c_idx += BLOCK_NUM) { + const float *src_c_ptr = src_batch_ptr + c_idx; + float *dst_c_ptr = out + c_idx; + SIMD_F32 tmp_avg = SIMD_SET0_F32; + for (int dd = d_start; dd < d_end; ++dd) { + for (int hh = h_start; hh < h_end; ++hh) { + for (int ww = w_start; ww < w_end; ++ww) { + const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel; + tmp_avg = SIMD_ADD_F32(SIMD_LD_F32(input), tmp_avg); + } + } + } + tmp_avg = SIMD_DIV_F32(tmp_avg, SIMD_MOV_F32(divisor)); + SIMD_ST_F32(dst_c_ptr, tmp_avg); + } + return c_idx; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..4113e51c5abfda8df8eb61b94b4a89a94955d8a5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32.c @@ -0,0 +1,70 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/power_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/power_fp32_simd.h" + +float OptimizedPowerScalar(float x, const float *exponent) { + int exp = abs((int)(*exponent)); + float result = 1; + while (exp) { + if (exp % 2) { + result *= x; + } + x *= x; + exp = exp / 2; + } + return *exponent >= 0 ? result : 1 / result; +} + +void PowerBroadCast(const float *input, const float *exponent, float *output, int len, float scale, float shift) { + PowerScalarFun PowerScalarFun_ = NULL; + + int i = 0; + if (CheckInteger(*exponent)) { + PowerScalarFun_ = OptimizedPowerScalar; + SIMD_RUN_NO_SCALAR(PowerBroadCastIntExponent, i, input, (int)(*exponent), output, len, scale, shift); + } else { + PowerScalarFun_ = StdPowerScalar; + SIMD_RUN_NO_SCALAR(PowerBroadCastFloatExponent, i, input, *exponent, output, len, scale, shift); + } + + for (; i < len; ++i) { + output[i] = PowerScalarFun_(scale * input[i] + shift, exponent); + } +} + +void PowerSingle(const float *input, const float *exponent, float *output, int len, float scale, float shift) { + int i = 0; + + SIMD_RUN_NO_SCALAR(PowerSingleExponent, i, input, exponent, output, len, scale, shift); + PowerScalarFun PowerScalarFun_ = NULL; + for (; i < len; ++i) { + PowerScalarFun_ = CheckInteger(exponent[i]) ? OptimizedPowerScalar : StdPowerScalar; + output[i] = PowerScalarFun_(scale * input[i] + shift, exponent + i); + } +} + +int Power(const float *input, const float *exponent, float *output, int len, float scale, float shift, bool broadcast) { + if (input == NULL || exponent == NULL || output == NULL) { + return NNACL_NULL_PTR; + } + PowerFun PowerFun_ = NULL; + PowerFun_ = broadcast ? PowerBroadCast : PowerSingle; + PowerFun_(input, exponent, output, len, scale, shift); + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..b2a2fb79b6dfc9b01fb6377616a0a35942bf6898 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_POWER_FP32_H_ +#define MINDSPORE_NNACL_FP32_POWER_FP32_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/pow_parameter.h" + +typedef void (*PowerFun)(const float *, const float *, float *, int, float, float); +typedef float (*PowerScalarFun)(float x, const float *exponent); + +#ifdef __cplusplus +extern "C" { +#endif +static inline bool CheckInteger(float f) { return fabsf(f - (int)(f)) < 0.000001; } + +static inline float StdPowerScalar(float x, const float *exponent) { return powf(x, *exponent); } + +int Power(const float *input, const float *exponent, float *output, int len, float scale, float shift, bool broadcast); +void PowerSingle(const float *input, const float *exponent, float *output, int len, float scale, float shift); +void PowerBroadCast(const float *input, const float *exponent, float *output, int len, float scale, float shift); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_POWER_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..2b9f398ca07ed2d1c03462b6ea340cba608736d6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32_simd.h.in @@ -0,0 +1,94 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_POWER_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_POWER_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int PowerBroadCastIntExponent@SIMD_INSTRUCTION@(int index, const float *input, int exponent, float *output, int len, + float scale, float shift) { + SIMD_F32 scale_vec = SIMD_MOV_F32(scale); + SIMD_F32 shift_vec = SIMD_MOV_F32(shift); + for (int block_max_size = len - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 tmp = SIMD_FMADD_F32(scale_vec, SIMD_LD_F32(input + index), shift_vec); + SIMD_F32 result = SIMD_MOV_F32(1.0f); + int exp = abs(exponent); + while (exp) { + if (exp % 2) { + result = SIMD_MUL_F32(result, tmp); + } + tmp = SIMD_MUL_SQUARE_F32(tmp); + exp = exp / 2; + } + SIMD_ST_F32(output + index, exponent >= 0 ? result : SIMD_DIV_F32(SIMD_MOV_F32(1), result)); + } + return index; +} + +static inline int PowerBroadCastFloatExponent@SIMD_INSTRUCTION@(int index, const float *input, float exponent, float *output, int len, + float scale, float shift) { + SIMD_F32 scale_vec = SIMD_MOV_F32(scale); + SIMD_F32 shift_vec = SIMD_MOV_F32(shift); + for (int block_max_size = len - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 tmp = SIMD_FMADD_F32(scale_vec, SIMD_LD_F32(input + index), shift_vec); + SIMD_F32 result; + for (int i = 0; i < BLOCK_NUM; ++i) { + SIMD_F32_GETI(result, i) = powf(SIMD_F32_GETI(tmp, i), exponent); + } + SIMD_ST_F32(output + index, result); + } + return index; +} + +static inline int PowerSingleExponent@SIMD_INSTRUCTION@(int index, const float *input, const float *exponent, float *output, int len, + float scale, float shift) { + SIMD_F32 scale_vec = SIMD_MOV_F32(scale); + SIMD_F32 shift_vec = SIMD_MOV_F32(shift); + for (int block_max_size = len - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 tmp_vec = SIMD_FMADD_F32(scale_vec, SIMD_LD_F32(input + index), shift_vec); + for (int j = 0; j < BLOCK_NUM; ++j) { + float cur_exponent = exponent[index + j]; + float cur_val = SIMD_F32_GETI(tmp_vec, j); + if (fabsf(cur_exponent - (int)(cur_exponent)) < 0.000001) { + int exp = abs((int)(cur_exponent)); + float result = 1; + while (exp) { + if (exp % 2) { + result *= cur_val; + } + cur_val *= cur_val; + exp = exp / 2; + } + output[index + j] = *exponent >= 0 ? result : 1 / result; + } else { + output[index + j] = powf(cur_val, cur_exponent); + } + } + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prelu_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prelu_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d57e63b93ab4ef5b9643df01cefda022ba00b925 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prelu_fp32.c @@ -0,0 +1,164 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/prelu_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +#ifdef ENABLE_ARM64 +static inline void PRelu4x16(const float *in, float *out, const float *cur_slope, size_t step) { + asm volatile( + "mov x10, %[in]\n" + "mov x11, %[out]\n" + "mov x12, %[cur_slope]\n" + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12]\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], %[step]\n" + "fmul v16.4s, v0.4s, v4.4s\n" + "fmul v17.4s, v1.4s, v5.4s\n" + "fmul v18.4s, v2.4s, v6.4s\n" + "fmul v19.4s, v3.4s, v7.4s\n" + "fcmgt v20.4s, v0.4s, #0\n" + "fcmgt v21.4s, v1.4s, #0\n" + "fcmgt v22.4s, v2.4s, #0\n" + "fcmgt v23.4s, v3.4s, #0\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], %[step]\n" + "bif v0.16b, v16.16b, v20.16b\n" + "bif v1.16b, v17.16b, v21.16b\n" + "bif v2.16b, v18.16b, v22.16b\n" + "bif v3.16b, v19.16b, v23.16b\n" + "fmul v8.4s, v24.4s, v4.4s\n" + "fmul v9.4s, v25.4s, v5.4s\n" + "fmul v10.4s, v26.4s, v6.4s\n" + "fmul v11.4s, v27.4s, v7.4s\n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x11], %[step]\n" + "fcmgt v12.4s, v24.4s, #0\n" + "fcmgt v13.4s, v25.4s, #0\n" + "fcmgt v14.4s, v26.4s, #0\n" + "fcmgt v15.4s, v27.4s, #0\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], %[step]\n" + "bif v24.16b, v8.16b, v12.16b\n" + "bif v25.16b, v9.16b, v13.16b\n" + "bif v26.16b, v10.16b, v14.16b\n" + "bif v27.16b, v11.16b, v15.16b\n" + "fmul v16.4s, v0.4s, v4.4s\n" + "fmul v17.4s, v1.4s, v5.4s\n" + "fmul v18.4s, v2.4s, v6.4s\n" + "fmul v19.4s, v3.4s, v7.4s\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], %[step]\n" + "fcmgt v20.4s, v0.4s, #0\n" + "fcmgt v21.4s, v1.4s, #0\n" + "fcmgt v22.4s, v2.4s, #0\n" + "fcmgt v23.4s, v3.4s, #0\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10]\n" + "bif v0.16b, v16.16b, v20.16b\n" + "bif v1.16b, v17.16b, v21.16b\n" + "bif v2.16b, v18.16b, v22.16b\n" + "bif v3.16b, v19.16b, v23.16b\n" + "fmul v8.4s, v24.4s, v4.4s\n" + "fmul v9.4s, v25.4s, v5.4s\n" + "fmul v10.4s, v26.4s, v6.4s\n" + "fmul v11.4s, v27.4s, v7.4s\n" + "fcmgt v12.4s, v24.4s, #0\n" + "fcmgt v13.4s, v25.4s, #0\n" + "fcmgt v14.4s, v26.4s, #0\n" + "fcmgt v15.4s, v27.4s, #0\n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x11], %[step]\n" + "bif v24.16b, v8.16b, v12.16b\n" + "bif v25.16b, v9.16b, v13.16b\n" + "bif v26.16b, v10.16b, v14.16b\n" + "bif v27.16b, v11.16b, v15.16b\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11]\n" + : + : [in] "r"(in), [out] "r"(out), [cur_slope] "r"(cur_slope), [step] "r"(step) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27"); +} +#endif + +void PRelu(const float *input, float *output, const float *slope, int start, int end, int channel) { + int i = start; +#ifdef ENABLE_ARM64 + for (; i < end - 3; i += 4) { + const float *cur_in = input + i * channel; + float *cur_out = output + i * channel; + int j = 0; + for (; j < channel - 15; j += 16) { + const float *in = cur_in + j; + float *out = cur_out + j; + const float *cur_slope = slope + j; + size_t step = channel * sizeof(float); + PRelu4x16(in, out, cur_slope, step); + } + for (; j < channel; j++) { + cur_out[j] = (cur_in[j] > 0) ? cur_in[j] : (cur_in[j] * slope[j]); + cur_out[j + channel] = (cur_in[j + channel] > 0) ? cur_in[j + channel] : cur_in[j + channel] * slope[j]; + cur_out[j + 2 * channel] = + (cur_in[j + 2 * channel] > 0) ? cur_in[j + 2 * channel] : (cur_in[j + 2 * channel] * slope[j]); + cur_out[j + 3 * channel] = + (cur_in[j + 3 * channel] > 0) ? cur_in[j + 3 * channel] : (cur_in[j + 3 * channel] * slope[j]); + } + } +#endif + for (; i < end; i++) { + const float *cur_in = input + i * channel; + float *cur_out = output + i * channel; + int j = 0; +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + for (; j < channel - 3; j += 4) { + MS_FLOAT32X4 in = MS_LDQ_F32(cur_in + j); + MS_FLOAT32X4 s = MS_LDQ_F32(slope + j); + MS_FLOAT32X4 mul = MS_MULQ_F32(in, s); + MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 res = MS_BLENDQ_F32(in, mul, MS_CMPLEQ_F32(in, zero)); + MS_STQ_F32(cur_out + j, res); + } +#endif + for (; j < channel; j++) { + if (cur_in[j] > 0) { + cur_out[j] = cur_in[j]; + } else { + cur_out[j] = cur_in[j] * slope[j]; + } + } + } +} + +void PReluShareChannel(const float *input, float *output, float slope, int start, int end) { + int i = start; +#if defined(ENABLE_AVX) +#define mask_offset 30 + for (; i <= end - C8NUM; i += C8NUM) { + MS_FLOAT32X8 src_tmp = MS_LD256_F32(input + i); + MS_FLOAT32X8 mul_tmp = MS_MUL256_N_F32(src_tmp, slope); + MS_FLOAT32X8 mask = MS_CMP256_F32(src_tmp, MS_MOV256_F32(0.0f), mask_offset); + MS_ST256_F32(output + i, MS_BLEND256_F32(mul_tmp, src_tmp, mask)); + } +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + for (; i <= end - C4NUM; i += C4NUM) { + MS_FLOAT32X4 src_tmp = MS_LDQ_F32(input + i); + MS_FLOAT32X4 mul_tmp = MS_MULQ_N_F32(src_tmp, slope); +#ifdef ENABLE_ARM + MS_UINT32X4 mask = MS_CMPLEQ_F32(src_tmp, MS_MOVQ_F32(0.0f)); +#else + MS_FLOAT32X4 mask = MS_CMPLEQ_F32(src_tmp, MS_MOVQ_F32(0.0f)); +#endif + MS_STQ_F32(output + i, MS_BLENDQ_F32(src_tmp, mul_tmp, mask)); + } +#endif + for (; i < end; i++) { + output[i] = input[i] > 0 ? input[i] : input[i] * slope; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prelu_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prelu_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..3c3c7f2ff8327a0037479ec96b27e2000e111402 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prelu_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_PRELU_H_ +#define MINDSPORE_NNACL_FP32_PRELU_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PRelu(const float *input, float *output, const float *slope, int start, int end, int channel); + +void PReluShareChannel(const float *input, float *output, float slope, int start, int end); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_PRELU_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prior_box_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prior_box_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..2ef55f69be205e8bc5f8cecb896aa29096de1faa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prior_box_fp32.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_PRIOR_BOX_FP32_H_ +#define MINDSPORE_NNACL_FP32_PRIOR_BOX_FP32_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/prior_box_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +static int PriorBox(const float *input_data, float *output_data, const size_t size, const int tid, + const int thread_num) { + NNACL_CHECK_NULL_RETURN_ERR(input_data); + NNACL_CHECK_NULL_RETURN_ERR(output_data); + NNACL_CHECK_ZERO_RETURN_ERR(thread_num); + size_t unit_size = size / thread_num; + size_t copy_size = (tid == thread_num - 1) ? size - unit_size * tid : unit_size; + (void)memcpy(output_data + tid * unit_size, input_data + tid * unit_size, copy_size * sizeof(float)); + return NNACL_OK; +} +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_PRIOR_BOX_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/ragged_range_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/ragged_range_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..b73d4c8100495247405ec6b74453f5855e5b2b17 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/ragged_range_fp32.c @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/ragged_range_fp32.h" +#include +#include "nnacl_c/op_base.h" + +void RaggedRangeFp32(const float *starts, const float *limits, const float *deltas, int32_t *splits, float *value, + RaggedRangeStruct *ragged_range) { + splits[0] = 0; + for (int i = 0; i < ragged_range->rows_; i++) { + float start = ragged_range->starts_is_scalar_ ? starts[0] : starts[i]; + float limit = ragged_range->limits_is_scalar_ ? limits[0] : limits[i]; + float delta = ragged_range->deltas_is_scalar_ ? deltas[0] : deltas[i]; + int len = NNACL_MAX((int)ceil((float)(limit - start) / delta), 0); + splits[i + 1] = splits[i] + len; + for (int j = 0; j < len; j++) { + *value++ = start; + start += delta; + } + } +} + +void RaggedRangeInt(const int32_t *starts, const int32_t *limits, const int32_t *deltas, int32_t *splits, + int32_t *value, RaggedRangeStruct *ragged_range) { + splits[0] = 0; + for (int i = 0; i < ragged_range->rows_; i++) { + int start = ragged_range->starts_is_scalar_ ? starts[0] : starts[i]; + int limit = ragged_range->limits_is_scalar_ ? limits[0] : limits[i]; + int delta = ragged_range->deltas_is_scalar_ ? deltas[0] : deltas[i]; + NNACL_CHECK_ZERO_RETURN(delta); + int len = NNACL_MAX((int)ceil((float)(limit - start) / delta), 0); + splits[i + 1] = splits[i] + len; + for (int j = 0; j < len; j++) { + *value++ = start; + start += delta; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/ragged_range_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/ragged_range_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..e44846dd7745158281ca2002d64dec4732f6ab69 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/ragged_range_fp32.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_RAGGED_RANGE_FP32_H_ +#define NNACL_FP32_RAGGED_RANGE_FP32_H_ + +#include "nnacl_c/kernel/ragged_range.h" + +void RaggedRangeFp32(const float *starts, const float *limits, const float *deltas, int32_t *splits, float *value, + RaggedRangeStruct *ragged_range); +void RaggedRangeInt(const int32_t *starts, const int32_t *limits, const int32_t *deltas, int32_t *splits, + int32_t *value, RaggedRangeStruct *ragged_range); + +#endif // NNACL_FP32_RAGGED_RANGE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/range_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/range_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..2058ff1f3d2a9fcd48d17e798a31667cb6b75b2e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/range_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_RANGE_FP32_H_ +#define NNACL_FP32_RANGE_FP32_H_ + +#include "nnacl_c/op_base.h" + +void Range(float *output_ptr, float start, float delta, int nums) { + for (int i = 0; i < nums; ++i, start += delta) { + output_ptr[i] = start; + } +} + +void RangeInt(int32_t *output_ptr, int start, int delta, int nums) { + for (int i = 0; i < nums; ++i, start += delta) { + output_ptr[i] = start; + } +} + +#endif // NNACL_FP32_RANGE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rank_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rank_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..fa7857e1abccf080595071b6432111181e5b5f2b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rank_fp32.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RANK_H_ +#define MINDSPORE_NNACL_RANK_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +inline void Rank(float *output, int rank) { + output[0] = (float)(rank); + return; +} +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_RANK_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d8931b3499d3ec747edb690e926fb4ab67d0c2aa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32.c @@ -0,0 +1,359 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/reduce_fp32.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/reduce_fp32_simd.h" +#ifdef ENABLE_NNACL_INFER_SHAPE +#include "nnacl_c/reduce_parameter.h" +#endif + +// 32 bits, block_size : (512/256/128/32), block_num : (16/8/4/1) +#define ReduceCoreCalc(op_name, op_type, outer_src, outer_dst, k) \ + for (; k < inner_size; k++) { \ + const op_type *inner_src = outer_src + k; \ + op_name##PreDeal; \ + for (int i = 0; i < axis_size; i++) { \ + op_name##MidCalc; \ + } \ + op_name##PostDeal; \ + } + +#define RegReduceOp(op_name, op_type) \ + int op_name(int outer_size, int inner_size, int axis_size, const op_type *src_data, op_type *dst_data, int tid, \ + int thread_num) { \ + NNACL_CHECK_TRUE_RET(src_data != NULL && dst_data != NULL, NNACL_NULL_PTR); \ + NNACL_CHECK_TRUE_RET(thread_num > 0, NNACL_PARAM_INVALID); \ + NNACL_CHECK_TRUE_RET(axis_size > 0, NNACL_ERR); \ + for (int j = tid; j < outer_size; j += thread_num) { \ + const op_type *outer_src = src_data + j * axis_size * inner_size; \ + op_type *outer_dst = dst_data + j * inner_size; \ + int k = 0; \ + SIMD_RUN_NO_SCALAR(op_name, k, outer_src, outer_dst, inner_size, axis_size); \ + \ + ReduceCoreCalc(op_name, op_type, outer_src, outer_dst, k); \ + } \ + return NNACL_OK; \ + } + +// ReduceSum +#define ReduceSumPreDeal float tmp = 0; +#define ReduceSumMidCalc tmp += inner_src[i * inner_size]; +#define ReduceSumPostDeal outer_dst[k] = tmp; +RegReduceOp(ReduceSum, float); + +int ReduceSumByLastAxis(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num) { + NNACL_CHECK_TRUE_RET(src_data != NULL && dst_data != NULL, NNACL_NULL_PTR); + NNACL_CHECK_TRUE_RET(thread_num > 0, NNACL_PARAM_INVALID); + NNACL_CHECK_TRUE_RET(axis_size > 0, NNACL_ERR); + + for (int j = tid; j < outer_size; j += thread_num) { + const float *src_tmp = src_data + j * axis_size; + + float tmp = src_tmp[0]; + int i = 1; + + SIMD_RUN_NO_SCALAR(ReduceSumByLastAxis, i, src_tmp, &tmp, axis_size); + for (; i < axis_size; i++) { + tmp += src_tmp[i]; + } + dst_data[j] = tmp; + } + return NNACL_OK; +} + +// ReduceMean +#define ReduceMeanPreDeal float tmp = 0; +#define ReduceMeanMidCalc tmp += inner_src[i * inner_size]; +#define ReduceMeanPostDeal outer_dst[k] = tmp / axis_size; +RegReduceOp(ReduceMean, float); + +// ReduceMin +#define ReduceMinPreDeal float tmp = FLT_MAX; +#define ReduceMinMidCalc tmp = fminf(tmp, inner_src[i * inner_size]); +#define ReduceMinPostDeal outer_dst[k] = tmp; +RegReduceOp(ReduceMin, float); + +// ReduceMax +#define ReduceMaxPreDeal float tmp = FLT_MIN; +#define ReduceMaxMidCalc tmp = fmaxf(tmp, inner_src[i * inner_size]); +#define ReduceMaxPostDeal outer_dst[k] = tmp; +RegReduceOp(ReduceMax, float); + +int ReduceMaxByLastAxis(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num) { + NNACL_CHECK_TRUE_RET(src_data != NULL && dst_data != NULL, NNACL_NULL_PTR); + NNACL_CHECK_TRUE_RET(thread_num > 0, NNACL_PARAM_INVALID); + NNACL_CHECK_TRUE_RET(axis_size > 0, NNACL_ERR); + + for (int j = tid; j < outer_size; j += thread_num) { + const float *src_tmp = src_data + j * axis_size; + + float tmp = src_tmp[0]; + int i = 1; + + SIMD_RUN_NO_SCALAR(ReduceMaxByLastAxis, i, src_tmp, &tmp, axis_size); + for (; i < axis_size; i++) { + tmp = fmaxf(tmp, src_tmp[i]); + } + dst_data[j] = tmp; + } + return NNACL_OK; +} + +// ReduceProd +#define ReduceProdPreDeal float tmp = 1.0f; +#define ReduceProdMidCalc tmp *= inner_src[i * inner_size]; +#define ReduceProdPostDeal outer_dst[k] = tmp; +RegReduceOp(ReduceProd, float); + +// ReduceSumSquare +#define ReduceSumSquarePreDeal float tmp = 0; +#define ReduceSumSquareMidCalc tmp += (inner_src[i * inner_size] * inner_src[i * inner_size]); +#define ReduceSumSquarePostDeal outer_dst[k] = tmp; +RegReduceOp(ReduceSumSquare, float); + +// ReduceL2Norm +#define ReduceL2NormPreDeal float tmp = 0; +#define ReduceL2NormMidCalc tmp += (inner_src[i * inner_size] * inner_src[i * inner_size]); +#define ReduceL2NormPostDeal outer_dst[k] = sqrt(tmp); +RegReduceOp(ReduceL2Norm, float); + +// IntReduceSum +#define IntReduceSumPreDeal int tmp = 0; +#define IntReduceSumMidCalc tmp += inner_src[i * inner_size]; +#define IntReduceSumPostDeal outer_dst[k] = tmp; +RegReduceOp(IntReduceSum, int32_t); + +// IntReduceMean +#define IntReduceMeanPreDeal int tmp = 0; +#define IntReduceMeanMidCalc tmp += inner_src[i * inner_size]; +#define IntReduceMeanPostDeal outer_dst[k] = tmp / axis_size; +RegReduceOp(IntReduceMean, int32_t); + +// IntReduceMin +#define IntReduceMinPreDeal int tmp = INT32_MAX; +#define IntReduceMinMidCalc tmp = MSMIN(tmp, inner_src[i * inner_size]); +#define IntReduceMinPostDeal outer_dst[k] = tmp; +RegReduceOp(IntReduceMin, int32_t); + +// IntReduceMax +#define IntReduceMaxPreDeal int tmp = INT32_MIN; +#define IntReduceMaxMidCalc tmp = MSMAX(tmp, inner_src[i * inner_size]); +#define IntReduceMaxPostDeal outer_dst[k] = tmp; +RegReduceOp(IntReduceMax, int32_t); + +int ReduceAll(int outer_size, int inner_size, int axis_size, const bool *src_data, bool *dst_data, int tid, + int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const bool *outer_src = src_data + j * axis_size * inner_size; + bool *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const bool *inner_src = outer_src + k; + bool *inner_dst = outer_dst + k; + bool tmp = true; + for (i = 0; i < axis_size; i++) { + tmp = tmp && inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int IntReduceProd(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int tmp = 1; + for (i = 0; i < axis_size; i++) { + if (isMulOverflow(tmp, inner_src[i * inner_size])) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + tmp *= inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +#ifdef ENABLE_NNACL_INFER_SHAPE +int ReduceInferShape(int32_t **in_shape, size_t *dim_size, int32_t *out_shape, int32_t *in_format, int32_t *out_format, + int32_t *in_datatype, int32_t *out_datatype, OpParameter *param) { + *out_format = in_format[0]; + *out_datatype = in_datatype[0]; + ReduceParameter *reduce_parameter = (ReduceParameter *)param; + bool keep_dims = reduce_parameter->keep_dims_; + int num_axes = reduce_parameter->num_axes_; + int32_t *in_shape0 = in_shape[0]; + int rank = dim_size[0]; + NNACL_CHECK_TRUE_RET(rank > 0 && rank <= REDUCE_MAX_AXES_NUM, NNACL_PARAM_INVALID); + int axes[REDUCE_MAX_AXES_NUM]; + int actual_axes_num = num_axes; + for (int i = 0; i < num_axes; ++i) { + NNACL_CHECK_TRUE_RET(reduce_parameter->axes_[i] >= -rank && reduce_parameter->axes_[i] < rank, NNACL_PARAM_INVALID); + if (reduce_parameter->axes_[i] < 0) { + axes[i] = reduce_parameter->axes_[i] + rank; + } else { + axes[i] = reduce_parameter->axes_[i]; + } + } + if (reduce_parameter->reduce_to_end_) { + NNACL_CHECK_TRUE_RET(num_axes == 1, NNACL_PARAM_INVALID); + int begin_axis = axes[0]; + num_axes = rank - begin_axis; + for (int i = begin_axis + 1; i < rank; ++i) { + axes[actual_axes_num++] = i; + } + } + if (num_axes == 0) { + int j = 0; + for (int i = 0; i < rank; ++i) { + axes[i] = i; + if (keep_dims) { + out_shape[j++] = 1; + } + } + reduce_parameter->num_axes_ = rank; + for (int i = 0; i < rank; ++i) { + reduce_parameter->axes_[i] = axes[i]; + } + return NNACL_OK; + } + // reduce on selected axes + int j = 0; + for (int i = 0; i < rank; ++i) { + bool reduce_axis = false; + for (int idx = 0; idx < num_axes; ++idx) { + if (axes[idx] == i) { + reduce_axis = true; + break; + } + } + if (reduce_axis) { + if (keep_dims) { + out_shape[j++] = 1; + } + } else { + out_shape[j++] = in_shape0[i]; + } + } + reduce_parameter->num_axes_ = num_axes; + for (int i = 0; i < num_axes; ++i) { + reduce_parameter->axes_[i] = axes[i]; + } + return NNACL_OK; +} +#endif + +// [A, B] -> [B] +// col_size : start -> end for parallel +int ReduceSumDim2Axis0(size_t col_size, size_t col_len, size_t row_len, const float *src_data, float *dst_data) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + + size_t k = 0; + SIMD_RUN_NO_SCALAR(ReduceSumDim2Axis0, k, col_size, col_len, row_len, src_data, dst_data); + for (; k < col_size; k++) { + const float *inner_src = src_data + k; + float *inner_dst = dst_data + k; + float tmp = 0.0f; + for (size_t i = 0; i < row_len; i++) { + tmp += inner_src[i * col_len]; + } + *inner_dst = tmp; + } + return NNACL_OK; +} + +// [A, B] -> [A] +int ReduceSumDim2Axis1(size_t col_len, const float *src_data, float *dst_data) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + size_t k = 0; + float tmp = 0; +#ifdef ENABLE_AVX + size_t block_mod = col_len % C8NUM; + size_t block_c8 = col_len - block_mod; + float tmp_arr[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + MS_FLOAT32X8 tmp_arr_8 = MS_MOV256_F32(tmp_arr[0]); + for (; k < block_c8; k += C8NUM) { + MS_FLOAT32X8 src_in = MS_LD256_F32(src_data + k); + tmp_arr_8 = MS_ADD256_F32(tmp_arr_8, src_in); + } + MS_ST256_F32(tmp_arr, tmp_arr_8); + for (size_t i = 0; i < 8; ++i) { + tmp += tmp_arr[i]; + } +#endif + for (; k < col_len; k++) { + tmp += src_data[k]; + } + dst_data[0] = tmp; + return NNACL_OK; +} + +int ReduceMeanWithAxis(const float *src_data, float *mean, int64_t size) { + if (size == 0 || src_data == NULL) { + return NNACL_NULL_PTR; + } + float sum = 0.0; + int64_t i = 0; + SIMD_RUN_NO_SCALAR(ReduceSumByLastAxis, i, src_data, &sum, 0); + for (; i < size; ++i) { + sum += src_data[i]; + } + *mean = sum / size; + return NNACL_OK; +} + +int ReduceDeviation(const float *src_data, int64_t size, float mean, float *deviation) { + if (size == 0 || src_data == NULL) { + return NNACL_NULL_PTR; + } + int64_t i = 0; + SIMD_RUN_NO_SCALAR(FloatReduceDeviation, i, src_data, mean, size, deviation); + for (; i < size; ++i) { + float tmp = src_data[i] - mean; + tmp = tmp * tmp; + *deviation += tmp; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..c33c4df4232744adf53324e1f5e66412afaed734 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32.h @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_REDUCE_H_ +#define MINDSPORE_NNACL_FP32_REDUCE_H_ +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/reduce_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ReduceMean(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int IntReduceMean(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num); +int ReduceSum(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int ReduceSumByLastAxis(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int IntReduceSum(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num); +int ReduceMax(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int ReduceMaxByLastAxis(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int IntReduceMax(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num); +int ReduceMin(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int IntReduceMin(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num); +int ReduceProd(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int IntReduceProd(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num); +int ReduceSumSquare(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int ReduceL2Norm(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int ReduceAll(int outer_size, int inner_size, int axis_size, const bool *src_data, bool *dst_data, int tid, + int thread_num); +int ReduceSumDim2Axis0(size_t col_size, size_t col_len, size_t row_len, const float *src_data, float *dst_data); +int ReduceSumDim2Axis1(size_t col_len, const float *src_data, float *dst_data); +int ReduceMeanWithAxis(const float *src_data, float *mean, int64_t size); +int ReduceDeviation(const float *src_data, int64_t size, float mean, float *deviation); + +#ifdef ENABLE_NNACL_INFER_SHAPE +int ReduceInferShape(int32_t **in_shape, size_t *dim_size, int32_t *out_shape, int32_t *in_format, int32_t *out_format, + int32_t *in_datatype, int32_t *out_datatype, OpParameter *param); +#endif +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_REDUCE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..eee95b4294fbe16623625f294f1517bbcef62015 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32_simd.h.in @@ -0,0 +1,220 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_FP32_REDUCE_FP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_REDUCE_FP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t ReduceSum@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(inner_src + i * inner_size)); + } + SIMD_ST_F32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceSumByLastAxis@SIMD_INSTRUCTION@(int64_t index, const float *src, float* tmp_sum, int axis_size) { + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int block_max_size = axis_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(src + index)); + } + *tmp_sum += SIMD_GET_SUM_F32(tmp); + return index; +} + +static inline int64_t ReduceMean@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(inner_src + i * inner_size)); + } + SIMD_ST_F32(outer_dst + index, SIMD_DIV_N_F32(tmp, axis_size)); + } + return index; +} + +static inline int64_t ReduceMin@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(FLT_MAX); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_MIN_F32(tmp, SIMD_LD_F32(inner_src + i * inner_size)); + } + SIMD_ST_F32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceMax@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(FLT_MIN); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_MAX_F32(tmp, SIMD_LD_F32(inner_src + i * inner_size)); + } + SIMD_ST_F32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceMaxByLastAxis@SIMD_INSTRUCTION@(int64_t index, const float *src, float* tmp_max, int axis_size) { + SIMD_F32 tmp = SIMD_MOV_F32(*tmp_max); + for (int block_max_size = axis_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + tmp = SIMD_MAX_F32(tmp, SIMD_LD_F32(src + index)); + } + *tmp_max = SIMD_GET_MAX_F32(tmp); + return index; +} + +static inline int64_t ReduceProd@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(1.0f); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_MUL_F32(tmp, SIMD_LD_F32(inner_src + i * inner_size)); + } + SIMD_ST_F32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceSumSquare@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_F32(tmp, SIMD_MUL_SQUARE_F32(SIMD_LD_F32(inner_src + i * inner_size))); + } + SIMD_ST_F32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceL2Norm@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_F32(tmp, SIMD_MUL_SQUARE_F32(SIMD_LD_F32(inner_src + i * inner_size))); + } + SIMD_ST_F32(outer_dst + index, SIMD_SQRT_F32(tmp)); + } + return index; +} + +static inline int64_t IntReduceSum@SIMD_INSTRUCTION@(int64_t index, const int32_t *outer_src, int32_t *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const int32_t *inner_src = outer_src + index; + SIMD_EPI32 tmp = SIMD_MOV_EPI32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_EPI32(tmp, SIMD_LD_EPI32(inner_src + i * inner_size)); + } + SIMD_ST_EPI32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t IntReduceMean@SIMD_INSTRUCTION@(int64_t index, const int32_t *outer_src, int32_t *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const int32_t *inner_src = outer_src + index; + SIMD_EPI32 tmp = SIMD_MOV_EPI32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_EPI32(tmp, SIMD_LD_EPI32(inner_src + i * inner_size)); + } + SIMD_ST_EPI32(outer_dst + index, SIMD_DIV_N_EPI32(tmp, axis_size)); + } + return index; +} + +static inline int64_t IntReduceMin@SIMD_INSTRUCTION@(int64_t index, const int32_t *outer_src, int32_t *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const int32_t *inner_src = outer_src + index; + SIMD_EPI32 tmp = SIMD_MOV_EPI32(INT32_MAX); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_MIN_EPI32(tmp, SIMD_LD_EPI32(inner_src + i * inner_size)); + } + SIMD_ST_EPI32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t IntReduceMax@SIMD_INSTRUCTION@(int64_t index, const int32_t *outer_src, int32_t *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const int32_t *inner_src = outer_src + index; + SIMD_EPI32 tmp = SIMD_MOV_EPI32(INT32_MIN); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_MAX_EPI32(tmp, SIMD_LD_EPI32(inner_src + i * inner_size)); + } + SIMD_ST_EPI32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceSumDim2Axis0@SIMD_INSTRUCTION@(int64_t index, size_t col_size, size_t col_len, size_t row_len, const float *src_data, float *dst_data) { + for (int block_max_size = col_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 tmp = SIMD_MOV_F32(0); + const float *inner_src = src_data + index; + float *inner_dst = dst_data + index; + for (size_t i = 0; i < row_len; ++i) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(inner_src + i * col_len)); + } + SIMD_ST_F32(inner_dst, tmp); + } + return index; +} + +static inline int64_t FloatReduceDeviation@SIMD_INSTRUCTION@(int64_t index, const float *src_data, float mean, size_t size, float *deviation) { + SIMD_F32 fs_deviation = SIMD_MOV_F32(0); + SIMD_F32 fs_mean = SIMD_MOV_F32(mean); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 fs_sub = SIMD_LD_F32(src_data + index); + + fs_sub = SIMD_SUB_F32(fs_sub, fs_mean); + SIMD_F32 fs_pow = SIMD_MUL_F32(fs_sub, fs_sub); + fs_deviation = SIMD_ADD_F32(fs_deviation, fs_pow); + } + *deviation += SIMD_GET_SUM_F32(fs_deviation); + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/resize_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/resize_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..86476f0adad65117d2c6281216b11365d8a3f730 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/resize_fp32.c @@ -0,0 +1,598 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/resize_fp32.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void CalculateCoordinate(float out, int in, int32_t *bottom, int32_t *top, float *bottom_weight) { + *bottom = (int)(floorf(out)); + *bottom = *bottom >= 0 ? *bottom : 0; // extrapolate may generate neg value + *top = *bottom + 1 < in ? (*bottom + 1) : (in - 1); + float top_weight = (float)out - (float)(*bottom); + *bottom_weight = 1.0f - top_weight; +} + +static void BicubicBaseFunc(float a, const float x, float *weight) { + float abs_x = fabsf(x); + if (abs_x >= 0 && abs_x <= 1) { + *weight = ((a + 2) * abs_x - (a + 3)) * abs_x * abs_x + 1; + } else if (abs_x > 1 && abs_x <= 2) { + *weight = a * abs_x * abs_x * abs_x - 5 * a * abs_x * abs_x + 8 * a * abs_x - 4 * a; + } else { + *weight = 0; + } +} + +// a is a coefficient +// W(x) = { (a + 2) * |x| * |x| * |x| - (a + 3) * |x| * |x| + 1, for |x| <= 1 +// { a * |x| * |x| * |x| - 5 * a * |x| * |x| + 8 * a *|x| - 4 * a, for 1 < |x| < 2 +// { 0, otherwise +// the value of 'a' depends on if is half_pixel_center(the scheme is the same as tf). +// If is half pixel mode, a equals to -0.5, otherwise -0.75. +void CalculateWeightForBicubic(float out, int in, int32_t *index, float *weights, float a) { + int floor_index = (int)(floorf(out)); + index[0] = (floor_index - 1) < 0 ? 0 : (floor_index - 1); + index[1] = floor_index; + index[2] = (floor_index + 1) < in ? (floor_index + 1) : (in - 1); + index[3] = (floor_index + 2) < in ? (floor_index + 2) : (in - 1); + + // get positive value + float distance[4] = {-1, 0, 1, 2}; + float tmp_dis = out - (float)floor_index; + distance[0] -= tmp_dis; + distance[1] -= tmp_dis; + distance[2] -= tmp_dis; + distance[3] -= tmp_dis; + + for (int i = 0; i < 4; ++i) { + BicubicBaseFunc(a, distance[i], &weights[i]); + } +} + +int PrepareResizeBilinear(const int32_t *input_shape, const int32_t *output_shape, + CalculateOriginalCoordinate calculate, int32_t *y_bottoms, int32_t *y_tops, int32_t *x_lefts, + int32_t *x_rights, float *y_bottom_weights, float *x_left_weights) { + if (input_shape == NULL || output_shape == NULL || y_bottoms == NULL || y_tops == NULL || x_lefts == NULL || + x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_h = input_shape[1]; + int in_w = input_shape[2]; + + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int h = 0; h < new_height; h++) { + float actual_y = calculate(h, in_h, new_height); + CalculateCoordinate(actual_y, in_h, y_bottoms + h, y_tops + h, y_bottom_weights + h); + } + for (int w = 0; w < new_width; w++) { + float actual_x = calculate(w, in_w, new_width); + CalculateCoordinate(actual_x, in_w, x_lefts + w, x_rights + w, x_left_weights + w); + } + return NNACL_OK; +} + +int PrepareResizeBicubic(const int32_t *input_shape, const int32_t *output_shape, CalculateOriginalCoordinate calculate, + int32_t *y_tops, int32_t *x_lefts, float *y_weights, float *x_weights, float cubic_coeff) { + if (input_shape == NULL || output_shape == NULL || y_tops == NULL || x_lefts == NULL || y_weights == NULL || + x_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int h = 0; h < new_height; h++) { + float actual_y = calculate(h, in_h, new_height); + CalculateWeightForBicubic(actual_y, in_h, y_tops + 4 * h, y_weights + 4 * h, cubic_coeff); + } + for (int w = 0; w < new_width; w++) { + float actual_x = calculate(w, in_w, new_width); + CalculateWeightForBicubic(actual_x, in_w, x_lefts + 4 * w, x_weights + 4 * w, cubic_coeff); + } + return NNACL_OK; +} + +int PrepareCropAndResizeBilinear(const int32_t *input_shape, const float *boxes, const int32_t *box_idx, + const int32_t *output_shape, int32_t *y_bottoms, int32_t *y_tops, int32_t *x_lefts, + int32_t *x_rights, float *y_bottom_weights, float *x_left_weights) { + if (input_shape == NULL || output_shape == NULL || y_bottoms == NULL || y_tops == NULL || x_lefts == NULL || + x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int batch = output_shape[0]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + float actual_x; + float actual_y; + + for (int b = 0; b < batch; b++) { + const float *box = boxes + b * 4; + float start_h = box[0]; + float end_h = box[2]; + float start_w = box[1]; + float end_w = box[3]; + + int32_t *y_bottom = y_bottoms + b * new_height; + int32_t *y_top = y_tops + b * new_height; + float *y_bottom_weight = y_bottom_weights + b * new_height; + int32_t *x_left = x_lefts + b * new_width; + int32_t *x_right = x_rights + b * new_width; + float *x_left_weight = x_left_weights + b * new_width; + for (int h = 0; h < new_height; h++) { + if (new_height > 1) { + actual_y = start_h * (in_h - 1) + h * (end_h - start_h) * (in_h - 1) / (new_height - 1); + } else { + actual_y = 0.5 * (end_h + start_h) * (in_h - 1); + } + CalculateCoordinate(actual_y, in_h, y_bottom + h, y_top + h, y_bottom_weight + h); + } + for (int w = 0; w < new_width; w++) { + if (new_width > 1) { + actual_x = start_w * (in_w - 1) + w * (end_w - start_w) * (in_w - 1) / (new_width - 1); + } else { + actual_x = 0.5 * (end_w + start_w) * (in_w - 1); + } + CalculateCoordinate(actual_x, in_w, x_left + w, x_right + w, x_left_weight + w); + } + } + return NNACL_OK; +} + +int InterpRow(const float *src_line, float *linear_output, int new_width, const float *x_left_weights, + const int32_t *x_lefts, const int32_t *x_rights, int in_c) { + int w; + for (w = 0; w < new_width; w++) { + int c = 0; +#if defined(ENABLE_AVX) + MS_FLOAT32X8 left_w_8 = MS_MOV256_F32(x_left_weights[w]); + MS_FLOAT32X8 right_w_8 = MS_MOV256_F32(1.0f - x_left_weights[w]); + for (; c <= in_c - C8NUM; c += C8NUM) { + MS_FLOAT32X8 left = MS_LD256_F32(src_line + x_lefts[w] * in_c + c); + MS_FLOAT32X8 right = MS_LD256_F32(src_line + x_rights[w] * in_c + c); + MS_FLOAT32X8 interp_value = MS_ADD256_F32(MS_MUL256_F32(left, left_w_8), MS_MUL256_F32(right, right_w_8)); + MS_ST256_F32(linear_output + w * in_c + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 left_w = MS_MOVQ_F32(x_left_weights[w]); + MS_FLOAT32X4 right_w = MS_MOVQ_F32(1.0f - x_left_weights[w]); + for (; c <= in_c - C4NUM; c += C4NUM) { + MS_FLOAT32X4 left = MS_LDQ_F32(src_line + x_lefts[w] * in_c + c); + MS_FLOAT32X4 right = MS_LDQ_F32(src_line + x_rights[w] * in_c + c); + MS_FLOAT32X4 interp_value = MS_ADDQ_F32(MS_MULQ_F32(left, left_w), MS_MULQ_F32(right, right_w)); + MS_STQ_F32(linear_output + w * in_c + c, interp_value); + } +#endif + int left_w_offset = x_lefts[w] * in_c; + int right_w_offset = x_rights[w] * in_c; + for (; c < in_c; c++) { + float left = src_line[left_w_offset + c]; + float right = src_line[right_w_offset + c]; + linear_output[w * in_c + c] = left * x_left_weights[w] + right * (1.0f - x_left_weights[w]); + } + } + return 0; +} + +int InterpCol(const float *bottom_line, const float *top_line, float *output, int new_width, float y_bottom_weight, + int in_c) { + int w; + for (w = 0; w < new_width; w++) { + int c = 0; +#if defined(ENABLE_AVX) + MS_FLOAT32X8 bottom_w_8 = MS_MOV256_F32(y_bottom_weight); + MS_FLOAT32X8 top_w_8 = MS_MOV256_F32(1.0f - y_bottom_weight); + for (; c <= in_c - C8NUM; c += C8NUM) { + MS_FLOAT32X8 bottom = MS_LD256_F32(bottom_line + w * in_c + c); + MS_FLOAT32X8 top = MS_LD256_F32(top_line + w * in_c + c); + MS_FLOAT32X8 interp_value = MS_ADD256_F32(MS_MUL256_F32(bottom, bottom_w_8), MS_MUL256_F32(top, top_w_8)); + MS_ST256_F32(output + w * in_c + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 bottom_w = MS_MOVQ_F32(y_bottom_weight); + MS_FLOAT32X4 top_w = MS_MOVQ_F32(1.0f - y_bottom_weight); + for (; c <= in_c - C4NUM; c += C4NUM) { + MS_FLOAT32X4 bottom = MS_LDQ_F32(bottom_line + w * in_c + c); + MS_FLOAT32X4 top = MS_LDQ_F32(top_line + w * in_c + c); + MS_FLOAT32X4 interp_value = MS_ADDQ_F32(MS_MULQ_F32(bottom, bottom_w), MS_MULQ_F32(top, top_w)); + MS_STQ_F32(output + w * in_c + c, interp_value); + } +#endif + for (; c < in_c; c++) { + float bottom = bottom_line[w * in_c + c]; + float top = top_line[w * in_c + c]; + output[w * in_c + c] = bottom * y_bottom_weight + top * (1.0f - y_bottom_weight); + } + } + return 0; +} + +void Bilinear(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_bottom, const int32_t *y_top, const int32_t *x_left, const int32_t *x_right, + const float *y_bottom_weight, const float *x_left_weight, float *line0, float *line1, const int h_begin, + const int h_end) { + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_width = output_shape[2]; + int h_stride = new_width * in_c; + + bool cache_line_used[2] = {false, false}; + int cache_line_num[2] = {-1, -1}; + float *const cache_line_ptr[2] = {line0, line1}; + float *current_line_ptr[2] = {line0, line1}; + int current_line_num[2] = {-1, -1}; + + for (int h = h_begin; h < h_end; h++) { + current_line_num[0] = y_bottom[h]; + current_line_num[1] = y_top[h]; + + for (int i = 0; i < 2; i++) { + cache_line_used[i] = false; + } + // search if we cached + for (int j = 0; j < 2; j++) { + bool find = false; + for (int k = 0; k < 2; k++) { + if (current_line_num[j] == cache_line_num[k]) { + cache_line_used[k] = true; + current_line_ptr[j] = cache_line_ptr[k]; + find = true; + break; + } + } + + if (!find) { + const float *line = input_data + current_line_num[j] * in_w * in_c; + for (int k = 0; k < 2; k++) { + if (!cache_line_used[k]) { + cache_line_num[k] = current_line_num[j]; + cache_line_used[k] = true; + current_line_ptr[j] = cache_line_ptr[k]; + InterpRow(line, current_line_ptr[j], new_width, x_left_weight, x_left, x_right, in_c); + break; + } + } + } + } + // do col interp + InterpCol(current_line_ptr[0], current_line_ptr[1], output_data + h * h_stride, new_width, y_bottom_weight[h], + in_c); + } +} + +int ResizeBilinear(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_bottoms, const int32_t *y_tops, const int32_t *x_lefts, const int32_t *x_rights, + const float *y_bottom_weights, const float *x_left_weights, float *line0, float *line1, + const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_bottoms == NULL || + y_tops == NULL || x_lefts == NULL || x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_b = input_shape[0]; + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int b = 0; b < in_b; b++) { + const float *input = input_data + b * in_h * in_w * in_c; + float *output = output_data + b * new_height * new_width * in_c; + Bilinear(input, output, input_shape, output_shape, y_bottoms, y_tops, x_lefts, x_rights, y_bottom_weights, + x_left_weights, line0, line1, h_begin, h_end); + } + return NNACL_OK; +} + +void BicubicInterpRow(const float *src, float *dst, const float *weights, const int32_t *lefts, int width, + int channel) { + for (int w = 0; w < width; w++) { + const float *weight = weights + 4 * w; + float *dst_w = dst + w * channel; + const float *src0_w = src + lefts[4 * w] * channel; + const float *src1_w = src + lefts[4 * w + 1] * channel; + const float *src2_w = src + lefts[4 * w + 2] * channel; + const float *src3_w = src + lefts[4 * w + 3] * channel; + int c = 0; +#if defined(ENABLE_AVX) + MS_FLOAT32X8 weight0_vec_8 = MS_MOV256_F32(weight[0]); + MS_FLOAT32X8 weight1_vec_8 = MS_MOV256_F32(weight[1]); + MS_FLOAT32X8 weight2_vec_8 = MS_MOV256_F32(weight[2]); + MS_FLOAT32X8 weight3_vec_8 = MS_MOV256_F32(weight[3]); + for (; c <= channel - C8NUM; c += C8NUM) { + MS_FLOAT32X8 src0_vec = MS_LD256_F32(src0_w + c); + MS_FLOAT32X8 src1_vec = MS_LD256_F32(src1_w + c); + MS_FLOAT32X8 src2_vec = MS_LD256_F32(src2_w + c); + MS_FLOAT32X8 src3_vec = MS_LD256_F32(src3_w + c); + MS_FLOAT32X8 dst0 = MS_MUL256_F32(src0_vec, weight0_vec_8); + MS_FLOAT32X8 dst1 = MS_MUL256_F32(src1_vec, weight1_vec_8); + MS_FLOAT32X8 dst2 = MS_MUL256_F32(src2_vec, weight2_vec_8); + MS_FLOAT32X8 dst3 = MS_MUL256_F32(src3_vec, weight3_vec_8); + MS_FLOAT32X8 interp_value = MS_ADD256_F32(dst3, MS_ADD256_F32(dst2, MS_ADD256_F32(dst1, dst0))); + MS_ST256_F32(dst_w + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 weight0_vec = MS_MOVQ_F32(weight[0]); + MS_FLOAT32X4 weight1_vec = MS_MOVQ_F32(weight[1]); + MS_FLOAT32X4 weight2_vec = MS_MOVQ_F32(weight[2]); + MS_FLOAT32X4 weight3_vec = MS_MOVQ_F32(weight[3]); + for (; c <= channel - C4NUM; c += C4NUM) { + MS_FLOAT32X4 src0_vec = MS_LDQ_F32(src0_w + c); + MS_FLOAT32X4 src1_vec = MS_LDQ_F32(src1_w + c); + MS_FLOAT32X4 src2_vec = MS_LDQ_F32(src2_w + c); + MS_FLOAT32X4 src3_vec = MS_LDQ_F32(src3_w + c); + MS_FLOAT32X4 dst0 = MS_MULQ_F32(src0_vec, weight0_vec); + MS_FLOAT32X4 dst1 = MS_MULQ_F32(src1_vec, weight1_vec); + MS_FLOAT32X4 dst2 = MS_MULQ_F32(src2_vec, weight2_vec); + MS_FLOAT32X4 dst3 = MS_MULQ_F32(src3_vec, weight3_vec); + MS_FLOAT32X4 interp_value = MS_ADDQ_F32(dst3, MS_ADDQ_F32(dst2, MS_ADDQ_F32(dst1, dst0))); + MS_STQ_F32(dst_w + c, interp_value); + } +#endif + for (; c < channel; c++) { + dst_w[c] = src0_w[c] * weight[0] + src1_w[c] * weight[1] + src2_w[c] * weight[2] + src3_w[c] * weight[3]; + } + } +} + +void BicubicInterpCol(const float *src, float *dst, const float *weights, int width, int channel) { + const float *src0 = src; + const float *src1 = src + width * channel; + const float *src2 = src + 2 * width * channel; + const float *src3 = src + 3 * width * channel; + for (int w = 0; w < width; w++) { + float *dst_w = dst + w * channel; + const float *src0_w = src0 + w * channel; + const float *src1_w = src1 + w * channel; + const float *src2_w = src2 + w * channel; + const float *src3_w = src3 + w * channel; + int c = 0; +#ifdef ENABLE_AVX + MS_FLOAT32X8 weight0_vec_8 = MS_MOV256_F32(weights[0]); + MS_FLOAT32X8 weight1_vec_8 = MS_MOV256_F32(weights[1]); + MS_FLOAT32X8 weight2_vec_8 = MS_MOV256_F32(weights[2]); + MS_FLOAT32X8 weight3_vec_8 = MS_MOV256_F32(weights[3]); + for (; c <= channel - C8NUM; c += C8NUM) { + MS_FLOAT32X8 src0_vec = MS_LD256_F32(src0_w + c); + MS_FLOAT32X8 src1_vec = MS_LD256_F32(src1_w + c); + MS_FLOAT32X8 src2_vec = MS_LD256_F32(src2_w + c); + MS_FLOAT32X8 src3_vec = MS_LD256_F32(src3_w + c); + MS_FLOAT32X8 dst1 = MS_MUL256_F32(src0_vec, weight0_vec_8); + MS_FLOAT32X8 dst2 = MS_MUL256_F32(src1_vec, weight1_vec_8); + MS_FLOAT32X8 dst3 = MS_MUL256_F32(src2_vec, weight2_vec_8); + MS_FLOAT32X8 dst4 = MS_MUL256_F32(src3_vec, weight3_vec_8); + MS_FLOAT32X8 interp_value = MS_ADD256_F32(dst4, MS_ADD256_F32(dst3, MS_ADD256_F32(dst1, dst2))); + MS_ST256_F32(dst_w + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 weight0_vec = MS_MOVQ_F32(weights[0]); + MS_FLOAT32X4 weight1_vec = MS_MOVQ_F32(weights[1]); + MS_FLOAT32X4 weight2_vec = MS_MOVQ_F32(weights[2]); + MS_FLOAT32X4 weight3_vec = MS_MOVQ_F32(weights[3]); + for (; c <= channel - C4NUM; c += C4NUM) { + MS_FLOAT32X4 src0_vec = MS_LDQ_F32(src0_w + c); + MS_FLOAT32X4 src1_vec = MS_LDQ_F32(src1_w + c); + MS_FLOAT32X4 src2_vec = MS_LDQ_F32(src2_w + c); + MS_FLOAT32X4 src3_vec = MS_LDQ_F32(src3_w + c); + MS_FLOAT32X4 dst1 = MS_MULQ_F32(src0_vec, weight0_vec); + MS_FLOAT32X4 dst2 = MS_MULQ_F32(src1_vec, weight1_vec); + MS_FLOAT32X4 dst3 = MS_MULQ_F32(src2_vec, weight2_vec); + MS_FLOAT32X4 dst4 = MS_MULQ_F32(src3_vec, weight3_vec); + MS_FLOAT32X4 interp_value = MS_ADDQ_F32(dst4, MS_ADDQ_F32(dst3, MS_ADDQ_F32(dst1, dst2))); + MS_STQ_F32(dst_w + c, interp_value); + } +#endif + for (; c < channel; c++) { + dst_w[c] = src0_w[c] * weights[0] + src1_w[c] * weights[1] + src2_w[c] * weights[2] + src3_w[c] * weights[3]; + } + } +} + +void Bicubic(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_tops, const int32_t *x_lefts, const float *y_weights, const float *x_weights, + float *line_buffer, const int h_begin, const int h_end) { + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_width = output_shape[2]; + int h_stride = new_width * in_c; + + for (int h = h_begin; h < h_end; h++) { + for (int i = 0; i < 4; ++i) { + BicubicInterpRow(input_data + y_tops[4 * h + i] * in_w * in_c, line_buffer + i * h_stride, x_weights, x_lefts, + new_width, in_c); + } + BicubicInterpCol(line_buffer, output_data + h * h_stride, y_weights + 4 * h, new_width, in_c); + } +} + +int ResizeBicubic(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_tops, const int32_t *x_lefts, const float *y_weights, const float *x_weights, + float *line_buffer, const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_tops == NULL || + x_lefts == NULL || y_weights == NULL || x_weights == NULL) { + return NNACL_NULL_PTR; + } + int input_cube_per_batch = input_shape[1] * input_shape[2] * input_shape[3]; + int output_cube_per_batch = output_shape[1] * output_shape[2] * input_shape[3]; + for (int b = 0; b < input_shape[0]; b++) { + const float *input = input_data + b * input_cube_per_batch; + float *output = output_data + b * output_cube_per_batch; + Bicubic(input, output, input_shape, output_shape, y_tops, x_lefts, y_weights, x_weights, line_buffer, h_begin, + h_end); + } + return NNACL_OK; +} + +int RewriteExtrapolationValue(const float *input_data, float *output_data, const int32_t *box_idx, const float *boxes, + float extrapolation_value, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_tops, const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || box_idx == NULL || input_shape == NULL || output_shape == NULL) { + return NNACL_NULL_PTR; + } + int batch = output_shape[0]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + int new_channel = output_shape[3]; + int input_h = input_shape[1]; + int input_w = input_shape[2]; + + for (int b = 0; b < batch; b++) { + float *output = output_data + b * new_height * new_width * new_channel; + const float *box = boxes + 4 * b; + float start_h = box[0]; + float end_h = box[2]; + float start_w = box[1]; + float end_w = box[3]; + float actual_y, actual_x; + for (int h = h_begin; h < h_end; ++h) { + if (new_height > 1) { + actual_y = start_h * (input_h - 1) + h * (end_h - start_h) * (input_h - 1) / (new_height - 1); + } else { + actual_y = 0.5 * (end_h + start_h) * (input_h - 1); + } + if (actual_y < 0 || actual_y > input_h - 1) { + float *output_data_base = output + h * new_width * new_channel; + for (int x = 0; x < new_width; ++x) { + for (int d = 0; d < new_channel; ++d) { + *output_data_base = extrapolation_value; + output_data_base++; + } + } + } + for (int w = 0; w < new_width; ++w) { + if (new_width > 1) { + actual_x = start_w * (input_w - 1) + w * (end_w - start_w) * (input_w - 1) / (new_width - 1); + } else { + actual_x = 0.5 * (end_w + start_w) * (input_w - 1); + } + if (actual_x < 0 || actual_x > input_w - 1) { + float *output_data_base = output + h * new_width * new_channel + w * new_channel; + for (int d = 0; d < new_channel; ++d) { + output_data_base[d] = extrapolation_value; + } + } + } + } + } + return NNACL_OK; +} + +int CropAndResizeBilinear(const float *input_data, float *output_data, const int32_t *box_idx, const float *boxes, + float extrapolation_value, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_bottoms, const int32_t *y_tops, const int32_t *x_lefts, + const int32_t *x_rights, const float *y_bottom_weights, const float *x_left_weights, + float *line0, float *line1, const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || box_idx == NULL || input_shape == NULL || output_shape == NULL || + y_bottoms == NULL || y_tops == NULL || x_lefts == NULL || x_rights == NULL || y_bottom_weights == NULL || + x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + int batch = output_shape[0]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + int new_channel = output_shape[3]; + int input_h = input_shape[1]; + int input_w = input_shape[2]; + + for (int b = 0; b < batch; b++) { + const float *cur_img = input_data + box_idx[b] * input_h * input_w * new_channel; + const int32_t *y_bottom = y_bottoms + b * new_height; + const int32_t *y_top = y_tops + b * new_height; + const float *y_bottom_weight = y_bottom_weights + b * new_height; + const int32_t *x_left = x_lefts + b * new_width; + const int32_t *x_right = x_rights + b * new_width; + const float *x_left_weight = x_left_weights + b * new_width; + float *output = output_data + b * new_height * new_width * new_channel; + + Bilinear(cur_img, output, input_shape, output_shape, y_bottom, y_top, x_left, x_right, y_bottom_weight, + x_left_weight, line0, line1, h_begin, h_end); + } + RewriteExtrapolationValue(input_data, output_data, box_idx, boxes, extrapolation_value, input_shape, output_shape, + y_tops, h_begin, h_end); + return NNACL_OK; +} + +int ResizeNearestNeighbor(const float *input_data, float *output_data, const int32_t *input_shape, + const int32_t *output_shape, CalculateOriginalCoordinate calculate, + int coordinate_transform_mode, int tid, int thread_num) { + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int c = input_shape[3]; + bool align_corners = coordinate_transform_mode == 1; + for (int batch = 0; batch < output_shape[0]; batch++) { + for (int y = tid; y < output_shape[1]; y += thread_num) { + float actual_y = calculate(y, input_shape[1], output_shape[1]); + int input_y; + if (align_corners) { + input_y = (int)(roundf(actual_y)); + } else { + input_y = (int)(floorf(actual_y)); + } + for (int x = 0; x < output_shape[2]; x++) { + float actual_x = calculate(x, input_shape[2], output_shape[2]); + int input_x; + if (align_corners) { + input_x = (int)(roundf(actual_x)); + } else { + input_x = (int)(floorf(actual_x)); + } + int in_offset = Offset(input_shape, batch, input_y, input_x, 0); + int out_offset = Offset(output_shape, batch, y, x, 0); + memcpy(output_data + out_offset, input_data + in_offset, c * sizeof(float)); + } + } + } + return NNACL_OK; +} + +float CalculateAsymmetric(int x_resized, int length_original, int length_resized) { + float scale = (float)(length_resized) / (float)(length_original); + return (float)(x_resized) / scale; +} + +float CalculateAlignCorners(int x_resized, int length_original, int length_resized) { + float scale = (float)(length_resized - 1) / (float)(length_original - 1); + return (float)(x_resized) / scale; +} + +float CalculateHalfPixel(int x_resized, int length_original, int length_resized) { + float scale = (float)(length_resized) / (float)(length_original); + float actual = (float)(x_resized + 0.5) / scale - 0.5; + return actual > 0 ? actual : 0; +} + +int CheckCropAndResizeBoxIdx(int32_t *box_idx, int32_t num_boxes, int32_t batch) { + for (int i = 0; i < num_boxes; i++) { + if (box_idx[i] < 0 || box_idx[i] >= batch) { + return NNACL_ERR; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/resize_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/resize_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..951b0feef4be7fabdbc614ab48f879c04a1cffc4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/resize_fp32.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_RESIZE_H_ +#define MINDSPORE_NNACL_FP32_RESIZE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "nnacl_c/resize_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/crop_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef float (*CalculateOriginalCoordinate)(int x_resized, int length_original, int length_resized); + +int PrepareResizeBilinear(const int32_t *input_shape, const int32_t *output_shape, + CalculateOriginalCoordinate calculate, int32_t *y_bottoms, int32_t *y_tops, int32_t *x_lefts, + int32_t *x_rights, float *y_bottom_weights, float *x_left_weights); + +int PrepareResizeBicubic(const int32_t *input_shape, const int32_t *output_shape, CalculateOriginalCoordinate calculate, + int32_t *y_tops, int32_t *x_lefts, float *y_weights, float *x_weights, float cubic_coeff); + +int ResizeBilinear(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_bottoms, const int32_t *y_tops, const int32_t *x_lefts, const int32_t *x_rights, + const float *y_bottom_weights, const float *x_left_weights, float *line0, float *line1, + const int h_begin, const int h_end); + +int ResizeBicubic(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_tops, const int32_t *x_lefts, const float *y_weights, const float *x_weights, + float *line_buffer, const int h_begin, const int h_end); + +int PrepareCropAndResizeBilinear(const int32_t *input_shape, const float *boxes, const int32_t *box_idx, + const int32_t *output_shape, int32_t *y_bottoms, int32_t *y_tops, int32_t *x_lefts, + int32_t *x_rights, float *y_bottom_weights, float *x_left_weights); + +int CropAndResizeBilinear(const float *input_data, float *output_data, const int32_t *box_idx, const float *boxes, + float extrapolation_value, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_bottoms, const int32_t *y_tops, const int32_t *x_lefts, + const int32_t *x_rights, const float *y_bottom_weights, const float *x_left_weights, + float *line0, float *line1, const int h_begin, const int h_end); + +int ResizeNearestNeighbor(const float *input_data, float *output_data, const int32_t *input_shape, + const int32_t *output_shape, CalculateOriginalCoordinate calculate, + int coordinate_transform_mode, int tid, int thread_num); + +float CalculateAsymmetric(int x_resized, int length_original, int length_resized); + +float CalculateAlignCorners(int x_resized, int length_original, int length_resized); + +float CalculateHalfPixel(int x_resized, int length_original, int length_resized); + +int CheckCropAndResizeBoxIdx(int32_t *box_idx, int32_t num_boxes, int32_t batch); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_RESIZE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..0886c07e0a2edd56bf7fbba7fa41682bbdee4971 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_fp32.c @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/reverse_fp32.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/nnacl_utils.h" + +int Reverse(const float *input, float *output, size_t elem_size, int32_t *index) { + for (size_t i = 0; i < elem_size; i++) { + NNACL_ASSERT(index[i] >= 0); + output[index[i]] = input[i]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..d95ac63339c1227ee363982014b221fbb27919c8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_REVERSE_FP32_H_ +#define NNACL_FP32_REVERSE_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/reverse_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int Reverse(const float *input, float *output, size_t elem_size, int32_t *index); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_REVERSE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_sequence_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_sequence_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..cb93db876cd6a3a9d5cd0725f771826f090fbc97 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_sequence_fp32.c @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/reverse_sequence_fp32.h" + +void ReverseSequence(const float *input0, const void *input1, float *output, ReverseSequenceParameter *para) { + (void)memcpy(output, input0, para->total_data_size_); + ComputeStrides(para->input_shape0_, para->input_stride_, para->ndim_); + ComputeStrides(para->output_shape_, para->output_stride_, para->ndim_); + for (int i = 0; i < para->outer_count_; ++i) { + const float *in = input0 + i * para->outer_stride_; + float *out = output + i * para->outer_stride_; + for (int batch = 0; batch < para->input_shape0_[para->batch_axis_]; batch++) { + const float *in_batch = in + batch * para->input_stride_[para->batch_axis_]; + float *out_batch = out + batch * para->output_stride_[para->batch_axis_]; + int32_t seq_length = para->is_seq_length_int32_ ? *((int32_t *)input1 + batch) : *((int64_t *)input1 + batch); + NNACL_CHECK_TRUE_RET_VOID(seq_length <= para->input_shape0_[para->seq_axis_]); + for (int n = 0; n < seq_length; ++n) { + const float *in_seq = in_batch + (seq_length - 1 - n) * para->input_stride_[para->seq_axis_]; + float *out_seq = out_batch + n * para->output_stride_[para->seq_axis_]; + for (int j = 0; j < para->inner_count_; ++j) { + (void)memcpy(out_seq + j * para->inner_stride_, in_seq + j * para->inner_stride_, para->copy_byte_size_); + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_sequence_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_sequence_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..2538f22ad606f4862ad3ecd99c5308edb1712b51 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_sequence_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_REVERSE_SEQUENCE_H_ +#define MINDSPORE_NNACL_FP32_REVERSE_SEQUENCE_H_ + +#include +#include "nnacl_c/common_func.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/reverse_sequence_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ReverseSequence(const float *input0, const void *input1, float *output, ReverseSequenceParameter *para); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_REVERSE_SEQUENCE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rmsprop_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rmsprop_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..fbdf194f00de4d5e70eba34c1f31578c9e24db10 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rmsprop_fp32.c @@ -0,0 +1,147 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/rmsprop_fp32.h" +#ifdef ENABLE_SSE +#ifdef _MSC_VER +#include +#else +#include +#endif +#endif + +#ifdef ENABLE_AVX +#include +#endif + +#include +#include "nnacl_c/errorcode.h" + +int RMSPropUnuseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float momentum, + float learning_rate, float decay, float epsilon, size_t start, size_t end) { + size_t c1 = start; +#ifdef ENABLE_AVX + size_t c8 = ((end - start) / C8NUM) * C8NUM; + float *variable_ptr = variable + start; + float *mean_square_ptr = mean_square + start; + float *gradients_ptr = gradients + start; + float *moment_ptr = moment + start; + + __m256 decay_r = _mm256_set1_ps(1.0 - decay); + __m256 momentum_r = _mm256_set1_ps(momentum); + __m256 lr_r = _mm256_set1_ps(learning_rate); + __m256 epsi_r = _mm256_set1_ps(epsilon); + __m256 gradient_r, mean_square_r, moment_r, variable_r, avx_r1, avx_r2; + for (; c1 < start + c8; c1 += C8NUM) { + gradient_r = _mm256_loadu_ps(gradients_ptr); + mean_square_r = _mm256_loadu_ps(mean_square_ptr); + avx_r1 = _mm256_sub_ps(_mm256_mul_ps(gradient_r, gradient_r), mean_square_r); + avx_r2 = _mm256_mul_ps(avx_r1, decay_r); + mean_square_r = _mm256_add_ps(mean_square_r, avx_r2); + _mm256_storeu_ps(mean_square_ptr, mean_square_r); + + avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(mean_square_r), epsi_r); + avx_r2 = _mm256_div_ps(_mm256_mul_ps(gradient_r, lr_r), avx_r1); + + moment_r = _mm256_loadu_ps(moment_ptr); + avx_r1 = _mm256_add_ps(_mm256_mul_ps(moment_r, momentum_r), avx_r2); + _mm256_storeu_ps(moment_ptr, avx_r1); + + variable_r = _mm256_loadu_ps(variable_ptr); + variable_r = _mm256_sub_ps(variable_r, avx_r1); + _mm256_storeu_ps(variable_ptr, variable_r); + + gradients_ptr += C8NUM; + mean_square_ptr += C8NUM; + moment_ptr += C8NUM; + variable_ptr += C8NUM; + } +#endif + + for (; c1 < end; c1++) { + mean_square[c1] += (gradients[c1] * gradients[c1] - mean_square[c1]) * (1.0 - decay); + moment[c1] = moment[c1] * momentum + (gradients[c1] * learning_rate) / sqrt(mean_square[c1] + epsilon); + variable[c1] -= moment[c1]; + } + return NNACL_OK; +} + +int RMSPropUseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float *mean_gradients, + float momentum, float learning_rate, float decay, float epsilon, size_t start, size_t end) { + size_t c1 = start; +#ifdef ENABLE_AVX + size_t c8 = ((end - start) / C8NUM) * C8NUM; + float *variable_ptr = variable + start; + float *mean_gradients_ptr = mean_gradients + start; + float *mean_square_ptr = mean_square + start; + float *moment_ptr = moment + start; + float *gradients_ptr = gradients + start; + + __m256 decay_r = _mm256_set1_ps(1.0 - decay); + __m256 momentum_r = _mm256_set1_ps(momentum); + __m256 lr_r = _mm256_set1_ps(learning_rate); + __m256 epsi_r = _mm256_set1_ps(epsilon); + __m256 grad_r, mean_grad_r, mean_square_r, moment_r, variable_r; + __m256 avx_r1, avx_r2; + for (; c1 < start + c8; c1 += C8NUM) { + grad_r = _mm256_loadu_ps(gradients_ptr); + mean_square_r = _mm256_loadu_ps(mean_square_ptr); + avx_r1 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), mean_square_r); + avx_r2 = _mm256_mul_ps(avx_r1, decay_r); + mean_square_r = _mm256_add_ps(mean_square_r, avx_r2); + _mm256_storeu_ps(mean_square_ptr, mean_square_r); + + mean_grad_r = _mm256_loadu_ps(mean_gradients_ptr); + avx_r1 = _mm256_mul_ps(_mm256_sub_ps(grad_r, mean_grad_r), decay_r); + mean_grad_r = _mm256_add_ps(mean_grad_r, avx_r1); + _mm256_storeu_ps(mean_gradients_ptr, mean_grad_r); + + avx_r1 = _mm256_sub_ps(mean_square_r, _mm256_mul_ps(mean_grad_r, mean_grad_r)); + __m256 denom_r = _mm256_add_ps(avx_r1, epsi_r); + __m256 cmp_r = _mm256_cmp_ps(denom_r, _mm256_setzero_ps(), _CMP_GE_OS); + __m256 gt_zero_r = _mm256_blendv_ps(_mm256_set1_ps(1.0f), denom_r, cmp_r); + + avx_r1 = _mm256_mul_ps(grad_r, lr_r); + avx_r2 = _mm256_div_ps(avx_r1, _mm256_sqrt_ps(gt_zero_r)); + moment_r = _mm256_loadu_ps(moment_ptr); + avx_r1 = _mm256_mul_ps(moment_r, momentum_r); + avx_r1 = _mm256_add_ps(avx_r1, avx_r2); + moment_r = _mm256_blendv_ps(moment_r, avx_r1, cmp_r); + _mm256_storeu_ps(moment_ptr, moment_r); + + variable_r = _mm256_loadu_ps(variable_ptr); + avx_r1 = _mm256_sub_ps(variable_r, moment_r); + variable_r = _mm256_blendv_ps(variable_r, avx_r1, cmp_r); + _mm256_storeu_ps(variable_ptr, variable_r); + + variable_ptr += C8NUM; + mean_gradients_ptr += C8NUM; + mean_square_ptr += C8NUM; + gradients_ptr += C8NUM; + moment_ptr += C8NUM; + } +#endif + + for (; c1 < end; c1++) { + mean_square[c1] += (gradients[c1] * gradients[c1] - mean_square[c1]) * (1.0 - decay); + mean_gradients[c1] += (gradients[c1] - mean_gradients[c1]) * (1.0 - decay); + float denom = (mean_square[c1] - mean_gradients[c1] * mean_gradients[c1]) + epsilon; + if (denom > 0) { + moment[c1] = moment[c1] * momentum + (gradients[c1] * learning_rate) / sqrt(denom); + variable[c1] -= moment[c1]; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rmsprop_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rmsprop_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..d0b5f1f82b0d190bb75f5b0eb60ead30854179f1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rmsprop_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RMDPROP_FP32_H +#define MINDSPORE_NNACL_RMDPROP_FP32_H + +#include +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +int RMSPropUnuseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float momentum, + float learning_rate, float decay, float epsilon, size_t start, size_t end); + +int RMSPropUseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float *mean_gradients, + float momentum, float learning_rate, float decay, float epsilon, size_t start, size_t end); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RMDPROP_FP32_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/roi_pooling_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/roi_pooling_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..6b6940f4c83e2effe15668d9ebbbba9aaacfe33d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/roi_pooling_fp32.c @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/roi_pooling_fp32.h" +#include +#include +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" + +int ROIPooling(const float *in_ptr, float *out_ptr, const float *roi, float *max_c, int tid, + const ROIPoolingParameter *param) { + if (param->thread_num_ == 0) { + return NNACL_PARAM_INVALID; + } + int num_rois = param->output_n_; + int units = UP_DIV(num_rois, param->thread_num_); + int roi_st = tid * units; + int roi_end = MSMIN(num_rois, roi_st + units); + if (roi_st >= num_rois) { + return NNACL_OK; + } + int batch_size = param->input_n_; + int height_ = param->input_h_; + int width_ = param->input_w_; + int channels_ = param->input_c_; + float scale = param->scale_; + int pooled_height = param->pooledH_; + int pooled_width = param->pooledW_; + const int roi_stride = 5; + int roi_ind_st = roi_st * roi_stride; + for (int i = roi_st; i < roi_end; ++i) { + int roi_batch_ind = (int)roi[roi_ind_st]; // batch_index + if (roi_batch_ind >= batch_size) { + return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; + } + int roi_start_h = (int)roundf(roi[roi_ind_st + 1] * scale); // top-left x1 + int roi_start_w = (int)roundf(roi[roi_ind_st + 2] * scale); // top-left y1 + int roi_end_h = (int)roundf(roi[roi_ind_st + 3] * scale); // bottom-right x2 + int roi_end_w = (int)roundf(roi[roi_ind_st + 4] * scale); // bottom-fight y2 + + int roi_height = MSMAX(roi_end_h - roi_start_h + 1, 1); + int roi_width = MSMAX(roi_end_w - roi_start_w + 1, 1); + + float bin_size_h = (float)roi_height / (float)pooled_height; + float bin_size_w = (float)roi_width / (float)pooled_width; + const float *batch_data = in_ptr + param->in_strides_[kNHWC_N] * roi_batch_ind; + + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = (int)floorf(ph * bin_size_h); // block xi_1 + int wstart = (int)floorf(pw * bin_size_w); // block yi_1 + int hend = (int)ceilf((ph + 1) * bin_size_h); // block xi_2 + int wend = (int)ceilf((pw + 1) * bin_size_w); // block yi_2 + + hstart = MSMIN(MSMAX(hstart + roi_start_h, 0), height_); + hend = MSMIN(MSMAX(hend + roi_start_h, 0), height_); + wstart = MSMIN(MSMAX(wstart + roi_start_w, 0), width_); + wend = MSMIN(MSMAX(wend + roi_start_w, 0), width_); + bool is_empty = (hend <= hstart) || (wend <= wstart); + for (int j = 0; j < channels_; ++j) { + max_c[j] = is_empty ? 0 : -FLT_MAX; + } + int pooled_index = i * param->out_strides_[0] + ph * param->out_strides_[1] + pw * param->out_strides_[2]; + int bd_index = hstart * param->in_strides_[1]; + for (int h = hstart; h < hend; ++h) { + int wi = bd_index + wstart * param->in_strides_[2]; + for (int w = wstart; w < wend; ++w) { + for (int c = 0; c < channels_; ++c) { + max_c[c] = MSMAX(batch_data[wi + c], max_c[c]); + } + wi += param->in_strides_[2]; + } // in_w end; + bd_index += param->in_strides_[1]; + } // in_h end + for (int j = 0; j < channels_; ++j) { + out_ptr[pooled_index + j] = max_c[j]; + } + } + } + roi_ind_st += roi_stride; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/roi_pooling_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/roi_pooling_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..1566f4687c459a496b1f3be6ff596bcf6adfe139 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/roi_pooling_fp32.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ROI_POOLING_H_ +#define MINDSPORE_NNACL_FP32_ROI_POOLING_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ROIPoolingParameter { + // primitive parameter + OpParameter op_parameter_; + int pooledW_; + int pooledH_; + float scale_; + + // shape correlative + int in_strides_[DIMENSION_4D]; + int out_strides_[DIMENSION_4D]; + int ndim_; + int input_w_; + int input_h_; + int input_n_; + int input_c_; + int output_w_; + int output_h_; + int output_n_; + int output_c_; + + // other parameter + int thread_num_; +} ROIPoolingParameter; + +#ifdef __cplusplus +extern "C" { +#endif +int ROIPooling(const float *in_ptr, float *out_ptr, const float *roi, float *max_c, int tid, + const ROIPoolingParameter *param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ROI_POOLING_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/scale_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/scale_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..32fd2319f6b33309c10fcf532b667dbedfc82209 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/scale_fp32.c @@ -0,0 +1,304 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/scale_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void ScaleInner(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size, int inner_size) { + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_AVX + MS_FLOAT32X8 scale_8 = MS_MOV256_F32(scale[i]); + MS_FLOAT32X8 offset_8 = MS_MOV256_F32(offset[i]); + for (; in_index <= inner_size - C8NUM; in_index += C8NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 result = MS_MLA256_F32(offset_8, data, scale_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 scale_4 = MS_MOVQ_F32(scale[i]); + MS_FLOAT32X4 offset_4 = MS_MOVQ_F32(offset[i]); + for (; in_index <= inner_size - C4NUM; in_index += C4NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 result = MS_MLAQ_F32(offset_4, data, scale_4); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + out_data[in_offset] = in_data[in_offset] * scale[i] + offset[i]; + } + } + } +} + +void ScaleAxis(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size) { + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#if defined(ENABLE_AVX) + for (; index <= axis_size - C8NUM; index += C8NUM) { + int in_offset = out_offset + index; + MS_FLOAT32X8 scale_8 = MS_LD256_F32(scale + index); + MS_FLOAT32X8 offset_8 = MS_LD256_F32(offset + index); + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 result = MS_MLA256_F32(offset_8, data, scale_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + for (; index <= axis_size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 scale_4 = MS_LDQ_F32(scale + index); + MS_FLOAT32X4 offset_4 = MS_LDQ_F32(offset + index); + int in_offset = out_offset + index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 result = MS_MLAQ_F32(offset_4, data, scale_4); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + out_data[in_offset] = in_data[in_offset] * scale[index] + offset[index]; + } + } +} + +void DoScale(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + ScaleAxis(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + ScaleInner(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} + +void ScaleInnerRelu(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size, int inner_size) { +#ifdef ENABLE_AVX + MS_FLOAT32X8 zeros_8 = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = {0, 0, 0, 0}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_AVX + MS_FLOAT32X8 scale_8 = MS_MOV256_F32(scale[i]); + MS_FLOAT32X8 offset_8 = MS_MOV256_F32(offset[i]); + for (; in_index <= inner_size - C8NUM; in_index += C8NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 tmp = MS_MLA256_F32(offset_8, data, scale_8); + MS_FLOAT32X8 result = MS_MAX256_F32(tmp, zeros_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 scale_4 = MS_MOVQ_F32(scale[i]); + MS_FLOAT32X4 offset_4 = MS_MOVQ_F32(offset[i]); + for (; in_index <= inner_size - C4NUM; in_index += C4NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 tmp = MS_MLAQ_F32(offset_4, data, scale_4); + MS_FLOAT32X4 result = MS_MAXQ_F32(tmp, zeros); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + float tmp = in_data[in_offset] * scale[i] + offset[i]; + out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; + } + } + } +} + +void ScaleAxisRelu(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size) { +#ifdef ENABLE_AVX + MS_FLOAT32X8 zeros_8 = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = {0, 0, 0, 0}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_AVX + for (; index <= axis_size - C8NUM; index += C8NUM) { + int in_offset = out_offset + index; + MS_FLOAT32X8 scale_8 = MS_LD256_F32(scale + index); + MS_FLOAT32X8 offset_8 = MS_LD256_F32(offset + index); + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 tmp = MS_MLA256_F32(offset_8, data, scale_8); + MS_FLOAT32X8 result = MS_MAX256_F32(tmp, zeros_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + for (; index <= axis_size - C4NUM; index += C4NUM) { + int in_offset = out_offset + index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 scale_4 = MS_LDQ_F32(scale + index); + MS_FLOAT32X4 offset_4 = MS_LDQ_F32(offset + index); + MS_FLOAT32X4 tmp = MS_MLAQ_F32(offset_4, data, scale_4); + MS_FLOAT32X4 result = MS_MAXQ_F32(tmp, zeros); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + float tmp = in_data[in_offset] * scale[index] + offset[index]; + out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; + } + } +} + +void DoScaleRelu(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + ScaleAxisRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + ScaleInnerRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} + +void ScaleInnerRelu6(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size, int inner_size) { +#ifdef ENABLE_AVX + MS_FLOAT32X8 zeros_8 = {0, 0, 0, 0, 0, 0, 0, 0}; + MS_FLOAT32X8 bounds_8 = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = {0, 0, 0, 0}; + MS_FLOAT32X4 bounds = {6, 6, 6, 6}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#if defined(ENABLE_AVX) + MS_FLOAT32X8 scale_8 = MS_MOV256_F32(scale[i]); + MS_FLOAT32X8 offset_8 = MS_MOV256_F32(offset[i]); + for (; in_index <= inner_size - C8NUM; in_index += C8NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 tmp = MS_MLA256_F32(offset_8, data, scale_8); + MS_FLOAT32X8 result = MS_MIN256_F32(MS_MAX256_F32(tmp, zeros_8), bounds_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + for (; in_index < inner_size - C4NUM; in_index += C4NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 scale_4 = MS_MOVQ_F32(scale[i]); + MS_FLOAT32X4 offset_4 = MS_MOVQ_F32(offset[i]); + MS_FLOAT32X4 tmp = MS_MLAQ_F32(offset_4, data, scale_4); + MS_FLOAT32X4 result = MS_MINQ_F32(MS_MAXQ_F32(tmp, zeros), bounds); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + float tmp = in_data[in_offset] * scale[i] + offset[i]; + out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); + } + } + } +} + +void ScaleAxisRelu6(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size) { +#ifdef ENABLE_AVX + MS_FLOAT32X8 zeros_8 = {0, 0, 0, 0, 0, 0, 0, 0}; + MS_FLOAT32X8 bounds_8 = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = {0, 0, 0, 0}; + MS_FLOAT32X4 bounds = {6, 6, 6, 6}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_AVX + for (; index <= axis_size - C8NUM; index += C8NUM) { + int in_offset = out_offset + index; + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 scale_8 = MS_LD256_F32(scale + index); + MS_FLOAT32X8 offset_8 = MS_LD256_F32(offset + index); + MS_FLOAT32X8 tmp = MS_MLA256_F32(offset_8, data, scale_8); + MS_FLOAT32X8 result = MS_MIN256_F32(MS_MAX256_F32(tmp, zeros_8), bounds_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + for (; index <= axis_size - C4NUM; index += C4NUM) { + int in_offset = out_offset + index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 scale_4 = MS_LDQ_F32(scale + index); + MS_FLOAT32X4 offset_4 = MS_LDQ_F32(offset + index); + MS_FLOAT32X4 tmp = MS_MLAQ_F32(offset_4, data, scale_4); + MS_FLOAT32X4 result = MS_MINQ_F32(MS_MAXQ_F32(tmp, zeros), bounds); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + float tmp = in_data[in_offset] * scale[index] + offset[index]; + out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); + } + } +} + +void DoScaleRelu6(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + ScaleAxisRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + ScaleInnerRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/scale_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/scale_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..40ef1d365b63fc11c8b28d120613e6f3e8368082 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/scale_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_SCALE_FP32_H_ +#define NNACL_FP32_SCALE_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/scale.h" +#ifdef __cplusplus +extern "C" { +#endif +void DoScale(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param); +void DoScaleRelu(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param); +void DoScaleRelu6(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_SCALE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..9e3d1d988b098ce6f01d33337663b762bd28f2d7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32.c @@ -0,0 +1,125 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/softmax_fp32.h" +#include +#include +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/softmax_fp32_simd.h" + +void SoftmaxNorm(const float *src, float *dst, int batch, int channel) { + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + int index = 0; + float max = -FLT_MAX; + + SIMD_RUN_NO_SCALAR(SoftmaxNormGetMax, index, src, cur_batch_offset, &max, channel); + for (; index < channel; index++) { + float input = src[cur_batch_offset + index]; + if (input > max) { + max = input; + } + } + + index = 0; + SIMD_RUN_NO_SCALAR(SoftmaxNormCalcNorm, index, src, dst, cur_batch_offset, max, channel); + for (; index < channel; index++) { + int offset = cur_batch_offset + index; + dst[offset] = src[offset] - max; + } + } +} + +int SoftmaxLastAxis(const float *src, float *dst, int batch, int channel) { + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + int index = 0; + + // get channel's max value + float max = -FLT_MAX; + SIMD_RUN_NO_SCALAR(SoftmaxNormGetMax, index, src, cur_batch_offset, &max, channel); + for (; index < channel; index++) { + float input = src[cur_batch_offset + index]; + if (input > max) { + max = input; + } + } + + // get channel's exp sum value + float exp_sum = 0.0f; + index = 0; + SIMD_RUN_NO_SCALAR(SoftmaxLastAxisGetExpSum, index, src, dst, cur_batch_offset, max, &exp_sum, channel); + for (; index < channel; index++) { + int offset = cur_batch_offset + index; + float exp_out = simd_exp32_f32(src[offset] - max); + exp_sum += exp_out; + dst[offset] = exp_out; + } + + // get result + NNACL_CHECK_TRUE_RET(exp_sum != 0, NNACL_ERR); + exp_sum = 1.0f / exp_sum; + index = 0; + SIMD_RUN_NO_SCALAR(SoftmaxLastAxisGetResult, index, dst, dst, cur_batch_offset, exp_sum, channel); + for (; index < channel; index++) { + dst[cur_batch_offset + index] = dst[cur_batch_offset + index] * exp_sum; + } + } + return NNACL_OK; +} + +// output = exp(input) / reduce_sum(exp(input), axis) +void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, int axis, int n_dim, + const int32_t *input_shape) { + int inner_size = 1; + int outter_size = 1; + + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + float max_data = input_ptr[inner_offset]; + sum_data[k + sum_outter_offset] = 0; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = expf(input_ptr[axis_offset] - max_data); + sum_data[k + sum_outter_offset] += output_ptr[axis_offset]; + } + } + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[k + sum_outter_offset]; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..36d5acc6da7559b2fbadb987fb3c5d98d5c040b4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_SOFTMAX_H_ +#define MINDSPORE_NNACL_FP32_SOFTMAX_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/softmax_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif +void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, int axis, int n_dim, + const int32_t *input_shape); +int SoftmaxLastAxis(const float *src, float *dst, int batch, int channel); +void SoftmaxNorm(const float *src, float *dst, int batch, int channel); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_SOFTMAX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..762d2c75fb5861b4e59c4414ae0261d77db2a9b8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32_simd.h.in @@ -0,0 +1,80 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_SOFTMAX_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_SOFTMAX_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t SoftmaxNormGetMax@SIMD_INSTRUCTION@(int64_t index, const float *src, int cur_batch_offset, + float *max, int channel) { + if (channel >= BLOCK_NUM * BLOCK_NUM) { + SIMD_F32 max_val = SIMD_MOV_F32(*max); + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + max_val = SIMD_MAX_F32(max_val, SIMD_LD_F32(src + cur_batch_offset + index)); + } + *max = SIMD_GET_MAX_F32(max_val); + } + return index; +} + +static inline int64_t SoftmaxNormCalcNorm@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, + int cur_batch_offset, float max, int channel) { + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 output = SIMD_SUB_F32(SIMD_LD_F32(src + cur_batch_offset + index), SIMD_MOV_F32(max)); + SIMD_ST_F32(dst + cur_batch_offset + index, output); + } + return index; +} + +static inline int64_t SoftmaxLastAxisGetExpSum@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, + int cur_batch_offset, float max, float *exp_sum, int channel) { +#ifndef _WIN32 + SIMD_F32 sum_val = SIMD_SET0_F32; + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_LD_F32(src + cur_batch_offset + index); + SIMD_F32 output = SIMD_SUB_F32(input, SIMD_MOV_F32(max)); + SIMD_F32 exp_out = SIMD_EXP_F32(output); + sum_val = SIMD_ADD_F32(sum_val, exp_out); + SIMD_ST_F32(dst + cur_batch_offset + index, exp_out); + } + *exp_sum += SIMD_GET_SUM_F32(sum_val); +#endif + return index; +} + +static inline int64_t SoftmaxLastAxisGetResult@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, + int cur_batch_offset, float exp_sum, int channel) { + SIMD_F32 exp_sum_val = SIMD_MOV_F32(exp_sum); + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_LD_F32(src + cur_batch_offset + index); + SIMD_F32 output = SIMD_MUL_F32(input, exp_sum_val); + SIMD_ST_F32(dst + cur_batch_offset + index, output); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +}; +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a9cdf4a03ac459b8f8846d2f12e36a360691388a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32.c @@ -0,0 +1,36 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/softmax_grad_fusion_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/softmax_grad_fusion_fp32_simd.h" +#include "nnacl_c/sub_fp32_simd.h" + +void SoftmaxGradFusionOpt(const float *a, const float *b, float *dst, int64_t m) { + float result = 0; + + int64_t i = 0; + SIMD_RUN_NO_SCALAR(SoftmaxGradFusionOpt, i, a, b, &result, m); + for (; i < m; i++) { + result += a[i] * b[i]; + } + + i = 0; + SIMD_RUN_NO_SCALAR(ElementOptSubMul, i, a, b, result, dst, m); + for (; i < m; i++) { + dst[i] = a[i] * (b[i] - result); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..5c69f9fd0d7ba89f098c3547d14b118d53abaca0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_SOFTMAX_GRAD_FUSION_H_ +#define MINDSPORE_NNACL_FP32_SOFTMAX_GRAD_FUSION_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +void SoftmaxGradFusionOpt(const float *a, const float *b, float *dst, int64_t m); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_SOFTMAX_GRAD_FUSION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..01a1de642193a2fb7eb48d740a5d861c848bbfda --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32_simd.h.in @@ -0,0 +1,55 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_SOFTMAX_GRAD_FUSION_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_SOFTMAX_GRAD_FUSION_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t SoftmaxGradFusionOpt@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, + float *out, int64_t size) { + SIMD_F32 result_vec = SIMD_MOV_F32(0.0f); + for (int64_t block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 a_vec = SIMD_LD_F32(a + index); + SIMD_F32 b_vec = SIMD_LD_F32(b + index); + result_vec = SIMD_FMADD_F32(a_vec, b_vec, result_vec); + } + *out += SIMD_GET_SUM_F32(result_vec); + + return index; +} + +static inline int64_t ElementOptSubMul@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float sum, + float *out, int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(sum); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_ST_F32(out + index, SIMD_MUL_F32(vin0, SIMD_SUB_F32(vin1, vin1_opt_))); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/space_to_batch_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/space_to_batch_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..b2b026b12d67cf4d09944961e30678b05863ac67 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/space_to_batch_fp32.c @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/space_to_batch_fp32.h" +#include "nnacl_c/errorcode.h" + +int DoSpaceToBatch(const void *input, void *output, SpaceToBatchParameter *param, int task_id) { + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + const int input_batch = param->input_shape_[0]; + const int input_height = param->input_shape_[1]; + const int input_width = param->input_shape_[2]; + + const int output_batch = param->output_shape_[0]; + const int output_height = param->output_shape_[1]; + const int output_width = param->output_shape_[2]; + + const int block_shape_height = param->block_sizes_[0]; + const int block_shape_width = param->block_sizes_[1]; + const int padding_top = param->paddings_[0]; + const int padding_left = param->paddings_[2]; + + NNACL_CHECK_ZERO_RETURN_ERR(input_batch); + NNACL_CHECK_ZERO_RETURN_ERR(block_shape_width); + int copy_size = param->input_shape_[3] * param->data_type_len; + for (int64_t out_b = task_id; out_b < output_batch; out_b += param->op_parameter_.thread_num_) { + int in_b = out_b % input_batch; + int shift_w = (out_b / input_batch) % block_shape_width; + int shift_h = (out_b / input_batch) / block_shape_width; + for (int out_h = 0; out_h < output_height; out_h++) { + for (int out_w = 0; out_w < output_width; out_w++) { + int64_t output_offset = + out_b * param->out_stride_[0] + out_h * param->out_stride_[1] + out_w * param->out_stride_[2]; + if (out_h * block_shape_height + shift_h < padding_top || + out_h * block_shape_height + shift_h >= padding_top + input_height || + out_w * block_shape_width + shift_w < padding_left || + out_w * block_shape_width + shift_w >= padding_left + input_width) { + memset((int8_t *)output + output_offset * param->data_type_len, 0, copy_size); + } else { + int in_h = (out_h * block_shape_height + shift_h) - padding_top; + int in_w = (out_w * block_shape_width + shift_w) - padding_left; + int input_offset = in_b * param->in_stride_[0] + in_h * param->in_stride_[1] + in_w * param->in_stride_[2]; + memcpy((int8_t *)output + output_offset * param->data_type_len, + (const int8_t *)input + input_offset * param->data_type_len, copy_size); + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/space_to_batch_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/space_to_batch_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..89a1abc51ee167d2c168b20e17833d2c2eda2212 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/space_to_batch_fp32.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_FP32_SPACE_TO_BATCH_H_ +#define MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_FP32_SPACE_TO_BATCH_H_ + +#include +#include "nnacl_c/op_base.h" + +typedef struct SpaceToBatchParameter { + // primitive parameter + OpParameter op_parameter_; + int block_sizes_[4]; + int paddings_[4]; + + // shape correlative + int input_shape_[4]; + int output_shape_[4]; + int in_stride_[4]; + int out_stride_[4]; + int padded_in_shape_[4]; + + // other parameter + bool need_paddings_; + int m_; + int data_type_len; +} SpaceToBatchParameter; +#ifdef __cplusplus +extern "C" { +#endif + +int DoSpaceToBatch(const void *input, void *output, SpaceToBatchParameter *param, int task_id); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_FP32_SPACE_TO_BATCH_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sparse_to_dense_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sparse_to_dense_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c1ef67ccbf172a3b7b7f9d576a732fbd9c92b291 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sparse_to_dense_fp32.c @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/sparse_to_dense_fp32.h" +#include "nnacl_c/errorcode.h" + +int SparseToDenseSetDefault(float *output, float default_value, const SparseToDenseParameter *param, int task_id) { + if (output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->output_num, param->op_parameter_.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->output_num); + for (int i = begin; i < end; i++) { + output[i] = default_value; + } + return NNACL_OK; +} + +int SparseToDense(int32_t *indices_vec, const float *sparse_values, float default_value, float *output, + SparseToDenseParameter *param, int task_id) { + if (indices_vec == NULL || sparse_values == NULL || output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->index_num, param->op_parameter_.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->index_num); + + int stride0 = param->output_stride[0]; + int stride1 = param->output_stride[1]; + int stride2 = param->output_stride[2]; + + if (param->validate_indices_ == true) { + int index_before = -1; + for (int i = begin; i < end; i++) { + int32_t *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + if (index <= index_before) { + return NNACL_ERR; + } + index_before = index; + } + } + + if (param->is_scalar == true) { + for (int i = begin; i < end; i++) { + int32_t *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + output[index] = sparse_values[0]; + } + } else { + for (int i = begin; i < end; i++) { + int32_t *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + output[index] = sparse_values[i]; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sparse_to_dense_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sparse_to_dense_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..874a2cf9d531e232dbcd0204e165606c347c98d0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sparse_to_dense_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_SPARSETODENSE_H_ +#define MINDSPORE_NNACL_FP32_SPARSETODENSE_H_ + +#include "nnacl_c/sparse_to_dense_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SparseToDenseSetDefault(float *output, float default_value, const SparseToDenseParameter *param, int task_id); +int SparseToDense(int32_t *indices_vec, const float *sparse_values, float default_value, float *output, + SparseToDenseParameter *param, int task_id); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_SPARSETODENSE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/splice_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/splice_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..0336d6dacdd3d1df555891c90a15227b708859b5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/splice_fp32.c @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/splice_fp32.h" +void SpliceFp32(const float *src_data, int src_row, int src_col, const SpliceParameter *splice_parameter, + float *dst_data, int dst_row, int dst_col) { + int forward_index = 0; + for (int r = 0; r < dst_row; ++r) { + float *dst_row_data = dst_data + r * dst_col; + for (int off = 0; off < splice_parameter->context_dim_; ++off) { + int r_off = splice_parameter->forward_indexes_[forward_index]; + forward_index++; + const float *tmp_src_data = src_data + r_off * src_col; + float *tmp_dst_data = dst_row_data + off * src_col; + memcpy(tmp_dst_data, tmp_src_data, (size_t)(src_col) * sizeof(float)); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/splice_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/splice_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..83c937c82095c60442d21be89c2c70b4eddd198b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/splice_fp32.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_SPLICE_FP32_H_ +#define NNACL_FP32_SPLICE_FP32_H_ + +#include +#include "nnacl_c/splice_parameter.h" + +void SpliceFp32(const float *src_data, int src_row, int src_col, const SpliceParameter *splice_parameter, + float *dst_data, int dst_row, int dst_col); + +#endif // NNACL_FP32_SPLICE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/squared_difference.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/squared_difference.c new file mode 100644 index 0000000000000000000000000000000000000000..5dc4d39080d35b2734d03e293225efca8b9819c9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/squared_difference.c @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ +#define MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ + +#include "nnacl_c/fp32/squared_difference.h" +#include "nnacl_c/fp32/sub_fp32.h" +#include "nnacl_c/fp32/mul_fp32.h" + +int ElementSquaredDifference(const float *in0, const float *in1, float *out, int size) { + ElementSub(in0, in1, out, size); + return ElementMul(out, out, out, size); +} + +int ElementOptSquaredDifference(const float *in0, const float *in1, float *out, int size, bool scale) { + ElementOptSub(in0, in1, out, size, scale); + return ElementMul(out, out, out, size); +} +#endif // MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/squared_difference.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/squared_difference.h new file mode 100644 index 0000000000000000000000000000000000000000..2d4db9dee8492288acbbc300c11fd230c55a0be5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/squared_difference.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ +#define MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/arithmetic_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* Element Squared Difference */ +int ElementSquaredDifference(const float *in0, const float *in1, float *out, int size); +int ElementOptSquaredDifference(const float *in0, const float *in1, float *out, int size, bool scale); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/strided_slice_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/strided_slice_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..9fda1af46c737d5248f52307bccd0e64777d75bf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/strided_slice_fp32.c @@ -0,0 +1,125 @@ +/** + * Copyright 2019-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/strided_slice_fp32.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/errorcode.h" + +int PadStridedSliceParameterTo8D(StridedSliceStruct *strided_slice) { + if (strided_slice->in_shape_size_ > DIMENSION_8D) { + return NNACL_STRIDED_SLICE_UNSUPPORTED_MAX_8D; + } + + int32_t begins[DIMENSION_8D]; + int32_t ends[DIMENSION_8D]; + int32_t strides[DIMENSION_8D]; + int32_t input_shape[DIMENSION_8D]; + int32_t i; + for (i = 0; i < strided_slice->in_shape_size_; ++i) { + begins[i] = strided_slice->begins_[i]; + ends[i] = MSMIN(strided_slice->ends_[i], strided_slice->in_shape_[i]); + strides[i] = strided_slice->strides_[i]; + input_shape[i] = strided_slice->in_shape_[i]; + } + + int32_t real_index = strided_slice->in_shape_size_ - 1; + for (i = DIMENSION_8D - 1; i >= 0; --i) { + if (real_index >= 0) { + strided_slice->begins_[i] = begins[real_index]; + strided_slice->ends_[i] = ends[real_index]; + strided_slice->strides_[i] = strides[real_index]; + strided_slice->in_shape_[i] = input_shape[real_index--]; + } else { + strided_slice->begins_[i] = 0; + strided_slice->ends_[i] = 1; + strided_slice->strides_[i] = 1; + strided_slice->in_shape_[i] = 1; + } + } + strided_slice->in_shape_size_ = DIMENSION_8D; + return NNACL_OK; +} + +bool LoopContinue(int stride, int i, int end) { return stride > 0 ? i < end : i > end; } + +int DoStridedSliceIn8D(const void *input, void *output, StridedSliceStruct *strided_slice) { + NNACL_CHECK_NULL_RETURN_ERR(strided_slice); + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + + const uint8_t *in = (const uint8_t *)input; + uint8_t *out = (uint8_t *)output; + int data_type_size = (int)DataTypeCSize(strided_slice->data_type_); + + int32_t *begins = strided_slice->begins_; + int32_t *ends = strided_slice->ends_; + int32_t *strides = strided_slice->strides_; + int32_t *in_shape = strided_slice->in_shape_; + + int dim_offset[DIMENSION_8D - 1]; + dim_offset[6] = in_shape[7]; + dim_offset[5] = in_shape[6] * dim_offset[6]; + dim_offset[4] = in_shape[5] * dim_offset[5]; + dim_offset[3] = in_shape[4] * dim_offset[4]; + dim_offset[2] = in_shape[3] * dim_offset[3]; + dim_offset[1] = in_shape[2] * dim_offset[2]; + dim_offset[0] = in_shape[1] * dim_offset[1]; + size_t out_offset = 0; + int32_t dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7; + for (dim0 = begins[0]; LoopContinue(strides[0], dim0, ends[0]); dim0 += strides[0]) { + for (dim1 = begins[1]; LoopContinue(strides[1], dim1, ends[1]); dim1 += strides[1]) { + for (dim2 = begins[2]; LoopContinue(strides[2], dim2, ends[2]); dim2 += strides[2]) { + for (dim3 = begins[3]; LoopContinue(strides[3], dim3, ends[3]); dim3 += strides[3]) { + for (dim4 = begins[4]; LoopContinue(strides[4], dim4, ends[4]); dim4 += strides[4]) { + for (dim5 = begins[5]; LoopContinue(strides[5], dim5, ends[5]); dim5 += strides[5]) { + for (dim6 = begins[6]; LoopContinue(strides[6], dim6, ends[6]); dim6 += strides[6]) { + for (dim7 = begins[7]; LoopContinue(strides[7], dim7, ends[7]); dim7 += strides[7]) { + int32_t in_offset = dim0 * dim_offset[0] + dim1 * dim_offset[1] + dim2 * dim_offset[2] + + dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5 * dim_offset[5] + + dim6 * dim_offset[6] + dim7; + memcpy(out + out_offset * data_type_size, in + in_offset * data_type_size, data_type_size); + out_offset++; + } + } + } + } + } + } + } + } + return NNACL_OK; +} + +void FastStride(const uint8_t *input, uint8_t *output, int split_len, int stride, size_t outer, size_t inner_size, + size_t in_offset) { + if (stride == 1) { + size_t unit = split_len * inner_size; + for (size_t i = 0; i < outer; ++i) { + memcpy(output, input, unit); + output += unit; + input += in_offset; + } + return; + } + for (size_t i = 0; i < outer; ++i) { + const uint8_t *input_ptr = input + i * in_offset; + for (int j = 0; j < split_len; ++j) { + memcpy(output, input_ptr, inner_size); + output += inner_size; + input_ptr += inner_size * stride; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/strided_slice_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/strided_slice_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..57ab5ce80d7720da8335f0ee096388c89d239207 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/strided_slice_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_STRIDED_SLICE_FP32_H_ +#define NNACL_FP32_STRIDED_SLICE_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/strided_slice_parameter.h" +#include "nnacl_c/kernel/strided_slice.h" +#ifdef __cplusplus +extern "C" { +#endif + +int PadStridedSliceParameterTo8D(StridedSliceStruct *strided_slice); +int DoStridedSliceIn8D(const void *input, void *output, StridedSliceStruct *strided_slice); + +void FastStride(const uint8_t *input, uint8_t *output, int split_len, int stride, size_t outer, size_t inner_size, + size_t in_offset); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_STRIDED_SLICE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..6975d41caa0b3d5441f8572bc1555f5987e04f86 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32.c @@ -0,0 +1,150 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/sub_fp32.h" +#include "nnacl_c/sub_fp32_simd.h" +#include "nnacl_c/errorcode.h" + +int ElementOptSub(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptSubNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[0] - in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptSubNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptSubExt(const float *in0, const float *in1, const float alpha, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptSubExtNum0, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[0] - in1[index] * alpha; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptSubExtNum1, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[0] * alpha; + } + } + return NNACL_OK; +} + +int ElementSubExt(const float *in0, const float *in1, const float alpha, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementSubExt, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[index] * alpha; + } + return NNACL_OK; +} + +int ElementOptSubInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptSubIntNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[0] - in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptSubIntNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptSubRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptSubReluNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[0] - in1[index], 0); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptSubReluNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[index] - in1[0], 0); + } + } + return NNACL_OK; +} + +int ElementOptSubRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptSubRelu6Num0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[0] - in1[index], 0), 6); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptSubRelu6Num1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] - in1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementSub(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementSub, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[index]; + } + return NNACL_OK; +} + +int ElementSubInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementSubInt, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[index]; + } + return NNACL_OK; +} + +int ElementSubRelu(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementSubRelu, index, in0, in1, out, size); + for (; index < size; index++) { + float res = in0[index] - in1[index]; + out[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementSubRelu6(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementSubRelu6, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] - in1[index], 0), 6); + } + + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..66063fbf6993fe637a2ae9dc5b21b89673f4f027 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32.h @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SUB_FP32_H_ +#define MINDSPORE_NNACL_SUB_FP32_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/arithmetic_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ElementSub(const float *in0, const float *in1, float *out, int size); +int ElementSubExt(const float *in0, const float *in1, const float alpha, float *out, int size); +int ElementOptSubExt(const float *in0, const float *in1, const float alpha, float *out, int size, bool first_scalar); +int ElementSubInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementSubRelu(const float *in0, const float *in1, float *out, int size); +int ElementSubRelu6(const float *in0, const float *in1, float *out, int size); +int ElementOptSub(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptSubInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int ElementOptSubRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptSubRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_SUB_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..36bc85e30287a6d53e94d231647cb01a2178e586 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32_simd.h.in @@ -0,0 +1,199 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_SUB_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_SUB_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int ElementOptSubNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_SUB_F32(vin0_opt, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_SUB_F32(vin0, vin1_opt_); + SIMD_ST_F32(out + index, vout); + } + return index; +} + + +static inline int ElementOptSubExtNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1, valpha); + SIMD_F32 vout = SIMD_SUB_F32(vin0_opt, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubExtNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { +SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); +SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1_opt_, valpha); +SIMD_F32 vout = SIMD_SUB_F32(vin0, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin0_opt = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_SUB_EPI32(vin0_opt, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin1_opt_ = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vout = SIMD_SUB_EPI32(vin0, vin1_opt_); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubReluNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_SUB_F32(vin0_opt, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubReluNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_SUB_F32(vin0, vin1_opt_), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubRelu6Num0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_SUB_F32(vin0_opt, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubRelu6Num1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_SUB_F32(vin0, vin1_opt_), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementSub@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_SUB_F32(vin0, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementSubExt@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1, valpha); + SIMD_F32 vout = SIMD_SUB_F32(vin0, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementSubInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_SUB_EPI32(vin0, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementSubRelu@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_SUB_F32(vin0, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementSubRelu6@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_SUB_F32(vin0, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +}; +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/topk_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/topk_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..ad268d8c350fba316297e31c3ba9f696901cdc1a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/topk_fp32.c @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/topk_fp32.h" +#include "nnacl_c/errorcode.h" + +int DescendCmp(const void *a, const void *b) { + NNACL_CHECK_NULL_RETURN_ERR(a); + NNACL_CHECK_NULL_RETURN_ERR(b); + float sub = ((const TopkNode *)b)->element - ((const TopkNode *)a)->element; + if (sub > 0) { + return 1; + } else if (sub < 0) { + return -1; + } + if (((const TopkNode *)a)->index > ((const TopkNode *)b)->index) { + return 1; + } else { + return -1; + } +} + +int IndexSortCmp(const void *a, const void *b) { + if (((const TopkNode *)a)->index > ((const TopkNode *)b)->index) { + return 1; + } else { + return -1; + } +} + +void Topk(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter) { + int dim_size = parameter->dim_size_; + int outer_loop_num = parameter->outer_loop_num_; + int inner_loop_num = parameter->inner_loop_num_; + int k = parameter->k_; + TopkNode *top_map = (TopkNode *)parameter->topk_node_list_; + + float *cur_input_data = (float *)input_data; + float *cur_output_data = (float *)output_data; + int32_t *cur_output_index = output_index; + for (int i = 0; i < outer_loop_num; i++) { + int in_offset = i * dim_size * inner_loop_num; + int out_offset = i * k * inner_loop_num; + for (int j = 0; j < inner_loop_num; j++) { + for (int m = 0; m < dim_size; m++) { + int offset = in_offset + m * inner_loop_num + j; + top_map[m].element = *(cur_input_data + offset); + top_map[m].index = m; + } + qsort(top_map, dim_size, sizeof(top_map[0]), DescendCmp); + if (!parameter->sorted_) { + qsort(top_map, k, sizeof(top_map[0]), IndexSortCmp); + } + for (int m = 0; m < k; m++) { + int offset = out_offset + m * inner_loop_num + j; + cur_output_data[offset] = top_map[m].element; + cur_output_index[offset] = top_map[m].index; + } + } + } +} + +void TopkInt(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter) { + int dim_size = parameter->dim_size_; + int outer_loop_num = parameter->outer_loop_num_; + int inner_loop_num = parameter->inner_loop_num_; + int k = parameter->k_; + TopkNode *top_map = (TopkNode *)parameter->topk_node_list_; + + int32_t *cur_input_data = (int32_t *)input_data; + int32_t *cur_output_data = (int32_t *)output_data; + int32_t *cur_output_index = output_index; + for (int i = 0; i < outer_loop_num; i++) { + int in_offset = i * dim_size * inner_loop_num; + int out_offset = i * k * inner_loop_num; + for (int j = 0; j < inner_loop_num; j++) { + for (int m = 0; m < dim_size; m++) { + int offset = in_offset + m * inner_loop_num + j; + top_map[m].element = (float)(*(cur_input_data + offset)); + top_map[m].index = m; + } + qsort(top_map, dim_size, sizeof(top_map[0]), DescendCmp); + if (!parameter->sorted_) { + qsort(top_map, k, sizeof(top_map[0]), IndexSortCmp); + } + for (int m = 0; m < k; m++) { + int offset = out_offset + m * inner_loop_num + j; + cur_output_data[offset] = (int)(top_map[m].element); + cur_output_index[offset] = top_map[m].index; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/topk_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/topk_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..8adfd5bebc7510d1cabd97bd7b907c627116bb75 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/topk_fp32.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_TOPK_H_ +#define MINDSPORE_NNACL_TOPK_H_ + +#include "nnacl_c/op_base.h" + +typedef struct TopkNode { + float element; + int32_t index; +} TopkNode; + +typedef struct TopkParameter { + // primitive parameter + OpParameter op_parameter_; + int k_; + int axis_; + bool sorted_; + + // other parameter + int dim_size_; + int outer_loop_num_; + int inner_loop_num_; + void *topk_node_list_; +} TopkParameter; + +#ifdef __cplusplus +extern "C" { +#endif +void Topk(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter); +void TopkInt(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_TOPK_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..550df6dd69ab03d98379bc7077f8096beacea8d5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.c @@ -0,0 +1,248 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/transpose_fp32.h" +#include "nnacl_c/op_base.h" + +void TransposeDim2Fp32(const float *in_data, float *out_data, const int32_t *strides, int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * output1; + int stride0_i = i * 1 * stride0; + for (int j = 0; j < output1; ++j) { + out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; + } + } +} + +void TransposeDim3Fp32(const float *in_data, float *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; + } + } + } +} + +void TransposeDim4Fp32(const float *in_data, float *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = + in_data[stride0_i + stride1_j + stride2_k + m * stride3]; + } + } + } + } +} + +void TransposeDim5Fp32(const float *in_data, float *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; + } + } + } + } + } +} + +void TransposeDim6Fp32(const float *in_data, float *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int stride5 = strides[perm[5]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int out_stride4 = out_strides[4]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + const int output5 = output_shape[5]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + int out_stride4_m = n * out_stride4; + int stride4_m = n * stride4; + for (int g = 0; g < output5; ++g) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_m + g] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_m + g * stride5]; + } + } + } + } + } + } +} + +void TransposeDimsFp32(const void *in, void *out, const int32_t *output_shape, int32_t *perm, int32_t *strides, + int32_t *out_strides, int num_axes, int task_id, int thread_num) { + const float *in_data = (const float *)in; + float *out_data = (float *)out; + NNACL_CHECK_NULL_RETURN_VOID(in_data); + NNACL_CHECK_NULL_RETURN_VOID(out_data); + NNACL_CHECK_NULL_RETURN_VOID(output_shape); + + int data_size = (*out_strides) * output_shape[0]; + int offset_size = UP_DIV(data_size, thread_num); + int task_offset = offset_size * task_id; + int count = data_size - task_offset; + if (count <= 0) { + return; + } + count = MSMIN(offset_size, count); + for (int idx = task_offset; idx < task_offset + count; ++idx) { + int pos = idx; + int output_idx = 0; + int input_idx = 0; + for (int i = 0; i < num_axes; ++i) { + NNACL_CHECK_ZERO_RETURN(*(out_strides + i)); + int position = pos / *(out_strides + i); + int out_stride = i < num_axes - 1 ? out_strides[i] : 1; + output_idx += (position * out_stride); + input_idx += (position * strides[perm[i]]); + pos -= position * (*(out_strides + i)); + } + out_data[output_idx] = in_data[input_idx]; + } +} + +int DoTransposeFp32(const void *in, void *out, const int32_t *output_shape, int32_t *perm, int32_t *strides, + int32_t *out_strides, int data_size, int num_axes) { + const float *in_data = (const float *)in; + float *out_data = (float *)out; + + NNACL_CHECK_NULL_RETURN_ERR(in_data); + NNACL_CHECK_NULL_RETURN_ERR(out_data); + NNACL_CHECK_NULL_RETURN_ERR(output_shape); + + // check if transpose is needed + bool needTranspose = false; + for (int i = 1; i < num_axes; ++i) { + if (perm[i] - perm[i - 1] != 1) { + needTranspose = true; + break; + } + } + + if (!needTranspose) { + (void)memcpy(out_data, in_data, data_size); + return NNACL_OK; + } + for (int i = 0; i < num_axes; ++i) { + if (perm[i] < 0) { + return NNACL_PARAM_INVALID; + } + } + if (num_axes == 2) { + TransposeDim2Fp32(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 3) { + TransposeDim3Fp32(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 4) { + TransposeDim4Fp32(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 5) { + TransposeDim5Fp32(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 6) { + TransposeDim6Fp32(in_data, out_data, strides, out_strides, perm, output_shape); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..3cca6000b83d963c72d685664477a6c0c4ba8ab2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_TRANSPOSE_H_ +#define MINDSPORE_NNACL_FP32_TRANSPOSE_H_ + +#include +#include +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoTransposeFp32(const void *in, void *out, const int32_t *output_shape, int32_t *perm, int32_t *strides, + int32_t *out_strides, int data_size, int num_axes); +void TransposeDimsFp32(const void *in, void *out, const int32_t *output_shape, int32_t *perm, int32_t *strides, + int32_t *out_strides, int num_axes, int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_TRANSPOSE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_server_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_server_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..44bc0de2ce770246c19753d751d3ed9e7b92f676 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_server_fp32.c @@ -0,0 +1,239 @@ +#ifdef BFC_MEMORY +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/transpose_server_fp32.h" + +#define JUDGEPART(NUM) \ + if (dim_start##NUM == overflow_point##NUM) { \ + dim_start##NUM = 0; \ + } else { \ + ++dim_start##NUM; \ + in_offset += stride##NUM; \ + continue; \ + } + +void DoTransposeServerDim3(const float *in_data, float *out_data, const int64_t *overflow_points, + const int64_t *strides, const TransposeBlockBoundaryInfo *boundary_info) { + int64_t stride2 = strides[THIRD_INPUT]; + int64_t size = boundary_info->sizes[0]; + int64_t in_offset = boundary_info->in_offsets[0]; + out_data += boundary_info->out_start_offset; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride2]; + } + int64_t dim_start1 = boundary_info->start_dim[1]; + int64_t overflow_point1 = overflow_points[1]; + int64_t overflow_point2 = overflow_points[THIRD_INPUT]; + int64_t stride0 = strides[0]; + int64_t stride1 = strides[1]; + int64_t last_dim = overflow_point2 + 1; + out_data += size; + size = boundary_info->sizes[1]; + in_offset = boundary_info->in_offsets[1]; + for (int64_t i = 0; i < size; i += last_dim) { + for (int64_t j = 0; j < overflow_point2; ++j) { + out_data[i + j] = in_data[in_offset]; + in_offset += stride2; + } + out_data[i + overflow_point2] = in_data[in_offset]; + JUDGEPART(1) + in_offset += stride0; + } + out_data += size; + size = boundary_info->sizes[THIRD_INPUT]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride2]; + } +} + +void DoTransposeServerDim4(const float *in_data, float *out_data, const int64_t *overflow_points, + const int64_t *strides, const TransposeBlockBoundaryInfo *boundary_info) { + int64_t stride3 = strides[FOURTH_INPUT]; + int64_t size = boundary_info->sizes[0]; + int64_t in_offset = boundary_info->in_offsets[0]; + out_data += boundary_info->out_start_offset; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride3]; + } + int64_t dim_start1 = boundary_info->start_dim[1]; + int64_t dim_start2 = boundary_info->start_dim[THIRD_INPUT]; + int64_t overflow_point1 = overflow_points[1]; + int64_t overflow_point2 = overflow_points[THIRD_INPUT]; + int64_t overflow_point3 = overflow_points[FOURTH_INPUT]; + int64_t stride0 = strides[0]; + int64_t stride1 = strides[1]; + int64_t stride2 = strides[THIRD_INPUT]; + int64_t last_dim = overflow_point3 + 1; + out_data += size; + size = boundary_info->sizes[1]; + in_offset = boundary_info->in_offsets[1]; + for (int64_t i = 0; i < size; i += last_dim) { + for (int64_t j = 0; j < overflow_point3; ++j) { + out_data[i + j] = in_data[in_offset]; + in_offset += stride3; + } + out_data[i + overflow_point3] = in_data[in_offset]; + JUDGEPART(2) + JUDGEPART(1) + in_offset += stride0; + } + out_data += size; + size = boundary_info->sizes[THIRD_INPUT]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride3]; + } +} + +void DoTransposeServerDim5(const float *in_data, float *out_data, const int64_t *overflow_points, + const int64_t *strides, const TransposeBlockBoundaryInfo *boundary_info) { + int64_t stride4 = strides[FIFTH_INPUT]; + int64_t size = boundary_info->sizes[0]; + int64_t in_offset = boundary_info->in_offsets[0]; + out_data += boundary_info->out_start_offset; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride4]; + } + int64_t dim_start1 = boundary_info->start_dim[1]; + int64_t dim_start2 = boundary_info->start_dim[THIRD_INPUT]; + int64_t dim_start3 = boundary_info->start_dim[FOURTH_INPUT]; + int64_t overflow_point1 = overflow_points[1]; + int64_t overflow_point2 = overflow_points[THIRD_INPUT]; + int64_t overflow_point3 = overflow_points[FOURTH_INPUT]; + int64_t overflow_point4 = overflow_points[FIFTH_INPUT]; + int64_t stride0 = strides[0]; + int64_t stride1 = strides[1]; + int64_t stride2 = strides[THIRD_INPUT]; + int64_t stride3 = strides[FOURTH_INPUT]; + int64_t last_dim = overflow_point4 + 1; + out_data += size; + size = boundary_info->sizes[1]; + in_offset = boundary_info->in_offsets[1]; + for (int64_t i = 0; i < size; i += last_dim) { + for (int64_t j = 0; j < overflow_point4; ++j) { + out_data[i + j] = in_data[in_offset]; + in_offset += stride4; + } + out_data[i + overflow_point4] = in_data[in_offset]; + JUDGEPART(3) + JUDGEPART(2) + JUDGEPART(1) + in_offset += stride0; + } + out_data += size; + size = boundary_info->sizes[THIRD_INPUT]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride4]; + } +} + +void DoTransposeServerDim6(const float *in_data, float *out_data, const int64_t *overflow_points, + const int64_t *strides, const TransposeBlockBoundaryInfo *boundary_info) { + int64_t stride5 = strides[SIXTH_INPUT]; + int64_t size = boundary_info->sizes[0]; + int64_t in_offset = boundary_info->in_offsets[0]; + out_data += boundary_info->out_start_offset; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride5]; + } + int64_t dim_start1 = boundary_info->start_dim[1]; + int64_t dim_start2 = boundary_info->start_dim[THIRD_INPUT]; + int64_t dim_start3 = boundary_info->start_dim[FOURTH_INPUT]; + int64_t dim_start4 = boundary_info->start_dim[FIFTH_INPUT]; + int64_t overflow_point1 = overflow_points[1]; + int64_t overflow_point2 = overflow_points[THIRD_INPUT]; + int64_t overflow_point3 = overflow_points[FOURTH_INPUT]; + int64_t overflow_point4 = overflow_points[FIFTH_INPUT]; + int64_t overflow_point5 = overflow_points[SIXTH_INPUT]; + int64_t stride0 = strides[0]; + int64_t stride1 = strides[1]; + int64_t stride2 = strides[THIRD_INPUT]; + int64_t stride3 = strides[FOURTH_INPUT]; + int64_t stride4 = strides[FIFTH_INPUT]; + int64_t last_dim = overflow_point5 + 1; + out_data += size; + size = boundary_info->sizes[1]; + in_offset = boundary_info->in_offsets[1]; + for (int64_t i = 0; i < size; i += last_dim) { + for (int64_t j = 0; j < overflow_point5; ++j) { + out_data[i + j] = in_data[in_offset]; + in_offset += stride5; + } + out_data[i + overflow_point5] = in_data[in_offset]; + JUDGEPART(4) + JUDGEPART(3) + JUDGEPART(2) + JUDGEPART(1) + in_offset += stride0; + } + out_data += size; + size = boundary_info->sizes[THIRD_INPUT]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride5]; + } +} + +void DoTransposeServer(const float *in_data, float *out_data, const int64_t *overflow_points, const int64_t *strides, + int axis_num, const TransposeBlockBoundaryInfo *boundary_info) { + if (axis_num == DIMENSION_3D) { + DoTransposeServerDim3(in_data, out_data, overflow_points, strides, boundary_info); + return; + } else if (axis_num == DIMENSION_4D) { + DoTransposeServerDim4(in_data, out_data, overflow_points, strides, boundary_info); + return; + } else if (axis_num == DIMENSION_5D) { + DoTransposeServerDim5(in_data, out_data, overflow_points, strides, boundary_info); + return; + } else if (axis_num == DIMENSION_6D) { + DoTransposeServerDim6(in_data, out_data, overflow_points, strides, boundary_info); + return; + } + out_data += boundary_info->out_start_offset; + int64_t stride = strides[axis_num - 1]; + int64_t size = boundary_info->sizes[0]; + int64_t in_offset = boundary_info->in_offsets[0]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride]; + } + int64_t dim_info[MAX_TRANSPOSE_DIM_SIZE] = {}; + for (int i = 0; i < axis_num; ++i) { + dim_info[i] = boundary_info->start_dim[i]; + } + int64_t last_overflow_point = overflow_points[axis_num - 1]; + int64_t last_dim = last_overflow_point + 1; + out_data += size; + size = boundary_info->sizes[1]; + for (int64_t i = 0; i < size; i += last_dim) { + for (int64_t j = 0; j < last_overflow_point; ++j) { + out_data[i + j] = in_data[in_offset]; + in_offset += stride; + } + out_data[i + last_overflow_point] = in_data[in_offset]; + int j = axis_num - 2; + while (dim_info[j] == overflow_points[j]) { + dim_info[j] = 0; + --j; + } + ++dim_info[j]; + in_offset += strides[j]; + } + out_data += size; + size = boundary_info->sizes[THIRD_INPUT]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride]; + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_server_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_server_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..1c1be31baef574290c1ddb72165679f366e4e68c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_server_fp32.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_TRANSPOSE_SERVER_FP32_H_ +#define MINDSPORE_NNACL_FP32_TRANSPOSE_SERVER_FP32_H_ + +#ifdef BFC_MEMORY +#include "nnacl_c/transpose_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TransposeBlockBoundaryInfo { + int64_t out_start_offset; + int64_t sizes[C3NUM]; + int64_t in_offsets[C2NUM]; + int64_t start_dim[MAX_TRANSPOSE_DIM_SIZE]; +} TransposeBlockBoundaryInfo; + +void DoTransposeServer(const float *in_data, float *out_data, const int64_t *overflow_points, const int64_t *strides, + int axis_num, const TransposeBlockBoundaryInfo *boundary_info); +#ifdef __cplusplus +}; +#endif + +#endif // MINDSPORE_NNACL_FP32_TRANSPOSE_SERVER_FP32_H_ +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/triu_tril_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/triu_tril_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..15de2a44646f0b0033bd6839be6658284be78069 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/triu_tril_fp32.c @@ -0,0 +1,179 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/triu_tril_fp32.h" + +int TriuTrilGetCalculateNum(KernelBase *self, int64_t *mul, int64_t *height, int64_t *width) { + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + for (size_t i = 0; i < input_tensor->shape_size_; i++) { + if (input_tensor->shape_[i] <= 0) { + return NNACL_TRIU_TRIL_INPUT_SHAPE_ERROR; + } + } + + size_t input_hw_dims = Num2; + NNACL_CHECK_FALSE(input_tensor->shape_size_ < DIMENSION_2D, NNACL_TRIU_INPUT_DIMS_INVALID); + + *mul = 1; + for (size_t i = 0; i < input_tensor->shape_size_ - input_hw_dims; i++) { + *mul *= input_tensor->shape_[i]; + } + *height = input_tensor->shape_[input_tensor->shape_size_ - Num2]; + *width = input_tensor->shape_[input_tensor->shape_size_ - Num1]; + + return NNACL_OK; +} + +int TriuTrilGetKValue(KernelBase *self, int64_t *k) { + if (self->in_size_ <= 1) { + *k = 0; + return NNACL_OK; + } + + TensorC *k_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(k_tensor); + NNACL_CHECK_NULL_RETURN_ERR(k_tensor->data_); + + switch (k_tensor->data_type_) { + case kNumberTypeInt: + case kNumberTypeInt32: + *k = *((int32_t *)k_tensor->data_); + break; + case kNumberTypeInt64: + *k = *((int64_t *)k_tensor->data_); + break; + default: + return NNACL_TRIU_K_TENSOR_DATA_TYPE_INVALID; + } + return NNACL_OK; +} + +void TriuByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int64_t *src_data = (const int64_t *)src; + int64_t *dst_data = (int64_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k <= w ? src_data[index] : 0; + } + } + } +} + +void TriuByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int32_t *src_data = (const int32_t *)src; + int32_t *dst_data = (int32_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k <= w ? src_data[index] : 0; + } + } + } +} + +void TriuByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int16_t *src_data = (const int16_t *)src; + int16_t *dst_data = (int16_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k <= w ? src_data[index] : 0; + } + } + } +} +void TriuByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int8_t *src_data = (const int8_t *)src; + int8_t *dst_data = (int8_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k <= w ? src_data[index] : 0; + } + } + } +} + +void TrilByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int64_t *src_data = (const int64_t *)src; + int64_t *dst_data = (int64_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k >= w ? src_data[index] : 0; + } + } + } +} + +void TrilByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int32_t *src_data = (const int32_t *)src; + int32_t *dst_data = (int32_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k >= w ? src_data[index] : 0; + } + } + } +} +void TrilByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int16_t *src_data = (const int16_t *)src; + int16_t *dst_data = (int16_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k >= w ? src_data[index] : 0; + } + } + } +} +void TrilByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int8_t *src_data = (const int8_t *)src; + int8_t *dst_data = (int8_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k >= w ? src_data[index] : 0; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/triu_tril_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/triu_tril_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..24205877e176242f60729cca54d3e79944089716 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/triu_tril_fp32.h @@ -0,0 +1,42 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_TRIU_TRIL_H_ +#define MINDSPORE_NNACL_FP32_TRIU_TRIL_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TriuTrilGetCalculateNum(KernelBase *self, int64_t *mul, int64_t *height, int64_t *width); +int TriuTrilGetKValue(KernelBase *self, int64_t *k); + +void TriuByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TriuByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TriuByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TriuByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); + +void TrilByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TrilByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TrilByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TrilByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_TRIU_TRIL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/unique_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/unique_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d0716dd13755a63efc60a0ee0954ad45ae298829 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/unique_fp32.c @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/unique_fp32.h" + +int Find(const float *array, int len, float target) { + if (array == NULL) { + return -1; + } + for (int i = 0; i < len; ++i) { + if (array[i] == target) { + return i; + } + } + return -1; +} + +void Unique(const float *input, int input_len, float *output0, int32_t *output0_len, int32_t *output1) { + *output0_len = 0; + for (int i = 0; i < input_len; i++) { + int idx = Find(output0, *output0_len, input[i]); + if (idx != -1) { + *output1++ = idx; + } else { + output0[(*output0_len)++] = input[i]; + *output1++ = *output0_len - 1; + } + } +} + +int FindInt(const int32_t *array, int len, int target) { + if (array == NULL) { + return -1; + } + for (int i = 0; i < len; ++i) { + if (array[i] == target) { + return i; + } + } + return -1; +} + +void UniqueInt(const int32_t *input, int input_len, int32_t *output0, int32_t *output0_len, int32_t *output1) { + *output0_len = 0; + for (int i = 0; i < input_len; i++) { + int idx = FindInt(output0, *output0_len, input[i]); + if (idx != -1) { + *output1++ = idx; + } else { + output0[(*output0_len)++] = input[i]; + *output1++ = *output0_len - 1; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/unique_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/unique_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..3b503922579d503ae5260b2aa3b48c61bb174335 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/unique_fp32.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_UNIQUE_H +#define MINDSPORE_NNACL_UNIQUE_H + +#include "nnacl_c/op_base.h" + +typedef struct UniqueParameter { + // primitive parameter + OpParameter op_parameter_; +} UniqueParameter; + +#ifdef __cplusplus +extern "C" { +#endif +void Unique(const float *input, int input_len, float *output0, int32_t *output0_len, int32_t *output1); +void UniqueInt(const int32_t *input, int input_len, int32_t *output0, int32_t *output0_len, int32_t *output1); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_UNIQUE_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/where_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/where_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..efcc3530b8e7946172b94e48f17346a053203158 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/where_fp32.c @@ -0,0 +1,35 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/where_fp32.h" +#include "nnacl_c/common_func.h" + +void WhereWithTripleInputs(const float *x, const float *y, float *output, const WhereArgs *param, int task_id, + int thread_num) { + const bool *condition = param->condition_; + int stride = UP_DIV(param->max_num_, thread_num); + int begin = task_id * stride; + int end = MSMIN(begin + stride, param->max_num_); + + for (int i = begin; i < end; ++i) { + bool cond = condition[param->condition_num_ > 1 ? i : 0]; + if (cond) { + output[i] = x[param->x_num_ > 1 ? i : 0]; + } else { + output[i] = y[param->y_num_ > 1 ? i : 0]; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/where_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/where_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..d1112c0a149663dcb348357711d8e3857c0f446b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/where_fp32.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_WHERE_FP32_H_ +#define MINDSPORE_NNACL_FP32_WHERE_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/where_parameter.h" +#include "nnacl_c/kernel/where.h" + +#ifdef __cplusplus +extern "C" { +#endif +void WhereWithTripleInputs(const float *x, const float *y, float *output, const WhereArgs *param, int task_id, + int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_WHERE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_avx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_avx.c new file mode 100644 index 0000000000000000000000000000000000000000..e3be43d9bbef0ea38c83a5ef86fe542736a7dc7c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_avx.c @@ -0,0 +1,2233 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless re256uired by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/fp32/winograd_avx.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void InputTransform4x4AvxUnit(const float *src_data, float *dst_data, const int src_step, const int dst_step, + const int real_c) { + if (real_c == C8NUM) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[16]; + MS_FLOAT32X8 m[16]; + LoadAvx16Data; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_SUB256_F32(src[offset], src[2 + offset]); + t[4 + l] = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + t[8 + l] = MS_SUB256_F32(src[2 + offset], src[1 + offset]); + t[12 + l] = MS_SUB256_F32(src[3 + offset], src[1 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = MS_SUB256_F32(t[offset], t[2 + offset]); + m[4 + l] = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + m[8 + l] = MS_SUB256_F32(t[2 + offset], t[1 + offset]); + m[12 + l] = MS_SUB256_F32(t[3 + offset], t[1 + offset]); + } + for (int i = 0; i < 16; i++) { + MS_ST256_F32(dst_data + i * dst_step, m[i]); + } + } else { + float src[16]; + float t[16]; + float m[16]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] - src[2 + offset]; + t[4 + l] = src[1 + offset] + src[2 + offset]; + t[8 + l] = src[2 + offset] - src[1 + offset]; + t[12 + l] = src[3 + offset] - src[1 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = t[offset] - t[2 + offset]; + m[4 + l] = t[1 + offset] + t[2 + offset]; + m[8 + l] = t[2 + offset] - t[1 + offset]; + m[12 + l] = t[3 + offset] - t[1 + offset]; + } + for (int k = 0; k < 16; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } + } +} + +void InputTransform6x6AvxUnit(const float *src_data, float *dst_data, const int src_step, const int dst_step, + const int real_c) { + if (real_c == C8NUM) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[36]; + MS_FLOAT32X8 m[36]; + LoadAvx36Data; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_SUB256_F32(src[3 + offset], src[1 + offset]); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(src[4 + offset], src[2 + offset]); + t[l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[offset], 4), MS_MUL256_N_F32(src[2 + offset], 5)), + src[4 + offset]); + t[6 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_ADD256_F32(src[1 + offset], src[2 + offset]), -4), + MS_ADD256_F32(src[3 + offset], src[4 + offset])); + t[12 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), 4), + MS_SUB256_F32(src[4 + offset], src[3 + offset])); + t[18 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 2), tmp2); + t[24 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, -2), tmp2); + t[30 + l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[1 + offset], 4), MS_MUL256_N_F32(src[3 + offset], 5)), + src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_SUB256_F32(t[3 + offset], t[1 + offset]); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(t[4 + offset], t[2 + offset]); + m[l] = + MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[offset], 4), MS_MUL256_N_F32(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_ADD256_F32(t[1 + offset], t[2 + offset]), -4), + MS_ADD256_F32(t[3 + offset], t[4 + offset])); + m[12 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), 4), + MS_SUB256_F32(t[4 + offset], t[3 + offset])); + m[18 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 2), tmp2); + m[24 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, -2), tmp2); + m[30 + l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[1 + offset], 4), MS_MUL256_N_F32(t[3 + offset], 5)), + t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + MS_ST256_F32(dst_data + i * dst_step, m[i]); + } + } else { + float src[36]; + float t[36]; + float m[36]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = src[3 + offset] - src[1 + offset]; + float tmp2 = src[4 + offset] - src[2 + offset]; + t[l] = 4 * src[offset] - 5 * src[2 + offset] + src[4 + offset]; + t[6 + l] = -4 * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]); + t[12 + l] = 4 * (src[1 + offset] - src[2 + offset]) + (src[4 + offset] - src[3 + offset]); + t[18 + l] = 2 * tmp1 + tmp2; + t[24 + l] = -2 * tmp1 + tmp2; + t[30 + l] = 4 * src[1 + offset] - 5 * src[3 + offset] + src[5 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = t[3 + offset] - t[1 + offset]; + float tmp2 = t[4 + offset] - t[2 + offset]; + m[l] = 4 * t[offset] - 5 * t[2 + offset] + t[4 + offset]; + m[6 + l] = -4 * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]); + m[12 + l] = 4 * (t[1 + offset] - t[2 + offset]) + (t[4 + offset] - t[3 + offset]); + m[18 + l] = 2 * tmp1 + tmp2; + m[24 + l] = -2 * tmp1 + tmp2; + m[30 + l] = 4 * t[1 + offset] - 5 * t[3 + offset] + t[5 + offset]; + } + for (int k = 0; k < 36; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } + } +} + +void InputTransform8x8AvxUnit_block8(const float *src_data, float *dst_data, const int src_step, const int dst_step) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[64]; + MS_FLOAT32X8 m[64]; + LoadAvx64Data; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = MS_SUB256_F32( + MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[offset], 0.5625), MS_MUL256_N_F32(src[2 + offset], 3.0625)), + MS_MUL256_N_F32(src[4 + offset], 3.5)), + src[6 + offset]); + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 1.125), MS_MUL256_N_F32(src[5 + offset], 0.5)); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 2.25), MS_MUL256_N_F32(src[4 + offset], 3.25)); + t[8 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 0.5625), MS_MUL256_N_F32(src[4 + offset], 2.5)); + t[24 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 0.375), MS_MUL256_N_F32(src[5 + offset], 1.5)); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 0.25), MS_MUL256_N_F32(src[4 + offset], 1.25)); + t[40 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = MS_ADD256_F32( + MS_SUB256_F32(MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], -0.5625), MS_MUL256_N_F32(src[3 + offset], 3.0625)), + MS_MUL256_N_F32(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = MS_SUB256_F32( + MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[offset], 0.5625), MS_MUL256_N_F32(t[2 + offset], 3.0625)), + MS_MUL256_N_F32(t[4 + offset], 3.5)), + t[6 + offset]); + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 1.125), MS_MUL256_N_F32(t[5 + offset], 0.5)); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 2.25), MS_MUL256_N_F32(t[4 + offset], 3.25)); + m[8 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 0.5625), MS_MUL256_N_F32(t[4 + offset], 2.5)); + m[24 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 0.375), MS_MUL256_N_F32(t[5 + offset], 1.5)); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 0.25), MS_MUL256_N_F32(t[4 + offset], 1.25)); + m[40 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = MS_ADD256_F32( + MS_SUB256_F32(MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], -0.5625), MS_MUL256_N_F32(t[3 + offset], 3.0625)), + MS_MUL256_N_F32(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + MS_ST256_F32(dst_data + i * dst_step, m[i]); + } +} + +void InputTransform8x8AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + if (real_c == C8NUM) { + InputTransform8x8AvxUnit_block8(src_data, dst_data, src_step, dst_step); + } else { + float src[64]; + float t[64]; + float m[64]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset]; + float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset]; + float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset]; + t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset]; + t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.5625f * src[1 + offset] + src[5 + offset]; + tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset]; + t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset]; + t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset]; + tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset]; + t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset]; + t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset]; + t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset]; + float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset]; + float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset]; + m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset]; + m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.5625f * t[1 + offset] + t[5 + offset]; + tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset]; + m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset]; + m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset]; + tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset]; + m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset]; + m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset]; + m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset]; + } + for (int k = 0; k < 64; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } + } +} + +void OutputTransform4x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[8]; + MS_FLOAT32X8 m[4]; + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform4x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[8]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform4x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[8]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + m[l + 2] = MS_MIN256_F32(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform4x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[9]; + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADD256_F32(src[offset], tmp); + t[l + 4] = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADD256_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(tmp, t[3 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform4x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADD256_F32(src[offset], tmp); + t[l + 4] = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADD256_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(tmp, t[3 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform4x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADD256_F32(src[offset], tmp); + t[l + 4] = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADD256_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(tmp, t[3 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 3] = MS_MIN256_F32(six, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + m[l + 6] = MS_MIN256_F32(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[4]; + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + m[l + 2] = MS_MIN256_F32(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} +void OutputTransform6x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[18]; + MS_FLOAT32X8 m[9]; + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[18]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[18]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 3] = MS_MIN256_F32(six, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + m[l + 6] = MS_MIN256_F32(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x4AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[16]; + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x4ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[16]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 4] = MS_MAX256_F32(zero, m[l + 4]); + m[l + 8] = MS_MAX256_F32(zero, m[l + 8]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x4Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[16]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 4] = MS_MAX256_F32(zero, m[l + 4]); + m[l + 4] = MS_MIN256_F32(six, m[l + 4]); + m[l + 8] = MS_MAX256_F32(zero, m[l + 8]); + m[l + 8] = MS_MIN256_F32(six, m[l + 8]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + m[l + 12] = MS_MIN256_F32(six, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x5AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[30]; + MS_FLOAT32X8 m[25]; + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x5ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[30]; + MS_FLOAT32X8 m[25]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 5] = MS_MAX256_F32(zero, m[l + 5]); + m[l + 10] = MS_MAX256_F32(zero, m[l + 10]); + m[l + 15] = MS_MAX256_F32(zero, m[l + 15]); + m[l + 20] = MS_MAX256_F32(zero, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x5Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[30]; + MS_FLOAT32X8 m[25]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 5] = MS_MAX256_F32(zero, m[l + 5]); + m[l + 5] = MS_MIN256_F32(six, m[l + 5]); + m[l + 10] = MS_MAX256_F32(zero, m[l + 10]); + m[l + 10] = MS_MIN256_F32(six, m[l + 10]); + m[l + 15] = MS_MAX256_F32(zero, m[l + 15]); + m[l + 15] = MS_MIN256_F32(six, m[l + 15]); + m[l + 20] = MS_MAX256_F32(zero, m[l + 20]); + m[l + 20] = MS_MIN256_F32(six, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[16]; + MS_FLOAT32X8 m[4]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[16]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[16]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + m[l + 2] = MS_MIN256_F32(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[9]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 6] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 6] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 6] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 3] = MS_MIN256_F32(six, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + m[l + 6] = MS_MIN256_F32(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x4AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[32]; + MS_FLOAT32X8 m[16]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 8] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x4ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[32]; + MS_FLOAT32X8 m[16]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 8] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 4] = MS_MAX256_F32(zero, m[l + 4]); + m[l + 8] = MS_MAX256_F32(zero, m[l + 8]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x4Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[32]; + MS_FLOAT32X8 m[16]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 8] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 4] = MS_MAX256_F32(zero, m[l + 4]); + m[l + 4] = MS_MIN256_F32(six, m[l + 4]); + m[l + 8] = MS_MAX256_F32(zero, m[l + 8]); + m[l + 8] = MS_MIN256_F32(six, m[l + 8]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + m[l + 12] = MS_MIN256_F32(six, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x5AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[40]; + MS_FLOAT32X8 m[25]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 10] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x5ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[40]; + MS_FLOAT32X8 m[25]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 10] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 5] = MS_MAX256_F32(zero, m[l + 5]); + m[l + 10] = MS_MAX256_F32(zero, m[l + 10]); + m[l + 15] = MS_MAX256_F32(zero, m[l + 15]); + m[l + 20] = MS_MAX256_F32(zero, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x5Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[40]; + MS_FLOAT32X8 m[25]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 10] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 5] = MS_MAX256_F32(zero, m[l + 5]); + m[l + 5] = MS_MIN256_F32(six, m[l + 5]); + m[l + 10] = MS_MAX256_F32(zero, m[l + 10]); + m[l + 10] = MS_MIN256_F32(six, m[l + 10]); + m[l + 15] = MS_MAX256_F32(zero, m[l + 15]); + m[l + 15] = MS_MIN256_F32(six, m[l + 15]); + m[l + 20] = MS_MAX256_F32(zero, m[l + 20]); + m[l + 20] = MS_MIN256_F32(six, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[48]; + MS_FLOAT32X8 m[36]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x6ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[48]; + MS_FLOAT32X8 m[36]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + m[l + 18] = MS_MAX256_F32(zero, m[l + 18]); + m[l + 24] = MS_MAX256_F32(zero, m[l + 24]); + m[l + 30] = MS_MAX256_F32(zero, m[l + 30]); + } + if (r_c == C8NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x6Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[48]; + MS_FLOAT32X8 m[36]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + m[l + 6] = MS_MIN256_F32(six, m[l + 6]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + m[l + 12] = MS_MIN256_F32(six, m[l + 12]); + m[l + 18] = MS_MAX256_F32(zero, m[l + 18]); + m[l + 18] = MS_MIN256_F32(six, m[l + 18]); + m[l + 24] = MS_MAX256_F32(zero, m[l + 24]); + m[l + 24] = MS_MIN256_F32(six, m[l + 24]); + m[l + 30] = MS_MAX256_F32(zero, m[l + 30]); + m[l + 30] = MS_MIN256_F32(six, m[l + 30]); + } + if (r_c == C8NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x7AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[56]; + MS_FLOAT32X8 m[49]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), MS_MUL256_N_F32(tmp3, 11.390625)), + src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 14] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), + MS_MUL256_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_ST256_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x7ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[56]; + MS_FLOAT32X8 m[49]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), MS_MUL256_N_F32(tmp3, 11.390625)), + src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 14] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), + MS_MUL256_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 7] = MS_MAX256_F32(zero, m[l + 7]); + m[l + 14] = MS_MAX256_F32(zero, m[l + 14]); + m[l + 21] = MS_MAX256_F32(zero, m[l + 21]); + m[l + 28] = MS_MAX256_F32(zero, m[l + 28]); + m[l + 35] = MS_MAX256_F32(zero, m[l + 35]); + m[l + 42] = MS_MAX256_F32(zero, m[l + 42]); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_ST256_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x7Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[56]; + MS_FLOAT32X8 m[49]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), MS_MUL256_N_F32(tmp3, 11.390625)), + src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 14] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), + MS_MUL256_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 7] = MS_MAX256_F32(zero, m[l + 7]); + m[l + 7] = MS_MIN256_F32(six, m[l + 7]); + m[l + 14] = MS_MAX256_F32(zero, m[l + 14]); + m[l + 14] = MS_MIN256_F32(six, m[l + 14]); + m[l + 21] = MS_MAX256_F32(zero, m[l + 21]); + m[l + 21] = MS_MIN256_F32(six, m[l + 21]); + m[l + 28] = MS_MAX256_F32(zero, m[l + 28]); + m[l + 28] = MS_MIN256_F32(six, m[l + 28]); + m[l + 35] = MS_MAX256_F32(zero, m[l + 35]); + m[l + 35] = MS_MIN256_F32(six, m[l + 35]); + m[l + 42] = MS_MAX256_F32(zero, m[l + 42]); + m[l + 42] = MS_MIN256_F32(six, m[l + 42]); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_ST256_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_avx.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_avx.h new file mode 100644 index 0000000000000000000000000000000000000000..a9843129689e06f6eae511d2831703fe092de981 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_avx.h @@ -0,0 +1,299 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#ifndef MINDSPORE_NNACL_WINOGRAD_AVX_H_ +#define MINDSPORE_NNACL_WINOGRAD_AVX_H_ + +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +#define LoadAvx16Data \ + src[0] = MS_LD256_F32(src_data + 0 * src_step); \ + src[1] = MS_LD256_F32(src_data + 1 * src_step); \ + src[2] = MS_LD256_F32(src_data + 2 * src_step); \ + src[3] = MS_LD256_F32(src_data + 3 * src_step); \ + src[4] = MS_LD256_F32(src_data + 4 * src_step); \ + src[5] = MS_LD256_F32(src_data + 5 * src_step); \ + src[6] = MS_LD256_F32(src_data + 6 * src_step); \ + src[7] = MS_LD256_F32(src_data + 7 * src_step); \ + src[8] = MS_LD256_F32(src_data + 8 * src_step); \ + src[9] = MS_LD256_F32(src_data + 9 * src_step); \ + src[10] = MS_LD256_F32(src_data + 10 * src_step); \ + src[11] = MS_LD256_F32(src_data + 11 * src_step); \ + src[12] = MS_LD256_F32(src_data + 12 * src_step); \ + src[13] = MS_LD256_F32(src_data + 13 * src_step); \ + src[14] = MS_LD256_F32(src_data + 14 * src_step); \ + src[15] = MS_LD256_F32(src_data + 15 * src_step); + +#define LoadAvx36Data \ + src[0] = MS_LD256_F32(src_data + 0 * src_step); \ + src[1] = MS_LD256_F32(src_data + 1 * src_step); \ + src[2] = MS_LD256_F32(src_data + 2 * src_step); \ + src[3] = MS_LD256_F32(src_data + 3 * src_step); \ + src[4] = MS_LD256_F32(src_data + 4 * src_step); \ + src[5] = MS_LD256_F32(src_data + 5 * src_step); \ + src[6] = MS_LD256_F32(src_data + 6 * src_step); \ + src[7] = MS_LD256_F32(src_data + 7 * src_step); \ + src[8] = MS_LD256_F32(src_data + 8 * src_step); \ + src[9] = MS_LD256_F32(src_data + 9 * src_step); \ + src[10] = MS_LD256_F32(src_data + 10 * src_step); \ + src[11] = MS_LD256_F32(src_data + 11 * src_step); \ + src[12] = MS_LD256_F32(src_data + 12 * src_step); \ + src[13] = MS_LD256_F32(src_data + 13 * src_step); \ + src[14] = MS_LD256_F32(src_data + 14 * src_step); \ + src[15] = MS_LD256_F32(src_data + 15 * src_step); \ + src[16] = MS_LD256_F32(src_data + 16 * src_step); \ + src[17] = MS_LD256_F32(src_data + 17 * src_step); \ + src[18] = MS_LD256_F32(src_data + 18 * src_step); \ + src[19] = MS_LD256_F32(src_data + 19 * src_step); \ + src[20] = MS_LD256_F32(src_data + 20 * src_step); \ + src[21] = MS_LD256_F32(src_data + 21 * src_step); \ + src[22] = MS_LD256_F32(src_data + 22 * src_step); \ + src[23] = MS_LD256_F32(src_data + 23 * src_step); \ + src[24] = MS_LD256_F32(src_data + 24 * src_step); \ + src[25] = MS_LD256_F32(src_data + 25 * src_step); \ + src[26] = MS_LD256_F32(src_data + 26 * src_step); \ + src[27] = MS_LD256_F32(src_data + 27 * src_step); \ + src[28] = MS_LD256_F32(src_data + 28 * src_step); \ + src[29] = MS_LD256_F32(src_data + 29 * src_step); \ + src[30] = MS_LD256_F32(src_data + 30 * src_step); \ + src[31] = MS_LD256_F32(src_data + 31 * src_step); \ + src[32] = MS_LD256_F32(src_data + 32 * src_step); \ + src[33] = MS_LD256_F32(src_data + 33 * src_step); \ + src[34] = MS_LD256_F32(src_data + 34 * src_step); \ + src[35] = MS_LD256_F32(src_data + 35 * src_step); + +#define LoadAvx64Data \ + src[0] = MS_LD256_F32(src_data + 0 * src_step); \ + src[1] = MS_LD256_F32(src_data + 1 * src_step); \ + src[2] = MS_LD256_F32(src_data + 2 * src_step); \ + src[3] = MS_LD256_F32(src_data + 3 * src_step); \ + src[4] = MS_LD256_F32(src_data + 4 * src_step); \ + src[5] = MS_LD256_F32(src_data + 5 * src_step); \ + src[6] = MS_LD256_F32(src_data + 6 * src_step); \ + src[7] = MS_LD256_F32(src_data + 7 * src_step); \ + src[8] = MS_LD256_F32(src_data + 8 * src_step); \ + src[9] = MS_LD256_F32(src_data + 9 * src_step); \ + src[10] = MS_LD256_F32(src_data + 10 * src_step); \ + src[11] = MS_LD256_F32(src_data + 11 * src_step); \ + src[12] = MS_LD256_F32(src_data + 12 * src_step); \ + src[13] = MS_LD256_F32(src_data + 13 * src_step); \ + src[14] = MS_LD256_F32(src_data + 14 * src_step); \ + src[15] = MS_LD256_F32(src_data + 15 * src_step); \ + src[16] = MS_LD256_F32(src_data + 16 * src_step); \ + src[17] = MS_LD256_F32(src_data + 17 * src_step); \ + src[18] = MS_LD256_F32(src_data + 18 * src_step); \ + src[19] = MS_LD256_F32(src_data + 19 * src_step); \ + src[20] = MS_LD256_F32(src_data + 20 * src_step); \ + src[21] = MS_LD256_F32(src_data + 21 * src_step); \ + src[22] = MS_LD256_F32(src_data + 22 * src_step); \ + src[23] = MS_LD256_F32(src_data + 23 * src_step); \ + src[24] = MS_LD256_F32(src_data + 24 * src_step); \ + src[25] = MS_LD256_F32(src_data + 25 * src_step); \ + src[26] = MS_LD256_F32(src_data + 26 * src_step); \ + src[27] = MS_LD256_F32(src_data + 27 * src_step); \ + src[28] = MS_LD256_F32(src_data + 28 * src_step); \ + src[29] = MS_LD256_F32(src_data + 29 * src_step); \ + src[30] = MS_LD256_F32(src_data + 30 * src_step); \ + src[31] = MS_LD256_F32(src_data + 31 * src_step); \ + src[32] = MS_LD256_F32(src_data + 32 * src_step); \ + src[33] = MS_LD256_F32(src_data + 33 * src_step); \ + src[34] = MS_LD256_F32(src_data + 34 * src_step); \ + src[35] = MS_LD256_F32(src_data + 35 * src_step); \ + src[36] = MS_LD256_F32(src_data + 36 * src_step); \ + src[37] = MS_LD256_F32(src_data + 37 * src_step); \ + src[38] = MS_LD256_F32(src_data + 38 * src_step); \ + src[39] = MS_LD256_F32(src_data + 39 * src_step); \ + src[40] = MS_LD256_F32(src_data + 40 * src_step); \ + src[41] = MS_LD256_F32(src_data + 41 * src_step); \ + src[42] = MS_LD256_F32(src_data + 42 * src_step); \ + src[43] = MS_LD256_F32(src_data + 43 * src_step); \ + src[44] = MS_LD256_F32(src_data + 44 * src_step); \ + src[45] = MS_LD256_F32(src_data + 45 * src_step); \ + src[46] = MS_LD256_F32(src_data + 46 * src_step); \ + src[47] = MS_LD256_F32(src_data + 47 * src_step); \ + src[48] = MS_LD256_F32(src_data + 48 * src_step); \ + src[49] = MS_LD256_F32(src_data + 49 * src_step); \ + src[50] = MS_LD256_F32(src_data + 50 * src_step); \ + src[51] = MS_LD256_F32(src_data + 51 * src_step); \ + src[52] = MS_LD256_F32(src_data + 52 * src_step); \ + src[53] = MS_LD256_F32(src_data + 53 * src_step); \ + src[54] = MS_LD256_F32(src_data + 54 * src_step); \ + src[55] = MS_LD256_F32(src_data + 55 * src_step); \ + src[56] = MS_LD256_F32(src_data + 56 * src_step); \ + src[57] = MS_LD256_F32(src_data + 57 * src_step); \ + src[58] = MS_LD256_F32(src_data + 58 * src_step); \ + src[59] = MS_LD256_F32(src_data + 59 * src_step); \ + src[60] = MS_LD256_F32(src_data + 60 * src_step); \ + src[61] = MS_LD256_F32(src_data + 61 * src_step); \ + src[62] = MS_LD256_F32(src_data + 62 * src_step); \ + src[63] = MS_LD256_F32(src_data + 63 * src_step); + +#define StoreAvx4Data \ + MS_ST256_F32(dst_data, m[0]); \ + MS_ST256_F32(dst_data + out_c, m[1]); \ + MS_ST256_F32(dst_data + dst_step * out_c, m[2]); \ + MS_ST256_F32(dst_data + dst_step * out_c + out_c, m[3]); + +#define StoreAvx9Data \ + MS_ST256_F32(dst_data, m[0]); \ + MS_ST256_F32(dst_data + out_c, m[1]); \ + MS_ST256_F32(dst_data + 2 * out_c, m[2]); \ + MS_ST256_F32(dst_data + dst_step * out_c, m[3]); \ + MS_ST256_F32(dst_data + dst_step * out_c + out_c, m[4]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c, m[6]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); + +#define StoreAvx16Data \ + MS_ST256_F32(dst_data, m[0]); \ + MS_ST256_F32(dst_data + out_c, m[1]); \ + MS_ST256_F32(dst_data + 2 * out_c, m[2]); \ + MS_ST256_F32(dst_data + 3 * out_c, m[3]); \ + MS_ST256_F32(dst_data + dst_step * out_c, m[4]); \ + MS_ST256_F32(dst_data + dst_step * out_c + out_c, m[5]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c, m[8]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c, m[12]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); + +#define StoreAvx25Data \ + MS_ST256_F32(dst_data, m[0]); \ + MS_ST256_F32(dst_data + out_c, m[1]); \ + MS_ST256_F32(dst_data + 2 * out_c, m[2]); \ + MS_ST256_F32(dst_data + 3 * out_c, m[3]); \ + MS_ST256_F32(dst_data + 4 * out_c, m[4]); \ + MS_ST256_F32(dst_data + dst_step * out_c, m[5]); \ + MS_ST256_F32(dst_data + dst_step * out_c + out_c, m[6]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c, m[10]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c, m[15]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ + MS_ST256_F32(dst_data + 4 * dst_step * out_c, m[20]); \ + MS_ST256_F32(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ + MS_ST256_F32(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ + MS_ST256_F32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ + MS_ST256_F32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); + +void InputTransform4x4AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void OutputTransform4x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform6x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform8x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_WINOGRAD_AVX_H_ +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_transform.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_transform.c new file mode 100644 index 0000000000000000000000000000000000000000..49960ef45fdd7c2ff410a02d52734dc8a05e2d99 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_transform.c @@ -0,0 +1,281 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/winograd_transform.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void PrepareTransInput(const float *src_data, float *dst_data, int interval_x_s, int interval_x_e, int interval_y_s, + int interval_y_e, int real_c, const ConvParameter *conv_param) { + int input_unit = conv_param->input_unit_; + int in_channel = conv_param->input_channel_; + int input_w = conv_param->input_w_; +#ifdef ENABLE_AVX + int channel_tile = C8NUM; +#else + int channel_tile = C4NUM; +#endif + // clear tmp buffer + if (interval_x_e - interval_x_s != input_unit || interval_y_e - interval_y_s != input_unit) { + memset(dst_data, 0, input_unit * input_unit * channel_tile * (int)(sizeof(float))); + } + + // get real input block with padding + if (real_c == channel_tile) { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * channel_tile + interval_x_s * channel_tile; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * channel_tile; + const float *src_addr = src_data + src_x_offset; + float *dst_addr = dst_data + dst_x_offset; +#ifdef ENABLE_AVX + MS_ST256_F32(dst_addr, MS_LD256_F32(src_addr)); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_STQ_F32(dst_addr, MS_LDQ_F32(src_addr)); +#else + for (int k = 0; k < channel_tile; k++) { + dst_addr[k] = src_addr[k]; + } +#endif + } // interval x loop + } // interval y loop + } else { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * channel_tile + interval_x_s * channel_tile; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * channel_tile; + const float *src_addr = src_data + src_x_offset; + float *dst_addr = dst_data + dst_x_offset; + for (int k = 0; k < real_c; k++) { + dst_addr[k] = src_addr[k]; + } + } // interval x loop + } // interval y loop + } +} + +// fp32 conv winograd +void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransFunc func) { + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; +#ifdef ENABLE_AVX + int channel_tile = C8NUM; +#else + int channel_tile = C4NUM; +#endif + int ic4 = UP_DIV(in_channel, channel_tile); + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + NNACL_CHECK_ZERO_RETURN(out_w_block_num); + + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * in_channel; + for (int ic = 0; ic < ic4; ic++) { + int real_c = in_channel - ic * channel_tile; + real_c = real_c > channel_tile ? channel_tile : real_c; + const float *src_data = input_data + src_plane_offset + ic * channel_tile; + PrepareTransInput(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, conv_param); + + // input transform + const int tile_num = C12NUM; + int dst_ic4_offset = dst_plane_offset + ic * channel_tile; + int dst_step = tile_num * in_channel; + float *trans_input_ptr = trans_input + dst_ic4_offset; + func(tmp_data, trans_input_ptr, channel_tile, dst_step, real_c); + } + out_tile_index++; + } // cal_tile_num loop +} + +// Only support arm64 +void WinogradInputTransformOptStep(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransStepFunc func) { + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; + int channel_tile = C4NUM; + int ic4 = UP_DIV(in_channel, channel_tile); + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + NNACL_CHECK_ZERO_RETURN(out_w_block_num); + + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * channel_tile; + for (int ic = 0; ic < ic4; ic++) { + int real_c = in_channel - ic * channel_tile; + real_c = real_c > channel_tile ? channel_tile : real_c; + const float *src_data = input_data + src_plane_offset + ic * channel_tile; + PrepareTransInput(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, conv_param); + + // input transform + const int block_tile = C12NUM; + int dst_ic8_offset = dst_plane_offset + ic * block_tile * input_unit * input_unit * channel_tile; + size_t dst_step = (size_t)(input_unit * block_tile * channel_tile); + float *trans_input_ptr = trans_input + dst_ic8_offset; + func(tmp_data, trans_input_ptr, channel_tile, dst_step, block_tile * channel_tile); + } + out_tile_index++; + } // cal_tile_num loop +} + +void WinogradOutputNHWCTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, const ConvParameter *conv_param, + OutputTransFunc func) { + int output_unit = conv_param->output_unit_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int output_channel = conv_param->output_channel_; +#ifndef ENABLE_AVX + int oc4 = UP_DIV(output_channel, C4NUM); +#endif + int oc8 = UP_DIV(output_channel, C8NUM); + int input_unit = conv_param->input_unit_; + NNACL_CHECK_ZERO_RETURN(output_unit_num); + + for (int i = 0; i < cal_num; i++) { + int dst_x_s = out_tile_index % output_unit_num; + int dst_y_s = out_tile_index / output_unit_num; + int r_w = output_w - dst_x_s * output_unit; + r_w = r_w > output_unit ? output_unit : r_w; + int r_h = output_h - dst_y_s * output_unit; + r_h = r_h > output_unit ? output_unit : r_h; + int tmp_ix = dst_x_s * output_unit; + dst_x_s = tmp_ix > output_w ? output_w : tmp_ix; + int tmp_iy = dst_y_s * output_unit; + dst_y_s = tmp_iy > output_h ? output_h : tmp_iy; + + int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit; + int dst_tile_offset = output_channel * (dst_x_s + dst_y_s * output_w); +#ifndef ENABLE_AVX + // avx is write to nc4hw4 + for (int j = 0; j < oc4; j++) { + int c8_block = j / 2; + int c8_res = j % 2; + int r_c = output_channel - j * C4NUM; + r_c = r_c > C4NUM ? C4NUM : r_c; + int src_oc4_offset = src_tile_offset + c8_block * input_unit * input_unit * C8NUM + c8_res * C4NUM; + int dst_oc4_offset = dst_tile_offset + j * C4NUM; + const float *src_ptr = gemm_out + src_oc4_offset; + const float *bias_ptr = bias_data + j * C4NUM; + float *dst_ptr = out_data + dst_oc4_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, output_channel, r_w, r_h, r_c); + } +#else + // avx is write to nc8hw8 + for (int j = 0; j < oc8; j++) { + int r_c = output_channel - j * C8NUM; + r_c = r_c > C8NUM ? C8NUM : r_c; + int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM; + int dst_oc8_offset = dst_tile_offset + j * C8NUM; + const float *src_ptr = gemm_out + src_oc8_offset; + const float *bias_ptr = bias_data + j * C8NUM; + float *dst_ptr = out_data + dst_oc8_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, output_channel, r_w, r_h, r_c); + } +#endif + out_tile_index++; + } +} + +void WinogradOutputNC4HW4Transform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, const ConvParameter *conv_param, + OutputTransFunc func) { + int output_unit = conv_param->output_unit_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int output_plane = output_w * output_h; + int output_channel = conv_param->output_channel_; +#ifndef ENABLE_AVX + int oc4 = UP_DIV(output_channel, C4NUM); +#endif + int oc8 = UP_DIV(output_channel, C8NUM); + int input_unit = conv_param->input_unit_; + NNACL_CHECK_ZERO_RETURN(output_unit_num); + + for (int i = 0; i < cal_num; i++) { + int dst_x_s = out_tile_index % output_unit_num; + int dst_y_s = out_tile_index / output_unit_num; + int r_w = output_w - dst_x_s * output_unit; + r_w = r_w > output_unit ? output_unit : r_w; + int r_h = output_h - dst_y_s * output_unit; + r_h = r_h > output_unit ? output_unit : r_h; + int tmp_ix = dst_x_s * output_unit; + dst_x_s = tmp_ix > output_w ? output_w : tmp_ix; + int tmp_iy = dst_y_s * output_unit; + dst_y_s = tmp_iy > output_h ? output_h : tmp_iy; + + int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit; + int dst_tile_offset = dst_x_s + dst_y_s * output_w; +#ifdef ENABLE_AVX + for (int j = 0; j < oc8; j++) { + int r_c = output_channel - j * C8NUM; + r_c = r_c > C8NUM ? C8NUM : r_c; + int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM; + int dst_oc8_offset = (dst_tile_offset + output_plane * j) * C8NUM; + const float *src_ptr = gemm_out + src_oc8_offset; + const float *bias_ptr = bias_data + j * C8NUM; + float *dst_ptr = out_data + dst_oc8_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, r_c, r_w, r_h, r_c); + } +#else + for (int j = 0; j < oc4; j++) { + int c8_block = j / 2; + int c8_res = j % 2; + int r_c = output_channel - j * C4NUM; + r_c = r_c > C4NUM ? C4NUM : r_c; + int src_oc4_offset = src_tile_offset + c8_block * input_unit * input_unit * C8NUM + c8_res * C4NUM; + int dst_oc4_offset = (dst_tile_offset + output_plane * j) * C4NUM; + const float *src_ptr = gemm_out + src_oc4_offset; + const float *bias_ptr = bias_data + j * C4NUM; + float *dst_ptr = out_data + dst_oc4_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, r_c, r_w, r_h, r_c); + } +#endif + out_tile_index++; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_transform.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_transform.h new file mode 100644 index 0000000000000000000000000000000000000000..38c19c19d991ba4a432936511d6046dfda6a6d29 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_transform.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_WINOGRAD_TRANSFORM_H_ +#define MINDSPORE_NNACL_WINOGRAD_TRANSFORM_H_ + +#ifdef ENABLE_ARM +#include +#endif +#include +#include "nnacl_c/pack.h" +#include "nnacl_c/fp32/winograd_utils.h" + +#ifdef __cplusplus +extern "C" { +#endif +// for fp32 winograd input/output transform +void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransFunc func); + +void WinogradInputTransformOptStep(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransStepFunc func); + +void WinogradOutputNHWCTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, const ConvParameter *conv_param, + OutputTransFunc func); + +void WinogradOutputNC4HW4Transform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, const ConvParameter *conv_param, + OutputTransFunc func); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_WINOGRAD_TRANSFORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_utils.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_utils.c new file mode 100644 index 0000000000000000000000000000000000000000..345cf64662c4c5269a71b768e299a93a938d477d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_utils.c @@ -0,0 +1,4289 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_avx.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/base/minimal_filtering_generator.h" +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef ENABLE_ARM64 +void transpose4(MS_FLOAT32X4 *s0, MS_FLOAT32X4 *s1, MS_FLOAT32X4 *s2, MS_FLOAT32X4 *s3) { + float64x2_t m0 = (float64x2_t)(vtrn1q_f32(*s0, *s1)); + float64x2_t m1 = (float64x2_t)(vtrn2q_f32(*s0, *s1)); + float64x2_t m2 = (float64x2_t)(vtrn1q_f32(*s2, *s3)); + float64x2_t m3 = (float64x2_t)(vtrn2q_f32(*s2, *s3)); + *s0 = (float32x4_t)(vtrn1q_f64(m0, m2)); + *s2 = (float32x4_t)(vtrn2q_f64(m0, m2)); + *s1 = (float32x4_t)(vtrn1q_f64(m1, m3)); + *s3 = (float32x4_t)(vtrn2q_f64(m1, m3)); +} +#endif + +#ifdef ENABLE_AVX +static InputTransFunc InputTransFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4AvxUnit, NULL, InputTransform6x6AvxUnit, NULL, InputTransform8x8AvxUnit}; + +static OutputTransFunc OutputTransFuncList[] = { + OutputTransform4x2AvxUnit, OutputTransform4x3AvxUnit, OutputTransform4x2ReluAvxUnit, + OutputTransform4x3ReluAvxUnit, OutputTransform4x2Relu6AvxUnit, OutputTransform4x3Relu6AvxUnit, + OutputTransform6x2AvxUnit, OutputTransform6x3AvxUnit, OutputTransform6x4AvxUnit, + OutputTransform6x5AvxUnit, OutputTransform6x2ReluAvxUnit, OutputTransform6x3ReluAvxUnit, + OutputTransform6x4ReluAvxUnit, OutputTransform6x5ReluAvxUnit, OutputTransform6x2Relu6AvxUnit, + OutputTransform6x3Relu6AvxUnit, OutputTransform6x4Relu6AvxUnit, OutputTransform6x5Relu6AvxUnit, + OutputTransform8x2AvxUnit, OutputTransform8x3AvxUnit, OutputTransform8x4AvxUnit, + OutputTransform8x5AvxUnit, OutputTransform8x6AvxUnit, OutputTransform8x7AvxUnit, + OutputTransform8x2ReluAvxUnit, OutputTransform8x3ReluAvxUnit, OutputTransform8x4ReluAvxUnit, + OutputTransform8x5ReluAvxUnit, OutputTransform8x6ReluAvxUnit, OutputTransform8x7ReluAvxUnit, + OutputTransform8x2Relu6AvxUnit, OutputTransform8x3Relu6AvxUnit, OutputTransform8x4Relu6AvxUnit, + OutputTransform8x5Relu6AvxUnit, OutputTransform8x6Relu6AvxUnit, OutputTransform8x7Relu6AvxUnit}; +#else +static InputTransFunc InputTransFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4Unit, NULL, InputTransform6x6Unit, NULL, InputTransform8x8Unit}; + +static OutputTransFunc OutputTransFuncList[] = { + OutputTransform4x2Unit, OutputTransform4x3Unit, OutputTransform4x2ReluUnit, OutputTransform4x3ReluUnit, + OutputTransform4x2Relu6Unit, OutputTransform4x3Relu6Unit, OutputTransform6x2Unit, OutputTransform6x3Unit, + OutputTransform6x4Unit, OutputTransform6x5Unit, OutputTransform6x2ReluUnit, OutputTransform6x3ReluUnit, + OutputTransform6x4ReluUnit, OutputTransform6x5ReluUnit, OutputTransform6x2Relu6Unit, OutputTransform6x3Relu6Unit, + OutputTransform6x4Relu6Unit, OutputTransform6x5Relu6Unit, OutputTransform8x2Unit, OutputTransform8x3Unit, + OutputTransform8x4Unit, OutputTransform8x5Unit, OutputTransform8x6Unit, OutputTransform8x7Unit, + OutputTransform8x2ReluUnit, OutputTransform8x3ReluUnit, OutputTransform8x4ReluUnit, OutputTransform8x5ReluUnit, + OutputTransform8x6ReluUnit, OutputTransform8x7ReluUnit, OutputTransform8x2Relu6Unit, OutputTransform8x3Relu6Unit, + OutputTransform8x4Relu6Unit, OutputTransform8x5Relu6Unit, OutputTransform8x6Relu6Unit, OutputTransform8x7Relu6Unit}; +#endif + +InputTransFunc GetInputTransFunc(int input_unit) { return InputTransFuncList[input_unit]; } + +#ifdef ENABLE_ARM64 +static InputTransStepFunc InputTransStepFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4Step, NULL, InputTransform6x6Step, NULL, InputTransform8x8Step}; + +static InputTransPackFunc InputTransPackFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4Pack12, NULL, InputTransform6x6Pack12, NULL, InputTransform8x8Pack12}; + +InputTransStepFunc GetInputTransStepFunc(int input_unit) { return InputTransStepFuncList[input_unit]; } + +InputTransPackFunc GetInputTransPackFunc(int input_unit) { return InputTransPackFuncList[input_unit]; } +#endif + +void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + if (real_c == 4) { + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[16]; + + src[0] = MS_LDQ_F32(src_data); + src[1] = MS_LDQ_F32(src_data + src_step); + src[2] = MS_LDQ_F32(src_data + 2 * src_step); + src[3] = MS_LDQ_F32(src_data + 3 * src_step); + + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + t[l] = MS_SUBQ_F32(src[offset], src[2 + offset]); + src[offset + 4] = MS_LDQ_F32(src_data + (offset + 4) * src_step); + t[4 + l] = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + src[offset + 5] = MS_LDQ_F32(src_data + (offset + 5) * src_step); + t[8 + l] = MS_SUBQ_F32(src[2 + offset], src[1 + offset]); + src[offset + 6] = MS_LDQ_F32(src_data + (offset + 6) * src_step); + t[12 + l] = MS_SUBQ_F32(src[3 + offset], src[1 + offset]); + src[offset + 7] = MS_LDQ_F32(src_data + (offset + 7) * src_step); + } + + int offset = 3 * 4; + t[3] = MS_SUBQ_F32(src[offset], src[2 + offset]); + t[7] = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + t[11] = MS_SUBQ_F32(src[2 + offset], src[1 + offset]); + t[15] = MS_SUBQ_F32(src[3 + offset], src[1 + offset]); + + src[0] = MS_SUBQ_F32(t[0], t[2]); + src[1] = MS_ADDQ_F32(t[1], t[2]); + src[2] = MS_SUBQ_F32(t[2], t[1]); + src[3] = MS_SUBQ_F32(t[3], t[1]); + + for (int l = 1; l < 4; ++l) { + offset = l * 4; + src[offset] = MS_SUBQ_F32(t[offset], t[2 + offset]); + MS_STQ_F32(dst_data + (l - 1) * dst_step, src[offset - 4]); + src[offset + 1] = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_STQ_F32(dst_data + (3 + l) * dst_step, src[offset - 3]); + src[offset + 2] = MS_SUBQ_F32(t[2 + offset], t[1 + offset]); + MS_STQ_F32(dst_data + (7 + l) * dst_step, src[offset - 2]); + src[offset + 3] = MS_SUBQ_F32(t[3 + offset], t[1 + offset]); + MS_STQ_F32(dst_data + (11 + l) * dst_step, src[offset - 1]); + } + + MS_STQ_F32(dst_data + 3 * dst_step, src[12]); + MS_STQ_F32(dst_data + dst_step * 7, src[13]); + MS_STQ_F32(dst_data + dst_step * 11, src[14]); + MS_STQ_F32(dst_data + dst_step * 15, src[15]); + + } else { +#endif + float src[16]; + float t[16]; + float m[16]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] - src[2 + offset]; + t[4 + l] = src[1 + offset] + src[2 + offset]; + t[8 + l] = src[2 + offset] - src[1 + offset]; + t[12 + l] = src[3 + offset] - src[1 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = t[offset] - t[2 + offset]; + m[4 + l] = t[1 + offset] + t[2 + offset]; + m[8 + l] = t[2 + offset] - t[1 + offset]; + m[12 + l] = t[3 + offset] - t[1 + offset]; + } + for (int k = 0; k < 16; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + } +#endif +} + +void InputTransform4x4Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) { +#ifdef ENABLE_ARM64 + for (int l = 0; l < 4; ++l) { + const float *src_ptr = src_data + l * 4 * src_step; + float *dst_ptr = dst_data + l * dst_row_step; + + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step); + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step); + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step); + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step); + MS_FLOAT32X4 m0 = MS_SUBQ_F32(s0, s2); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(s1, s2); + MS_FLOAT32X4 m2 = MS_SUBQ_F32(s2, s1); + MS_FLOAT32X4 m3 = MS_SUBQ_F32(s3, s1); + + MS_STQ_F32(dst_ptr + 0 * dst_step, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step, m2); + MS_STQ_F32(dst_ptr + 3 * dst_step, m3); + } +#else + float src[4]; + float m[4]; + for (int i = 0; i < C4NUM; ++i) { + for (int l = 0; l < 4; ++l) { + for (int w = 0; w < 4; ++w) { + int tmp_index = l * 4 + w; + src[w] = src_data[i + tmp_index * src_step]; + } + m[0] = src[0] - src[2]; + m[1] = src[1] + src[2]; + m[2] = src[2] - src[1]; + m[3] = src[3] - src[1]; + + float *dst = dst_data + l * dst_row_step; + for (int w = 0; w < 4; ++w) { + dst[i + w * dst_step] = m[w]; + } + } + } +#endif +} + +#ifdef ENABLE_ARM64 +void InputTransform4x4Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) { + LOAD_LINE_DATA(0); + LOAD_LINE_DATA(1); + LOAD_LINE_DATA(2); + LOAD_LINE_DATA(3); + + MS_FLOAT32X4 m0 = MS_SUBQ_F32(s00, s20); + MS_FLOAT32X4 m1 = MS_SUBQ_F32(s01, s21); + MS_FLOAT32X4 m2 = MS_SUBQ_F32(s02, s22); + MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(s10, s20); + m1 = MS_ADDQ_F32(s11, s21); + m2 = MS_ADDQ_F32(s12, s22); + MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2); + + m0 = MS_SUBQ_F32(s20, s10); + m1 = MS_SUBQ_F32(s21, s11); + m2 = MS_SUBQ_F32(s22, s12); + MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2); + + m0 = MS_SUBQ_F32(s30, s10); + m1 = MS_SUBQ_F32(s31, s11); + m2 = MS_SUBQ_F32(s32, s12); + MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2); +} +#endif + +void InputTransform4x4Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 12; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; +#ifdef ENABLE_ARM64 + for (int l = 0; l < 4; ++l) { + float *src_ptr = src_data + l * C4NUM * block_tile; + TRANSPOSE_12x4; + } + + for (int c = 0; c < real_c; ++c) { + float *src_ptr = src_data + c * block_tile; + float *dst_ptr = dst_data + c * block_tile; + InputTransform4x4Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +#else + for (int l = 0; l < 4; ++l) { + float *src = src_data + l * pack_tile * block_tile; + // 12 * 4 -> 4 * 12 + float tmp_mat[4][12]; + for (int i = 0; i < block_tile; ++i) { + for (int j = 0; j < pack_tile; ++j) { + tmp_mat[j][i] = src[i * pack_tile + j]; + } + } + memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float)); + } + + float src[4]; + float m[4]; + for (int c = 0; c < real_c; ++c) { + for (int i = 0; i < block_tile; ++i) { + int tmp_index = c * block_tile + i; + for (int w = 0; w < 4; ++w) { + src[w] = src_data[tmp_index + w * src_point_stride]; + } + + m[0] = src[0] - src[2]; + m[1] = src[1] + src[2]; + m[2] = src[2] - src[1]; + m[3] = src[3] - src[1]; + + for (int w = 0; w < 4; ++w) { + dst_data[tmp_index + w * dst_step] = m[w]; + } + } + } +#endif +} + +void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + if (real_c == 4) { + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[36]; + MS_FLOAT32X4 m[36]; + Load36Data; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_SUBQ_F32(src[3 + offset], src[1 + offset]); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(src[4 + offset], src[2 + offset]); + t[l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(src[offset], 4), MS_MULQ_N_F32(src[2 + offset], 5)), src[4 + offset]); + t[6 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(src[1 + offset], src[2 + offset]), -4), + MS_ADDQ_F32(src[3 + offset], src[4 + offset])); + t[12 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), 4), + MS_SUBQ_F32(src[4 + offset], src[3 + offset])); + t[18 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 2), tmp2); + t[24 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, -2), tmp2); + t[30 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(src[1 + offset], 4), MS_MULQ_N_F32(src[3 + offset], 5)), src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_SUBQ_F32(t[3 + offset], t[1 + offset]); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(t[4 + offset], t[2 + offset]); + m[l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(t[offset], 4), MS_MULQ_N_F32(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(t[1 + offset], t[2 + offset]), -4), + MS_ADDQ_F32(t[3 + offset], t[4 + offset])); + m[12 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), 4), + MS_SUBQ_F32(t[4 + offset], t[3 + offset])); + m[18 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 2), tmp2); + m[24 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, -2), tmp2); + m[30 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(t[1 + offset], 4), MS_MULQ_N_F32(t[3 + offset], 5)), t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + MS_STQ_F32(dst_data + i * dst_step, m[i]); + } + } else { +#endif + float src[36]; + float t[36]; + float m[36]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = src[3 + offset] - src[1 + offset]; + float tmp2 = src[4 + offset] - src[2 + offset]; + t[l] = 4 * src[offset] - 5 * src[2 + offset] + src[4 + offset]; + t[6 + l] = -4 * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]); + t[12 + l] = 4 * (src[1 + offset] - src[2 + offset]) + (src[4 + offset] - src[3 + offset]); + t[18 + l] = 2 * tmp1 + tmp2; + t[24 + l] = -2 * tmp1 + tmp2; + t[30 + l] = 4 * src[1 + offset] - 5 * src[3 + offset] + src[5 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = t[3 + offset] - t[1 + offset]; + float tmp2 = t[4 + offset] - t[2 + offset]; + m[l] = 4 * t[offset] - 5 * t[2 + offset] + t[4 + offset]; + m[6 + l] = -4 * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]); + m[12 + l] = 4 * (t[1 + offset] - t[2 + offset]) + (t[4 + offset] - t[3 + offset]); + m[18 + l] = 2 * tmp1 + tmp2; + m[24 + l] = -2 * tmp1 + tmp2; + m[30 + l] = 4 * t[1 + offset] - 5 * t[3 + offset] + t[5 + offset]; + } + for (int k = 0; k < 36; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + } +#endif +} + +void InputTransform6x6Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) { +#ifdef ENABLE_ARM64 + for (int l = 0; l < 6; ++l) { + const float *src_ptr = src_data + l * 6 * src_step; + float *dst_ptr = dst_data + l * dst_row_step; + + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step); + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step); + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step); + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step); + MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 4 * src_step); + MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 5 * src_step); + + MS_FLOAT32X4 tmp1 = MS_SUBQ_F32(s3, s1); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(s4, s2); + MS_FLOAT32X4 m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s0, 4), MS_MULQ_N_F32(s2, 5)), s4); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s1, s2), -4), MS_ADDQ_F32(s3, s4)); + MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s1, s2), 4), MS_SUBQ_F32(s4, s3)); + MS_FLOAT32X4 m3 = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 2), tmp2); + MS_FLOAT32X4 m4 = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, -2), tmp2); + MS_FLOAT32X4 m5 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s1, 4), MS_MULQ_N_F32(s3, 5)), s5); + + MS_STQ_F32(dst_ptr + 0 * dst_step, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step, m2); + MS_STQ_F32(dst_ptr + 3 * dst_step, m3); + MS_STQ_F32(dst_ptr + 4 * dst_step, m4); + MS_STQ_F32(dst_ptr + 5 * dst_step, m5); + } +#else + float src[6]; + float m[6]; + for (int i = 0; i < C4NUM; ++i) { + for (int l = 0; l < 6; ++l) { + for (int w = 0; w < 6; ++w) { + int tmp_index = l * 6 + w; + src[w] = src_data[i + tmp_index * src_step]; + } + float tmp1 = src[3] - src[1]; + float tmp2 = src[4] - src[2]; + m[0] = 4 * src[0] - 5 * src[2] + src[4]; + m[1] = -4 * (src[1] + src[2]) + (src[3] + src[4]); + m[2] = 4 * (src[1] - src[2]) + (src[4] - src[3]); + m[3] = 2 * tmp1 + tmp2; + m[4] = -2 * tmp1 + tmp2; + m[5] = 4 * src[1] - 5 * src[3] + src[5]; + + float *dst = dst_data + l * dst_row_step; + for (int w = 0; w < 6; ++w) { + dst[i + w * dst_step] = m[w]; + } + } + } +#endif +} + +#ifdef ENABLE_ARM64 +void InputTransform6x6Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) { + LOAD_LINE_DATA(0); + LOAD_LINE_DATA(1); + LOAD_LINE_DATA(2); + LOAD_LINE_DATA(3); + LOAD_LINE_DATA(4); + LOAD_LINE_DATA(5); + + MS_FLOAT32X4 m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s00, 4), MS_MULQ_N_F32(s20, 5)), s40); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s01, 4), MS_MULQ_N_F32(s21, 5)), s41); + MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s02, 4), MS_MULQ_N_F32(s22, 5)), s42); + MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s10, s20), -4), MS_ADDQ_F32(s30, s40)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s11, s21), -4), MS_ADDQ_F32(s31, s41)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s12, s22), -4), MS_ADDQ_F32(s32, s42)); + MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s10, s20), 4), MS_SUBQ_F32(s40, s30)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s11, s21), 4), MS_SUBQ_F32(s41, s31)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s12, s22), 4), MS_SUBQ_F32(s42, s32)); + MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s30, s10), 2), MS_SUBQ_F32(s40, s20)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s31, s11), 2), MS_SUBQ_F32(s41, s21)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s32, s12), 2), MS_SUBQ_F32(s42, s22)); + MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s30, s10), -2), MS_SUBQ_F32(s40, s20)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s31, s11), -2), MS_SUBQ_F32(s41, s21)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s32, s12), -2), MS_SUBQ_F32(s42, s22)); + MS_STQ_F32(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 4 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s10, 4), MS_MULQ_N_F32(s30, 5)), s50); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s11, 4), MS_MULQ_N_F32(s31, 5)), s51); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s12, 4), MS_MULQ_N_F32(s32, 5)), s52); + MS_STQ_F32(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 5 * dst_step + 2 * pack_tile, m2); +} +#endif + +void InputTransform6x6Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 12; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; +#ifdef ENABLE_ARM64 + for (int l = 0; l < 6; ++l) { + float *src_ptr = src_data + l * C4NUM * block_tile; + TRANSPOSE_12x4; + } + + for (int c = 0; c < real_c; ++c) { + float *src_ptr = src_data + c * block_tile; + float *dst_ptr = dst_data + c * block_tile; + InputTransform6x6Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +#else + for (int l = 0; l < 6; ++l) { + float *src = src_data + l * pack_tile * block_tile; + // 12 * 4 -> 4 * 12 + float tmp_mat[4][12]; + for (int i = 0; i < block_tile; ++i) { + for (int j = 0; j < pack_tile; ++j) { + tmp_mat[j][i] = src[i * pack_tile + j]; + } + } + memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float)); + } + + float src[6]; + float m[6]; + for (int c = 0; c < real_c; ++c) { + for (int i = 0; i < block_tile; ++i) { + int tmp_index = c * block_tile + i; + for (int w = 0; w < 6; ++w) { + src[w] = src_data[tmp_index + w * src_point_stride]; + } + + float tmp1 = src[3] - src[1]; + float tmp2 = src[4] - src[2]; + m[0] = 4 * src[0] - 5 * src[2] + src[4]; + m[1] = -4 * (src[1] + src[2]) + (src[3] + src[4]); + m[2] = 4 * (src[1] - src[2]) + (src[4] - src[3]); + m[3] = 2 * tmp1 + tmp2; + m[4] = -2 * tmp1 + tmp2; + m[5] = 4 * src[1] - 5 * src[3] + src[5]; + + for (int w = 0; w < 6; ++w) { + dst_data[tmp_index + w * dst_step] = m[w]; + } + } + } +#endif +} + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void InputTransform8x8Unit_block4(const float *src_data, float *dst_data, int src_step, int dst_step) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[64]; + MS_FLOAT32X4 m[64]; + Load64Data; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = + MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(src[offset], 0.5625), MS_MULQ_N_F32(src[2 + offset], 3.0625)), + MS_MULQ_N_F32(src[4 + offset], 3.5)), + src[6 + offset]); + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], 1.125), MS_MULQ_N_F32(src[5 + offset], 0.5)); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(src[2 + offset], 2.25), MS_MULQ_N_F32(src[4 + offset], 3.25)); + t[8 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(src[2 + offset], 0.5625), MS_MULQ_N_F32(src[4 + offset], 2.5)); + t[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], 0.375), MS_MULQ_N_F32(src[5 + offset], 1.5)); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(src[2 + offset], 0.25), MS_MULQ_N_F32(src[4 + offset], 1.25)); + t[40 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], -0.5625), MS_MULQ_N_F32(src[3 + offset], 3.0625)), + MS_MULQ_N_F32(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(t[offset], 0.5625), MS_MULQ_N_F32(t[2 + offset], 3.0625)), + MS_MULQ_N_F32(t[4 + offset], 3.5)), + t[6 + offset]); + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], 1.125), MS_MULQ_N_F32(t[5 + offset], 0.5)); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(t[2 + offset], 2.25), MS_MULQ_N_F32(t[4 + offset], 3.25)); + m[8 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(t[2 + offset], 0.5625), MS_MULQ_N_F32(t[4 + offset], 2.5)); + m[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], 0.375), MS_MULQ_N_F32(t[5 + offset], 1.5)); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(t[2 + offset], 0.25), MS_MULQ_N_F32(t[4 + offset], 1.25)); + m[40 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], -0.5625), MS_MULQ_N_F32(t[3 + offset], 3.0625)), + MS_MULQ_N_F32(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + MS_STQ_F32(dst_data + i * dst_step, m[i]); + } +} +#endif + +void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + if (real_c == 4) { + InputTransform8x8Unit_block4(src_data, dst_data, src_step, dst_step); + } else { +#endif + float src[64]; + float t[64]; + float m[64]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset]; + float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset]; + float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset]; + t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset]; + t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.5625f * src[1 + offset] + src[5 + offset]; + tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset]; + t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset]; + t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset]; + tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset]; + t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset]; + t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset]; + t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset]; + float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset]; + float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset]; + m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset]; + m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.5625f * t[1 + offset] + t[5 + offset]; + tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset]; + m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset]; + m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset]; + tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset]; + m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset]; + m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset]; + m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset]; + } + for (int k = 0; k < 64; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + } +#endif +} + +void InputTransform8x8Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) { +#ifdef ENABLE_ARM64 + for (int l = 0; l < 8; ++l) { + const float *src_ptr = src_data + l * 8 * src_step; + float *dst_ptr = dst_data + l * dst_row_step; + + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step); + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step); + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step); + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step); + MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 4 * src_step); + MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 5 * src_step); + MS_FLOAT32X4 s6 = MS_LDQ_F32(src_ptr + 6 * src_step); + MS_FLOAT32X4 s7 = MS_LDQ_F32(src_ptr + 7 * src_step); + + MS_FLOAT32X4 m0 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s0, 0.5625), MS_MULQ_N_F32(s2, 3.0625)), MS_MULQ_N_F32(s4, 3.5)), s6); + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 1.125), MS_MULQ_N_F32(s5, 0.5)); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 2.25), MS_MULQ_N_F32(s4, 3.25)); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 1.625)), s6); + MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 1.625)), s6); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 0.5625), s5); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 0.5625), MS_MULQ_N_F32(s4, 2.5)); + MS_FLOAT32X4 m3 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 2.5)), s6); + MS_FLOAT32X4 m4 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 2.5)), s6); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 0.375), MS_MULQ_N_F32(s5, 1.5)); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 0.25), MS_MULQ_N_F32(s4, 1.25)); + MS_FLOAT32X4 m5 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 1.875)), s6); + MS_FLOAT32X4 m6 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 1.875)), s6); + MS_FLOAT32X4 m7 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s1, -0.5625), MS_MULQ_N_F32(s3, 3.0625)), MS_MULQ_N_F32(s5, 3.5)), s7); + + MS_STQ_F32(dst_ptr + 0 * dst_step, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step, m2); + MS_STQ_F32(dst_ptr + 3 * dst_step, m3); + MS_STQ_F32(dst_ptr + 4 * dst_step, m4); + MS_STQ_F32(dst_ptr + 5 * dst_step, m5); + MS_STQ_F32(dst_ptr + 6 * dst_step, m6); + MS_STQ_F32(dst_ptr + 7 * dst_step, m7); + } +#else + float src[8]; + float m[8]; + for (int i = 0; i < C4NUM; ++i) { + for (int l = 0; l < 8; ++l) { + for (int w = 0; w < 8; ++w) { + int tmp_index = l * 8 + w; + src[w] = src_data[i + tmp_index * src_step]; + } + m[0] = 0.5625f * src[0] - 3.0625f * src[2] + 3.5f * src[4] - src[6]; + float tmp1 = 1.125f * src[1] + 0.5f * src[5]; + float tmp2 = 2.25f * src[2] - 3.25f * src[4]; + m[1] = tmp1 + tmp2 - 1.625f * src[3] + src[6]; + m[2] = tmp2 - tmp1 + 1.625f * src[3] + src[6]; + tmp1 = 0.5625f * src[1] + src[5]; + tmp2 = 0.5625f * src[2] - 2.5f * src[4]; + m[3] = tmp1 + tmp2 - 2.5f * src[3] + src[6]; + m[4] = tmp2 - tmp1 + 2.5f * src[3] + src[6]; + tmp1 = 0.375f * src[1] + 1.5f * src[5]; + tmp2 = 0.25f * src[2] - 1.25f * src[4]; + m[5] = tmp1 + tmp2 - 1.875f * src[3] + src[6]; + m[6] = tmp2 - tmp1 + 1.875f * src[3] + src[6]; + m[7] = -0.5625f * src[1] + 3.0625f * src[3] - 3.5f * src[5] + src[7]; + + float *dst = dst_data + l * dst_row_step; + for (int w = 0; w < 8; ++w) { + dst[i + w * dst_step] = m[w]; + } + } + } +#endif +} + +#ifdef ENABLE_ARM64 +void InputTransform8x8Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) { + LOAD_LINE_DATA(0); + LOAD_LINE_DATA(1); + LOAD_LINE_DATA(2); + LOAD_LINE_DATA(3); + LOAD_LINE_DATA(4); + LOAD_LINE_DATA(5); + LOAD_LINE_DATA(6); + LOAD_LINE_DATA(7); + + MS_FLOAT32X4 m0 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s00, 0.5625), MS_MULQ_N_F32(s20, 3.0625)), MS_MULQ_N_F32(s40, 3.5)), s60); + MS_FLOAT32X4 m1 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s01, 0.5625), MS_MULQ_N_F32(s21, 3.0625)), MS_MULQ_N_F32(s41, 3.5)), s61); + MS_FLOAT32X4 m2 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s02, 0.5625), MS_MULQ_N_F32(s22, 3.0625)), MS_MULQ_N_F32(s42, 3.5)), s62); + MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2); + + MS_FLOAT32X4 tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 1.125), MS_MULQ_N_F32(s50, 0.5)); + MS_FLOAT32X4 tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 1.125), MS_MULQ_N_F32(s51, 0.5)); + MS_FLOAT32X4 tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 1.125), MS_MULQ_N_F32(s52, 0.5)); + MS_FLOAT32X4 tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 2.25), MS_MULQ_N_F32(s40, 3.25)); + MS_FLOAT32X4 tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 2.25), MS_MULQ_N_F32(s41, 3.25)); + MS_FLOAT32X4 tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 2.25), MS_MULQ_N_F32(s42, 3.25)); + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 1.625)), s60); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 1.625)), s61); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 1.625)), s62); + MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 1.625)), s60); + m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 1.625)), s61); + m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 1.625)), s62); + MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2); + + tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 0.5625), s50); + tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 0.5625), s51); + tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 0.5625), s52); + tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 0.5625), MS_MULQ_N_F32(s40, 2.5)); + tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 0.5625), MS_MULQ_N_F32(s41, 2.5)); + tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 0.5625), MS_MULQ_N_F32(s42, 2.5)); + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 2.5)), s60); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 2.5)), s61); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 2.5)), s62); + MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 2.5)), s60); + m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 2.5)), s61); + m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 2.5)), s62); + MS_STQ_F32(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 4 * dst_step + 2 * pack_tile, m2); + + tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 0.375), MS_MULQ_N_F32(s50, 1.5)); + tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 0.375), MS_MULQ_N_F32(s51, 1.5)); + tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 0.375), MS_MULQ_N_F32(s52, 1.5)); + tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 0.25), MS_MULQ_N_F32(s40, 1.25)); + tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 0.25), MS_MULQ_N_F32(s41, 1.25)); + tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 0.25), MS_MULQ_N_F32(s42, 1.25)); + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 1.875)), s60); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 1.875)), s61); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 1.875)), s62); + MS_STQ_F32(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 5 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 1.875)), s60); + m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 1.875)), s61); + m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 1.875)), s62); + MS_STQ_F32(dst_ptr + 6 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 6 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 6 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s10, -0.5625), MS_MULQ_N_F32(s30, 3.0625)), MS_MULQ_N_F32(s50, 3.5)), s70); + m1 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s11, -0.5625), MS_MULQ_N_F32(s31, 3.0625)), MS_MULQ_N_F32(s51, 3.5)), s71); + m2 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s12, -0.5625), MS_MULQ_N_F32(s32, 3.0625)), MS_MULQ_N_F32(s52, 3.5)), s72); + MS_STQ_F32(dst_ptr + 7 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 7 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 7 * dst_step + 2 * pack_tile, m2); +} +#endif + +void InputTransform8x8Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 12; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; +#ifdef ENABLE_ARM64 + for (int l = 0; l < 8; ++l) { + float *src_ptr = src_data + l * C4NUM * block_tile; + TRANSPOSE_12x4; + } + + for (int c = 0; c < real_c; ++c) { + float *src_ptr = src_data + c * block_tile; + float *dst_ptr = dst_data + c * block_tile; + InputTransform8x8Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +#else + for (int l = 0; l < 8; ++l) { + float *src = src_data + l * pack_tile * block_tile; + // 12 * 4 -> 4 * 12 + float tmp_mat[4][12]; + for (int i = 0; i < block_tile; ++i) { + for (int j = 0; j < pack_tile; ++j) { + tmp_mat[j][i] = src[i * pack_tile + j]; + } + } + memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float)); + } + + float src[8]; + float m[8]; + for (int c = 0; c < real_c; ++c) { + for (int i = 0; i < block_tile; ++i) { + int tmp_index = c * block_tile + i; + for (int w = 0; w < 8; ++w) { + src[w] = src_data[tmp_index + w * src_point_stride]; + } + m[0] = 0.5625f * src[0] - 3.0625f * src[2] + 3.5f * src[4] - src[6]; + float tmp1 = 1.125f * src[1] + 0.5f * src[5]; + float tmp2 = 2.25f * src[2] - 3.25f * src[4]; + m[1] = tmp1 + tmp2 - 1.625f * src[3] + src[6]; + m[2] = tmp2 - tmp1 + 1.625f * src[3] + src[6]; + tmp1 = 0.5625f * src[1] + src[5]; + tmp2 = 0.5625f * src[2] - 2.5f * src[4]; + m[3] = tmp1 + tmp2 - 2.5f * src[3] + src[6]; + m[4] = tmp2 - tmp1 + 2.5f * src[3] + src[6]; + tmp1 = 0.375f * src[1] + 1.5f * src[5]; + tmp2 = 0.25f * src[2] - 1.25f * src[4]; + m[5] = tmp1 + tmp2 - 1.875f * src[3] + src[6]; + m[6] = tmp2 - tmp1 + 1.875f * src[3] + src[6]; + m[7] = -0.5625f * src[1] + 3.0625f * src[3] - 3.5f * src[5] + src[7]; + + for (int w = 0; w < 8; ++w) { + dst_data[tmp_index + w * dst_step] = m[w]; + } + } + } +#endif +} + +OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type) { + if (!CheckWinogradInputOutputUnit(input_unit, output_unit)) { + return NULL; + } + int in_index = (input_unit - 4) / 2; + int index = 0; + for (int i = 0; i < in_index; i++) { + index += ((i * 2 + 4) - 2) * 3; + } + int act_index; + if (act_type == ActType_Relu) { + act_index = 1; + } else if (act_type == ActType_Relu6) { + act_index = 2; + } else { + act_index = 0; + } + return OutputTransFuncList[index + (input_unit - 2) * act_index + output_unit - 2]; +} + +void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[8]; + MS_FLOAT32X4 m[4]; + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[8]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform4x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[8]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[8]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform4x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[8]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + m[l + 2] = MS_MINQ_F32(six, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[8]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[9]; + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADDQ_F32(src[offset], tmp); + t[l + 4] = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADDQ_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(tmp, t[3 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[12]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[0 + offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset]; + t[l + 8] = src[1 + offset] + src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset]; + m[l + 6] = t[1 + offset] + t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform4x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADDQ_F32(src[offset], tmp); + t[l + 4] = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADDQ_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(tmp, t[3 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[12]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[0 + offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset]; + t[l + 8] = src[1 + offset] + src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset]; + m[l + 6] = t[1 + offset] + t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform4x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADDQ_F32(src[offset], tmp); + t[l + 4] = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADDQ_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(tmp, t[3 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 3] = MS_MINQ_F32(six, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + m[l + 6] = MS_MINQ_F32(six, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[12]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[0 + offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset]; + t[l + 8] = src[1 + offset] + src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset]; + m[l + 6] = t[1 + offset] + t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[4]; + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[12]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform6x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[12]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + m[l + 2] = MS_MINQ_F32(six, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[12]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[18]; + MS_FLOAT32X4 m[9]; + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[18]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 6] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform6x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[18]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[18]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 6] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[18]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 3] = MS_MINQ_F32(six, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + m[l + 6] = MS_MINQ_F32(six, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[18]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 6] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[16]; + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[24]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 4] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 8] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 12] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform6x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[16]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); + m[l + 8] = MS_MAXQ_F32(zero, m[l + 8]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[24]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 4] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 8] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 12] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[16]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); + m[l + 4] = MS_MINQ_F32(six, m[l + 4]); + m[l + 8] = MS_MAXQ_F32(zero, m[l + 8]); + m[l + 8] = MS_MINQ_F32(six, m[l + 8]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + m[l + 12] = MS_MINQ_F32(six, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[24]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 4] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 8] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 12] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[30]; + MS_FLOAT32X4 m[25]; + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[30]; + float m[25]; + for (int i = 0; i < C4NUM; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]); + t[l + 24] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 5] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 10] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 15] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]); + m[l + 20] = t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform6x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[30]; + MS_FLOAT32X4 m[25]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); + m[l + 10] = MS_MAXQ_F32(zero, m[l + 10]); + m[l + 15] = MS_MAXQ_F32(zero, m[l + 15]); + m[l + 20] = MS_MAXQ_F32(zero, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[30]; + float m[25]; + for (int i = 0; i < C4NUM; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]); + t[l + 24] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 5] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 10] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 15] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]); + m[l + 20] = t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[30]; + MS_FLOAT32X4 m[25]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); + m[l + 5] = MS_MINQ_F32(six, m[l + 5]); + m[l + 10] = MS_MAXQ_F32(zero, m[l + 10]); + m[l + 10] = MS_MINQ_F32(six, m[l + 10]); + m[l + 15] = MS_MAXQ_F32(zero, m[l + 15]); + m[l + 15] = MS_MINQ_F32(six, m[l + 15]); + m[l + 20] = MS_MAXQ_F32(zero, m[l + 20]); + m[l + 20] = MS_MINQ_F32(six, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[30]; + float m[25]; + for (int i = 0; i < C4NUM; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]); + t[l + 24] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 5] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 10] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 15] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]); + m[l + 20] = t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[16]; + MS_FLOAT32X4 m[4]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[16]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 2] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[16]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[16]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 2] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[16]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + m[l + 2] = MS_MINQ_F32(six, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[16]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 2] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[9]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), + src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 6] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[24]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 3] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 6] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), + src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 6] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[24]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 3] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 6] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), + src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 6] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 3] = MS_MINQ_F32(six, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + m[l + 6] = MS_MINQ_F32(six, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[24]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 3] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 6] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[32]; + MS_FLOAT32X4 m[16]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[32]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 4] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 8] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 12] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[32]; + MS_FLOAT32X4 m[16]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); + m[l + 8] = MS_MAXQ_F32(zero, m[l + 8]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[32]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 4] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 8] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 12] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[32]; + MS_FLOAT32X4 m[16]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); + m[l + 4] = MS_MINQ_F32(six, m[l + 4]); + m[l + 8] = MS_MAXQ_F32(zero, m[l + 8]); + m[l + 8] = MS_MINQ_F32(six, m[l + 8]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + m[l + 12] = MS_MINQ_F32(six, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[32]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 4] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 8] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 12] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[40]; + MS_FLOAT32X4 m[25]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 10] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[40]; + float m[25]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 5] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 10] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 15] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 20] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[40]; + MS_FLOAT32X4 m[25]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 10] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); + m[l + 10] = MS_MAXQ_F32(zero, m[l + 10]); + m[l + 15] = MS_MAXQ_F32(zero, m[l + 15]); + m[l + 20] = MS_MAXQ_F32(zero, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[40]; + float m[25]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 5] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 10] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 15] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 20] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[40]; + MS_FLOAT32X4 m[25]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 10] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); + m[l + 5] = MS_MINQ_F32(six, m[l + 5]); + m[l + 10] = MS_MAXQ_F32(zero, m[l + 10]); + m[l + 10] = MS_MINQ_F32(six, m[l + 10]); + m[l + 15] = MS_MAXQ_F32(zero, m[l + 15]); + m[l + 15] = MS_MINQ_F32(six, m[l + 15]); + m[l + 20] = MS_MAXQ_F32(zero, m[l + 20]); + m[l + 20] = MS_MINQ_F32(six, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[40]; + float m[25]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 5] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 10] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 15] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 20] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[48]; + MS_FLOAT32X4 m[36]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[48]; + float m[36]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 6] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 12] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 18] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 24] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 30] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 6; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[48]; + MS_FLOAT32X4 m[36]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + m[l + 18] = MS_MAXQ_F32(zero, m[l + 18]); + m[l + 24] = MS_MAXQ_F32(zero, m[l + 24]); + m[l + 30] = MS_MAXQ_F32(zero, m[l + 30]); + } + if (r_c == C4NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[48]; + float m[36]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 6] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 12] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 18] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 24] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 30] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 6; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[48]; + MS_FLOAT32X4 m[36]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + m[l + 6] = MS_MINQ_F32(six, m[l + 6]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + m[l + 12] = MS_MINQ_F32(six, m[l + 12]); + m[l + 18] = MS_MAXQ_F32(zero, m[l + 18]); + m[l + 18] = MS_MINQ_F32(six, m[l + 18]); + m[l + 24] = MS_MAXQ_F32(zero, m[l + 24]); + m[l + 24] = MS_MINQ_F32(six, m[l + 24]); + m[l + 30] = MS_MAXQ_F32(zero, m[l + 30]); + m[l + 30] = MS_MINQ_F32(six, m[l + 30]); + } + if (r_c == C4NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[48]; + float m[36]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 6] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 12] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 18] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 24] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 30] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 6; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[56]; + MS_FLOAT32X4 m[49]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 14] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_STQ_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[56]; + float m[49]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]); + t[l + 48] = 0.015625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 11.390625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 7] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 14] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 21] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 28] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 35] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]); + m[l + 42] = 0.015625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 11.390625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 7; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[56]; + MS_FLOAT32X4 m[49]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 14] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 7] = MS_MAXQ_F32(zero, m[l + 7]); + m[l + 14] = MS_MAXQ_F32(zero, m[l + 14]); + m[l + 21] = MS_MAXQ_F32(zero, m[l + 21]); + m[l + 28] = MS_MAXQ_F32(zero, m[l + 28]); + m[l + 35] = MS_MAXQ_F32(zero, m[l + 35]); + m[l + 42] = MS_MAXQ_F32(zero, m[l + 42]); + } + if (r_c == C4NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_STQ_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[56]; + float m[49]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]); + t[l + 48] = 0.015625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 11.390625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 7] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 14] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 21] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 28] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 35] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]); + m[l + 42] = 0.015625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 11.390625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 7; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[56]; + MS_FLOAT32X4 m[49]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 14] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 7] = MS_MAXQ_F32(zero, m[l + 7]); + m[l + 7] = MS_MINQ_F32(six, m[l + 7]); + m[l + 14] = MS_MAXQ_F32(zero, m[l + 14]); + m[l + 14] = MS_MINQ_F32(six, m[l + 14]); + m[l + 21] = MS_MAXQ_F32(zero, m[l + 21]); + m[l + 21] = MS_MINQ_F32(six, m[l + 21]); + m[l + 28] = MS_MAXQ_F32(zero, m[l + 28]); + m[l + 28] = MS_MINQ_F32(six, m[l + 28]); + m[l + 35] = MS_MAXQ_F32(zero, m[l + 35]); + m[l + 35] = MS_MINQ_F32(six, m[l + 35]); + m[l + 42] = MS_MAXQ_F32(zero, m[l + 42]); + m[l + 42] = MS_MINQ_F32(six, m[l + 42]); + } + if (r_c == C4NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_STQ_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[56]; + float m[49]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]); + t[l + 48] = 0.015625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 11.390625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 7] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 14] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 21] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 28] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 35] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]); + m[l + 42] = 0.015625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 11.390625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 7; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_utils.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..00d4705b3bad72c17be3da1bf4f438acdf24b6f7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_utils.h @@ -0,0 +1,373 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_WINOGRAD_UTILS_H_ +#define MINDSPORE_NNACL_WINOGRAD_UTILS_H_ + +#ifdef ENABLE_ARM +#include +#endif +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +typedef void (*InputTransStepFunc)(const float *src_data, float *dst_data, int src_step, int dst_step, + int dst_row_step); + +typedef void (*InputTransPackFunc)(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +typedef struct TransFuncList { + InputTransFunc in_func_; + InputTransStepFunc in_step_func_; + InputTransPackFunc in_pack_func_; + OutputTransFunc out_func_; +} TransFuncList; + +#define Load16Data \ + src[0] = MS_LDQ_F32(src_data + 0 * src_step); \ + src[1] = MS_LDQ_F32(src_data + 1 * src_step); \ + src[2] = MS_LDQ_F32(src_data + 2 * src_step); \ + src[3] = MS_LDQ_F32(src_data + 3 * src_step); \ + src[4] = MS_LDQ_F32(src_data + 4 * src_step); \ + src[5] = MS_LDQ_F32(src_data + 5 * src_step); \ + src[6] = MS_LDQ_F32(src_data + 6 * src_step); \ + src[7] = MS_LDQ_F32(src_data + 7 * src_step); \ + src[8] = MS_LDQ_F32(src_data + 8 * src_step); \ + src[9] = MS_LDQ_F32(src_data + 9 * src_step); \ + src[10] = MS_LDQ_F32(src_data + 10 * src_step); \ + src[11] = MS_LDQ_F32(src_data + 11 * src_step); \ + src[12] = MS_LDQ_F32(src_data + 12 * src_step); \ + src[13] = MS_LDQ_F32(src_data + 13 * src_step); \ + src[14] = MS_LDQ_F32(src_data + 14 * src_step); \ + src[15] = MS_LDQ_F32(src_data + 15 * src_step); + +#define Load36Data \ + src[0] = MS_LDQ_F32(src_data + 0 * src_step); \ + src[1] = MS_LDQ_F32(src_data + 1 * src_step); \ + src[2] = MS_LDQ_F32(src_data + 2 * src_step); \ + src[3] = MS_LDQ_F32(src_data + 3 * src_step); \ + src[4] = MS_LDQ_F32(src_data + 4 * src_step); \ + src[5] = MS_LDQ_F32(src_data + 5 * src_step); \ + src[6] = MS_LDQ_F32(src_data + 6 * src_step); \ + src[7] = MS_LDQ_F32(src_data + 7 * src_step); \ + src[8] = MS_LDQ_F32(src_data + 8 * src_step); \ + src[9] = MS_LDQ_F32(src_data + 9 * src_step); \ + src[10] = MS_LDQ_F32(src_data + 10 * src_step); \ + src[11] = MS_LDQ_F32(src_data + 11 * src_step); \ + src[12] = MS_LDQ_F32(src_data + 12 * src_step); \ + src[13] = MS_LDQ_F32(src_data + 13 * src_step); \ + src[14] = MS_LDQ_F32(src_data + 14 * src_step); \ + src[15] = MS_LDQ_F32(src_data + 15 * src_step); \ + src[16] = MS_LDQ_F32(src_data + 16 * src_step); \ + src[17] = MS_LDQ_F32(src_data + 17 * src_step); \ + src[18] = MS_LDQ_F32(src_data + 18 * src_step); \ + src[19] = MS_LDQ_F32(src_data + 19 * src_step); \ + src[20] = MS_LDQ_F32(src_data + 20 * src_step); \ + src[21] = MS_LDQ_F32(src_data + 21 * src_step); \ + src[22] = MS_LDQ_F32(src_data + 22 * src_step); \ + src[23] = MS_LDQ_F32(src_data + 23 * src_step); \ + src[24] = MS_LDQ_F32(src_data + 24 * src_step); \ + src[25] = MS_LDQ_F32(src_data + 25 * src_step); \ + src[26] = MS_LDQ_F32(src_data + 26 * src_step); \ + src[27] = MS_LDQ_F32(src_data + 27 * src_step); \ + src[28] = MS_LDQ_F32(src_data + 28 * src_step); \ + src[29] = MS_LDQ_F32(src_data + 29 * src_step); \ + src[30] = MS_LDQ_F32(src_data + 30 * src_step); \ + src[31] = MS_LDQ_F32(src_data + 31 * src_step); \ + src[32] = MS_LDQ_F32(src_data + 32 * src_step); \ + src[33] = MS_LDQ_F32(src_data + 33 * src_step); \ + src[34] = MS_LDQ_F32(src_data + 34 * src_step); \ + src[35] = MS_LDQ_F32(src_data + 35 * src_step); + +#define Load64Data \ + src[0] = MS_LDQ_F32(src_data + 0 * src_step); \ + src[1] = MS_LDQ_F32(src_data + 1 * src_step); \ + src[2] = MS_LDQ_F32(src_data + 2 * src_step); \ + src[3] = MS_LDQ_F32(src_data + 3 * src_step); \ + src[4] = MS_LDQ_F32(src_data + 4 * src_step); \ + src[5] = MS_LDQ_F32(src_data + 5 * src_step); \ + src[6] = MS_LDQ_F32(src_data + 6 * src_step); \ + src[7] = MS_LDQ_F32(src_data + 7 * src_step); \ + src[8] = MS_LDQ_F32(src_data + 8 * src_step); \ + src[9] = MS_LDQ_F32(src_data + 9 * src_step); \ + src[10] = MS_LDQ_F32(src_data + 10 * src_step); \ + src[11] = MS_LDQ_F32(src_data + 11 * src_step); \ + src[12] = MS_LDQ_F32(src_data + 12 * src_step); \ + src[13] = MS_LDQ_F32(src_data + 13 * src_step); \ + src[14] = MS_LDQ_F32(src_data + 14 * src_step); \ + src[15] = MS_LDQ_F32(src_data + 15 * src_step); \ + src[16] = MS_LDQ_F32(src_data + 16 * src_step); \ + src[17] = MS_LDQ_F32(src_data + 17 * src_step); \ + src[18] = MS_LDQ_F32(src_data + 18 * src_step); \ + src[19] = MS_LDQ_F32(src_data + 19 * src_step); \ + src[20] = MS_LDQ_F32(src_data + 20 * src_step); \ + src[21] = MS_LDQ_F32(src_data + 21 * src_step); \ + src[22] = MS_LDQ_F32(src_data + 22 * src_step); \ + src[23] = MS_LDQ_F32(src_data + 23 * src_step); \ + src[24] = MS_LDQ_F32(src_data + 24 * src_step); \ + src[25] = MS_LDQ_F32(src_data + 25 * src_step); \ + src[26] = MS_LDQ_F32(src_data + 26 * src_step); \ + src[27] = MS_LDQ_F32(src_data + 27 * src_step); \ + src[28] = MS_LDQ_F32(src_data + 28 * src_step); \ + src[29] = MS_LDQ_F32(src_data + 29 * src_step); \ + src[30] = MS_LDQ_F32(src_data + 30 * src_step); \ + src[31] = MS_LDQ_F32(src_data + 31 * src_step); \ + src[32] = MS_LDQ_F32(src_data + 32 * src_step); \ + src[33] = MS_LDQ_F32(src_data + 33 * src_step); \ + src[34] = MS_LDQ_F32(src_data + 34 * src_step); \ + src[35] = MS_LDQ_F32(src_data + 35 * src_step); \ + src[36] = MS_LDQ_F32(src_data + 36 * src_step); \ + src[37] = MS_LDQ_F32(src_data + 37 * src_step); \ + src[38] = MS_LDQ_F32(src_data + 38 * src_step); \ + src[39] = MS_LDQ_F32(src_data + 39 * src_step); \ + src[40] = MS_LDQ_F32(src_data + 40 * src_step); \ + src[41] = MS_LDQ_F32(src_data + 41 * src_step); \ + src[42] = MS_LDQ_F32(src_data + 42 * src_step); \ + src[43] = MS_LDQ_F32(src_data + 43 * src_step); \ + src[44] = MS_LDQ_F32(src_data + 44 * src_step); \ + src[45] = MS_LDQ_F32(src_data + 45 * src_step); \ + src[46] = MS_LDQ_F32(src_data + 46 * src_step); \ + src[47] = MS_LDQ_F32(src_data + 47 * src_step); \ + src[48] = MS_LDQ_F32(src_data + 48 * src_step); \ + src[49] = MS_LDQ_F32(src_data + 49 * src_step); \ + src[50] = MS_LDQ_F32(src_data + 50 * src_step); \ + src[51] = MS_LDQ_F32(src_data + 51 * src_step); \ + src[52] = MS_LDQ_F32(src_data + 52 * src_step); \ + src[53] = MS_LDQ_F32(src_data + 53 * src_step); \ + src[54] = MS_LDQ_F32(src_data + 54 * src_step); \ + src[55] = MS_LDQ_F32(src_data + 55 * src_step); \ + src[56] = MS_LDQ_F32(src_data + 56 * src_step); \ + src[57] = MS_LDQ_F32(src_data + 57 * src_step); \ + src[58] = MS_LDQ_F32(src_data + 58 * src_step); \ + src[59] = MS_LDQ_F32(src_data + 59 * src_step); \ + src[60] = MS_LDQ_F32(src_data + 60 * src_step); \ + src[61] = MS_LDQ_F32(src_data + 61 * src_step); \ + src[62] = MS_LDQ_F32(src_data + 62 * src_step); \ + src[63] = MS_LDQ_F32(src_data + 63 * src_step); + +#define LOAD_LINE_DATA(line) \ + MS_FLOAT32X4 s##line##0 = MS_LDQ_F32(src_ptr + line * src_point_stride + 0 * pack_tile); \ + MS_FLOAT32X4 s##line##1 = MS_LDQ_F32(src_ptr + line * src_point_stride + 1 * pack_tile); \ + MS_FLOAT32X4 s##line##2 = MS_LDQ_F32(src_ptr + line * src_point_stride + 2 * pack_tile); + +#define TRANSPOSE_12x4 \ + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * pack_tile); \ + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 1 * pack_tile); \ + MS_FLOAT32X4 s6 = MS_LDQ_F32(src_ptr + 2 * pack_tile); \ + MS_FLOAT32X4 s9 = MS_LDQ_F32(src_ptr + 3 * pack_tile); \ + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 4 * pack_tile); \ + MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 5 * pack_tile); \ + MS_FLOAT32X4 s7 = MS_LDQ_F32(src_ptr + 6 * pack_tile); \ + MS_FLOAT32X4 s10 = MS_LDQ_F32(src_ptr + 7 * pack_tile); \ + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 8 * pack_tile); \ + MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 9 * pack_tile); \ + MS_FLOAT32X4 s8 = MS_LDQ_F32(src_ptr + 10 * pack_tile); \ + MS_FLOAT32X4 s11 = MS_LDQ_F32(src_ptr + 11 * pack_tile); \ + transpose4(&s0, &s3, &s6, &s9); \ + transpose4(&s1, &s4, &s7, &s10); \ + transpose4(&s2, &s5, &s8, &s11); \ + MS_STQ_F32(src_ptr + 0 * pack_tile, s0); \ + MS_STQ_F32(src_ptr + 1 * pack_tile, s1); \ + MS_STQ_F32(src_ptr + 2 * pack_tile, s2); \ + MS_STQ_F32(src_ptr + 3 * pack_tile, s3); \ + MS_STQ_F32(src_ptr + 4 * pack_tile, s4); \ + MS_STQ_F32(src_ptr + 5 * pack_tile, s5); \ + MS_STQ_F32(src_ptr + 6 * pack_tile, s6); \ + MS_STQ_F32(src_ptr + 7 * pack_tile, s7); \ + MS_STQ_F32(src_ptr + 8 * pack_tile, s8); \ + MS_STQ_F32(src_ptr + 9 * pack_tile, s9); \ + MS_STQ_F32(src_ptr + 10 * pack_tile, s10); \ + MS_STQ_F32(src_ptr + 11 * pack_tile, s11); + +InputTransFunc GetInputTransFunc(int input_unit); + +#ifdef ENABLE_ARM64 +InputTransStepFunc GetInputTransStepFunc(int input_unit); + +InputTransPackFunc GetInputTransPackFunc(int input_unit); +#endif + +void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform4x4Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step); + +void InputTransform4x4Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step); + +void InputTransform6x6Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step); + +void InputTransform8x8Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type); + +#define Store4Data \ + MS_STQ_F32(dst_data, m[0]); \ + MS_STQ_F32(dst_data + out_c, m[1]); \ + MS_STQ_F32(dst_data + dst_step * out_c, m[2]); \ + MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[3]); + +#define Store9Data \ + MS_STQ_F32(dst_data, m[0]); \ + MS_STQ_F32(dst_data + out_c, m[1]); \ + MS_STQ_F32(dst_data + 2 * out_c, m[2]); \ + MS_STQ_F32(dst_data + dst_step * out_c, m[3]); \ + MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[4]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[6]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); + +#define Store16Data \ + MS_STQ_F32(dst_data, m[0]); \ + MS_STQ_F32(dst_data + out_c, m[1]); \ + MS_STQ_F32(dst_data + 2 * out_c, m[2]); \ + MS_STQ_F32(dst_data + 3 * out_c, m[3]); \ + MS_STQ_F32(dst_data + dst_step * out_c, m[4]); \ + MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[5]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[8]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c, m[12]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); + +#define Store25Data \ + MS_STQ_F32(dst_data, m[0]); \ + MS_STQ_F32(dst_data + out_c, m[1]); \ + MS_STQ_F32(dst_data + 2 * out_c, m[2]); \ + MS_STQ_F32(dst_data + 3 * out_c, m[3]); \ + MS_STQ_F32(dst_data + 4 * out_c, m[4]); \ + MS_STQ_F32(dst_data + dst_step * out_c, m[5]); \ + MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[6]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[10]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c, m[15]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ + MS_STQ_F32(dst_data + 4 * dst_step * out_c, m[20]); \ + MS_STQ_F32(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ + MS_STQ_F32(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ + MS_STQ_F32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ + MS_STQ_F32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); + +void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +int SelectOutputUnit(const ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_WINOGRAD_UTILS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c84f5317d49765e19c23a3dfb5049bca1a83641d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_fp32.c @@ -0,0 +1,161 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/fp32_grad/activation_grad_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/activation_grad_simd.h" + +int ReluGrad(const float *src0, const float *src1, int length, float *dst) { + int i = 0; +#ifdef ENABLE_ARM + float32x4_t zero_4 = vdupq_n_f32(0.0f); + for (; i < length - C4NUM; i += C4NUM) { + float32x4_t src1_4 = vld1q_f32(src1 + i); + float32x4_t src0_4 = vld1q_f32(src0 + i); + uint32x4_t mask_4 = vcleq_f32(src1_4, zero_4); + float32x4_t dst_4 = vbslq_f32(mask_4, zero_4, src0_4); + vst1q_f32(dst + i, dst_4); + } +#endif + for (; i < length; ++i) { + dst[i] = (src1[i] > 0.0f) ? src0[i] : 0.0f; + } + return NNACL_OK; +} + +int Relu6Grad(const float *src0, const float *src1, size_t length, float *dst) { + size_t i = 0; +#ifdef ENABLE_ARM + float32x4_t zero_4 = vdupq_n_f32(0.0f); + float32x4_t six_4 = vdupq_n_f32(6.0f); + for (; i < length - C4NUM; i += C4NUM) { + float32x4_t src1_4 = vld1q_f32(src1 + i); + float32x4_t src0_4 = vld1q_f32(src0 + i); + uint32x4_t gt_4 = vcgtq_f32(src1_4, zero_4); + uint32x4_t le_4 = vcleq_f32(src1_4, six_4); + uint32x4_t mask_4 = vandq_u32(gt_4, le_4); + float32x4_t dst_4 = vbslq_f32(mask_4, src0_4, zero_4); + vst1q_f32(dst + i, dst_4); + } +#endif + for (; i < length; ++i) { + dst[i] = (src1[i] > 0.0f && src1[i] <= 6.0f) ? src0[i] : 0.0f; + } + return NNACL_OK; +} + +int LReluGrad(const float *src0, const float *src1, size_t length, float *dst, float alpha) { + for (size_t i = 0; i < length; ++i) { + dst[i] = src1[i] > 0.0f ? src0[i] : alpha * src0[i]; + } + return NNACL_OK; +} + +int SigmoidGrad(const float *src0, const float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); + } + return NNACL_OK; +} + +int TanhGrad(const float *src0, const float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + dst[i] = (1.0f - (src1[i] * src1[i])) * src0[i]; + } + return NNACL_OK; +} + +int HSwishGrad(const float *src0, const float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : (2.0f * src1[i] + 3.0f) / 6.0f)); + dst[i] = tmp * src0[i]; + } + return NNACL_OK; +} + +int HSigmoidGrad(const float *src0, const float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + float tmp = (src1[i] > 3.0f ? 0.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f)); + dst[i] = tmp * src0[i]; + } + return NNACL_OK; +} + +int EluGrad(const float *src0, const float *src1, size_t length, float *dst, float alpha) { + for (size_t i = 0; i < length; ++i) { + dst[i] = (src1[i] > 0.0f ? src0[i] : alpha * expm1(src1[i]) * src0[i]); + } + return NNACL_OK; +} + +int GeluGrad(const float *src0, const float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + dst[i] = src0[i] * ((0.5 * (1.0 + erf(src1[i] / 1.4142135623730951))) + + (src1[i] * exp(-0.5 * src1[i] * src1[i]) / 2.5066282746)); + } + return NNACL_OK; +} + +int SoftplusGrad(const float *src0, const float *src1, int length, float *dst) { + int i = 0; +#if defined(ENABLE_AVX) + for (; i <= length - C8NUM; i += C8NUM) { + simd_exp256(MS_SUB256_F32(MS_MOV256_F32(0.0f), (MS_LD256_F32(src1 + i))), dst + i); + MS_ST256_F32(dst + i, + MS_DIV256_F32(MS_LD256_F32(src0 + i), MS_ADD256_F32(MS_MOV256_F32(1.0f), MS_LD256_F32(dst + i)))); + } +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + for (; i <= length - C4NUM; i += C4NUM) { + simd_exp128(MS_SUBQ_F32(MS_MOVQ_F32(0.0f), MS_LDQ_F32(src1 + i)), dst + i); + MS_STQ_F32(dst + i, MS_DIVQ_F32(MS_LDQ_F32(src0 + i), MS_ADDQ_F32(MS_MOVQ_F32(1.0f), MS_LDQ_F32(dst + i)))); + } +#endif + + for (; i < length; ++i) { + simd_exp32(-src1[i], dst + i); + dst[i] = src0[i] / (1.0f + dst[i]); + } + return NNACL_OK; +} + +int HardShrinkGrad(const float *src0, const float *src1, int length, float *dst, float lambd) { + int i = 0; + const float neg_lambd = -1 * lambd; + SIMD_RUN_NO_SCALAR(ShrinkGrad, i, src0, src1, length, dst, lambd); + + for (; i < length; ++i) { + dst[i] = (src1[i] >= neg_lambd && src1[i] <= lambd) ? 0 : src0[i]; + } + return NNACL_OK; +} + +int SoftShrinkGrad(const float *src0, const float *src1, int length, float *dst, float lambd) { + int i = 0; + const float neg_lambd = -1 * lambd; + SIMD_RUN_NO_SCALAR(ShrinkGrad, i, src0, src1, length, dst, lambd); + + for (; i < length; ++i) { + dst[i] = (src1[i] >= neg_lambd && src1[i] <= lambd) ? 0 : src0[i]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..ab2933c68f5e90fb2b0f3013ba736e7a279549dc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_fp32.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_ACTIVATION_GRAD_FP32_H_ +#define NNACL_FP32_GRAD_ACTIVATION_GRAD_FP32_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/errorcode.h" + +typedef struct ActivationGradParameter { + OpParameter op_parameter; + int type_; + float alpha_; +} ActivationGradParameter; +#ifdef __cplusplus +extern "C" { +#endif + +int ReluGrad(const float *src0, const float *src1, int length, float *dst); +int Relu6Grad(const float *src0, const float *src1, size_t length, float *dst); +int LReluGrad(const float *src0, const float *src1, size_t length, float *dst, float alpha); +int SigmoidGrad(const float *src0, const float *src1, size_t length, float *dst); +int TanhGrad(const float *src0, const float *src1, size_t length, float *dst); +int HSwishGrad(const float *src0, const float *src1, size_t length, float *dst); +int HSigmoidGrad(const float *src0, const float *src1, size_t length, float *dst); +int EluGrad(const float *src0, const float *src1, size_t length, float *dst, float alpha); +int GeluGrad(const float *src0, const float *src1, size_t length, float *dst); +int SoftplusGrad(const float *src, const float *src1, int length, float *dst); +int HardShrinkGrad(const float *src0, const float *src1, int length, float *dst, float lambd); +int SoftShrinkGrad(const float *src0, const float *src1, int length, float *dst, float lambd); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_ACTIVATION_GRAD_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..1f98839b0f24202e681c53509b671099b3e6031e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_simd.h.in @@ -0,0 +1,50 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_ACTIVATION_GRAD_@SIMD_INSTRUCTION@_H_ +#define NNACL_FP32_GRAD_ACTIVATION_GRAD_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int ShrinkGrad@SIMD_INSTRUCTION@(int index, const float *src0, const float *src1, + int length, float *dst, float lambd) { + SIMD_F32 pos_lamdb_v = SIMD_MOV_F32(lambd); + SIMD_F32 neg_lamdb_v = SIMD_MOV_F32(-lambd); + + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src0_t = SIMD_LD_F32(src0 + index); + SIMD_F32 src1_t = SIMD_LD_F32(src1 + index); + + SIMD_MASK mask0 = SIMD_CMPLE_F32(src1_t, pos_lamdb_v); + SIMD_MASK mask1 = SIMD_CMPLE_F32(neg_lamdb_v, src1_t); + SIMD_MASK mask = SIMD_AND_MASK(mask0, mask1); + + SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src0_t, SIMD_MOV_F32(0.0f), mask)); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..92d494f070008eed3eec8227ec6167c22fdde7eb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.c @@ -0,0 +1,48 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/apply_proximal_adagrad_fp32_simd.h" + +int Sign(float x) { + if (x > 0) { + return 1; + } + if (x < 0) { + return -1; + } + return 0; +} + +void ApplyProximalAdagradOpt(float *var, float *accum, float lr, float l1, float l2, float *grad, + int64_t input_elements) { + int64_t i = 0; + + SIMD_RUN_NO_SCALAR(ApplyProximalAdagradOpt, i, var, accum, lr, l1, l2, grad, input_elements); + + for (; i < input_elements; ++i) { + accum[i] += grad[i] * grad[i]; + float learning_rate = lr / sqrt(accum[i]); + float prox_v = var[i]; + prox_v -= grad[i] * learning_rate; + + if (l1 > 0) { + var[i] = Sign(prox_v) * fmax(fabs(prox_v) - learning_rate * l1, 0.0) / (1 + l2 * learning_rate); + } else { + var[i] = prox_v / (1 + l2 * learning_rate); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..00bc67a17a23148e83a5a0e46d906274ec70f059 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_APPLY_PROXIMAL_ADAGRAD_H_ +#define NNACL_FP32_APPLY_PROXIMAL_ADAGRAD_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ApplyProximalAdagradOpt(float *var, float *accum, float lr, float l1, float l2, float *grad, + int64_t input_elements); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_APPLY_PROXIMAL_ADAGRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..70a434291aef339fdf3c84cb95cfcd561ee6fc80 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32_simd.h.in @@ -0,0 +1,68 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_APPLY_PROXIMAL_ADAGRAD_@SIMD_INSTRUCTION@_H_ +#define NNACL_FP32_APPLY_PROXIMAL_ADAGRAD_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t ApplyProximalAdagradOpt@SIMD_INSTRUCTION@( + int64_t index, float *var, float *accum, float lr, float l1, float l2, float *grad, int64_t size) { + SIMD_F32 lr_vec = SIMD_MOV_F32(lr); + SIMD_F32 l1_vec = SIMD_MOV_F32(l1); + SIMD_F32 l2_vec = SIMD_MOV_F32(l2); + for (int64_t block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 tmp_vec1 = SIMD_LD_F32(grad + index); + SIMD_F32 accum_vec = SIMD_LD_F32(accum + index); + SIMD_F32 prox_v_vec = SIMD_LD_F32(var + index); + + accum_vec = SIMD_FMADD_F32(tmp_vec1, tmp_vec1, accum_vec); + SIMD_F32 learn_rate_vec = SIMD_DIV_F32(lr_vec, SIMD_SQRT_F32(accum_vec)); + prox_v_vec = SIMD_SUB_F32(prox_v_vec, SIMD_MUL_F32(tmp_vec1, learn_rate_vec)); + SIMD_ST_F32(accum + index, accum_vec); + tmp_vec1 = SIMD_FMADD_F32(l2_vec, learn_rate_vec, SIMD_MOV_F32(1)); + if (l1 > 0) { + learn_rate_vec = SIMD_MUL_F32(learn_rate_vec, l1_vec); + learn_rate_vec = SIMD_SUB_F32(SIMD_ABS_F32(prox_v_vec), learn_rate_vec); + learn_rate_vec = SIMD_MAX_F32(learn_rate_vec, SIMD_MOV_F32(0.0f)); + learn_rate_vec = SIMD_DIV_F32(learn_rate_vec, tmp_vec1); + + SIMD_MASK greater_mask = SIMD_CMPGT_F32(SIMD_SET0_F32, prox_v_vec); + SIMD_MASK less_mask = SIMD_CMPLT_F32(SIMD_SET0_F32, prox_v_vec); + SIMD_F32 greater_v = SIMD_BLEND_F32(SIMD_MOV_F32(1), SIMD_SET0_F32, greater_mask); + SIMD_F32 less_v = SIMD_BLEND_F32(SIMD_MOV_F32(1), SIMD_SET0_F32, less_mask); + greater_v = SIMD_SUB_F32(greater_v, less_v); + + prox_v_vec = SIMD_MUL_F32(learn_rate_vec, greater_v); + } else { + prox_v_vec = SIMD_DIV_F32(prox_v_vec, tmp_vec1); + } + SIMD_ST_F32(var + index, prox_v_vec); + } + + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..82c6624552768ed338919613aa796146546edda6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.c @@ -0,0 +1,44 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/apply_proximal_gradient_descent_fp32_simd.h" + +void ApplyProximalGradientDescentOpt(float *var, float alpha, float l1, float l2, float *delta, + int64_t input_elements) { + int64_t i = 0; + SIMD_RUN_NO_SCALAR(ApplyProximalGradientDescentOpt, i, var, alpha, l1, l2, delta, input_elements); + for (; i < input_elements; ++i) { + float prox_v = var[i]; + prox_v -= delta[i] * alpha; + + if (l1 > 0) { + var[i] = SignFp32(prox_v) * fmax(fabs(prox_v) - alpha * l1, 0.0) / (1 + l2 * alpha); + } else { + var[i] = prox_v / (1 + l2 * alpha); + } + } +} + +float SignFp32(const float x) { + if (x > 0.0) { + return 1.0; + } + if (x < 0.0) { + return -1.0; + } + return 0.0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..3e519010006b2bc042e94d5faf09bac37d3bb5f9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_APPLY_PROXIMAL_GRADIENT_DESCENT_H_ +#define NNACL_FP32_APPLY_PROXIMAL_GRADIENT_DESCENT_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ApplyProximalGradientDescentOpt(float *var, float alpha, float l1, float l2, float *delta, int64_t input_elements); +float SignFp32(const float x); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_APPLY_PROXIMAL_GRADIENT_DESCENT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..f885665b93d2ddeaf14b7248b3c0512eba2c1e8b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32_simd.h.in @@ -0,0 +1,64 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_APPLY_PROXIMAL_GRADIENT_DESCENT_@SIMD_INSTRUCTION@_H_ +#define NNACL_FP32_APPLY_PROXIMAL_GRADIENT_DESCENT_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t ApplyProximalGradientDescentOpt@SIMD_INSTRUCTION@( + int64_t index, float *var, float alpha, float l1, float l2, float *delta, int64_t size) { + SIMD_F32 alpha_vec = SIMD_MOV_F32(alpha); + SIMD_F32 l1_vec = SIMD_MOV_F32(l1); + SIMD_F32 l2_vec = SIMD_MOV_F32(l2); + for (int64_t block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 delta_vec = SIMD_LD_F32(delta + index); + SIMD_F32 prox_v_vec = SIMD_LD_F32(var + index); + + prox_v_vec = SIMD_SUB_F32(prox_v_vec, SIMD_MUL_F32(delta_vec, alpha_vec)); + SIMD_F32 tmp_vec1 = SIMD_FMADD_F32(l2_vec, alpha_vec, SIMD_MOV_F32(1)); + if (l1 > 0) { + SIMD_F32 tmp_vec2 = SIMD_MUL_F32(alpha_vec, l1_vec); + tmp_vec2 = SIMD_SUB_F32(SIMD_ABS_F32(prox_v_vec), tmp_vec2); + tmp_vec2 = SIMD_MAX_F32(tmp_vec2, SIMD_MOV_F32(0.0f)); + tmp_vec2 = SIMD_DIV_F32(tmp_vec2, tmp_vec1); + + SIMD_MASK greater_mask = SIMD_CMPGT_F32(SIMD_SET0_F32, prox_v_vec); + SIMD_MASK less_mask = SIMD_CMPLT_F32(SIMD_SET0_F32, prox_v_vec); + SIMD_F32 greater_v = SIMD_BLEND_F32(SIMD_MOV_F32(1), SIMD_SET0_F32, greater_mask); + SIMD_F32 less_v = SIMD_BLEND_F32(SIMD_MOV_F32(1), SIMD_SET0_F32, less_mask); + greater_v = SIMD_SUB_F32(greater_v, less_v); + + prox_v_vec = SIMD_MUL_F32(tmp_vec2, greater_v); + } else { + prox_v_vec = SIMD_DIV_F32(prox_v_vec, tmp_vec1); + } + SIMD_ST_F32(var + index, prox_v_vec); + } + + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/arithmetic_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/arithmetic_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..0cb6d2015b4476a0566bd4c53e9a085bc4436f5c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/arithmetic_grad.c @@ -0,0 +1,154 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/arithmetic_grad.h" +#include +#include +#include "nnacl_c/fp32_grad/utils.h" +#include "nnacl_c/errorcode.h" + +void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = -nom[i] / (denom[i] * denom[i]); + } +} + +void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = -a[i] * b[i] / (denom[i] * denom[i]); + } +} + +int ElementAbsGrad(const float *in1, const float *in2, float *out, int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = (in1[i] < 0.f) ? -in2[i] : ((in1[i] > 0.f) ? in2[i] : 0); + } + return NNACL_OK; +} + +void MaximumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims) { + int num_output0 = 1; + int num_output1 = 1; + bool same_shape = true; + for (int idx = 0; idx < num_dims; ++idx) { + num_output0 *= input0_dims[idx]; + num_output1 *= input1_dims[idx]; + if (input0_dims[idx] != input1_dims[idx]) { + same_shape = false; + } + } + + if (same_shape) { + int input_iter[C8NUM] = {0}; + + // Iterate through input_data. + do { + size_t offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset] = input0[offset] > input1[offset] ? dy[offset] : 0.; + output1[offset] = input1[offset] >= input0[offset] ? dy[offset] : 0.; + } while (NextIndex(num_dims, input0_dims, input_iter)); + } else { + memset(output0, 0, num_output0 * sizeof(float)); // zero output + memset(output1, 0, num_output1 * sizeof(float)); // zero output + + int input_iter[C8NUM] = {0}; + int axes0[C5NUM] = {0}; + int axes1[C5NUM] = {0}; + int num_axes0 = 0; + int num_axes1 = 0; + for (int i = 0; i < num_dims; i++) { + if (input0_dims[i] == 1 && num_axes0 < C5NUM) { + axes0[num_axes0++] = i; + } + if (input1_dims[i] == 1 && num_axes1 < C5NUM) { + axes1[num_axes1++] = i; + } + } + + do { + size_t offset0 = GetOutputOffset(num_dims, input0_dims, input_iter, num_axes0, axes0); + size_t offset1 = GetOutputOffset(num_dims, input1_dims, input_iter, num_axes1, axes1); + size_t yt_offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset0] += input0[offset0] > input1[offset1] ? dy[yt_offset] : 0.; + output1[offset1] += input1[offset1] >= input0[offset0] ? dy[yt_offset] : 0.; + } while (NextIndex(num_dims, dy_dims, input_iter)); + } +} + +void MinimumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims) { + int num_output0 = 1; + int num_output1 = 1; + bool same_shape = true; + for (int idx = 0; idx < num_dims; ++idx) { + num_output0 *= input0_dims[idx]; + num_output1 *= input1_dims[idx]; + if (input0_dims[idx] != input1_dims[idx]) { + same_shape = false; + } + } + + if (same_shape) { + int input_iter[C8NUM] = {0}; + + // Iterate through input_data. + do { + size_t offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset] = input0[offset] < input1[offset] ? dy[offset] : 0.; + output1[offset] = input1[offset] <= input0[offset] ? dy[offset] : 0.; + } while (NextIndex(num_dims, input0_dims, input_iter)); + } else { + memset(output0, 0, num_output0 * sizeof(float)); // zero output + memset(output1, 0, num_output1 * sizeof(float)); // zero output + + int input_iter[C8NUM] = {0}; + int axes0[C5NUM] = {0}; + int axes1[C5NUM] = {0}; + int num_axes0 = 0; + int num_axes1 = 0; + for (int i = 0; i < num_dims; i++) { + if (input0_dims[i] == 1 && num_axes0 < C5NUM) { + axes0[num_axes0++] = i; + } + if (input1_dims[i] == 1 && num_axes1 < C5NUM) { + axes1[num_axes1++] = i; + } + } + + do { + size_t offset0 = GetOutputOffset(num_dims, input0_dims, input_iter, num_axes0, axes0); + size_t offset1 = GetOutputOffset(num_dims, input1_dims, input_iter, num_axes1, axes1); + size_t yt_offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset0] += input0[offset0] < input1[offset1] ? dy[yt_offset] : 0.; + output1[offset1] += input1[offset1] <= input0[offset0] ? dy[yt_offset] : 0.; + } while (NextIndex(num_dims, dy_dims, input_iter)); + } +} + +int ElementSqrtGrad(const float *in1, const float *in2, float *out, const int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = 0.5f * in2[i] / in1[i]; + } + return NNACL_OK; +} + +int ElementRsqrtGrad(const float *in1, const float *in2, float *out, const int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = -0.5f * in2[i] * in1[i] * in1[i] * in1[i]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/arithmetic_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/arithmetic_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..b46b72dd87f2257ba4a1d152016fa1469afcc1c4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/arithmetic_grad.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_ +#define NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size); +void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size); +int ElementAbsGrad(const float *in1, const float *in2, float *out, int element_size); +void MaximumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims); +void MinimumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims); +int ElementSqrtGrad(const float *in1, const float *in2, float *out, const int element_size); +int ElementRsqrtGrad(const float *in1, const float *in2, float *out, const int element_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..3787acc5277cc2649981b6781921749de772f2d9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_grad.c @@ -0,0 +1,100 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "nnacl_c/fp32_grad/batch_norm_grad.h" + +void var2Invar(float *save_var, int size, float eps) { + for (int i = 0; i < size; i++) { + save_var[i] = 1.0f / sqrtf(save_var[i] + eps); + } +} + +static void backwardComputeDx(const float *in, const float *yt, const float *mean, const float *invar, + const float *scale, int size, int ch, const float *dbias, const float *dscale, float *dx, + float N, bool is_train) { + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + // dx_2 + int ix = i * ch + c; + dx[ix] = yt[ix]; + if (is_train) { + dx[ix] -= dbias[c] / N + (in[ix] - mean[c]) * dscale[c] * invar[c] / N; + } + dx[ix] *= scale[c] * invar[c]; + } + } +} + +#ifdef _MSC_VER +void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, + int ch, float *dbias, float *dscale, float *dx, bool is_train) { +#else +void backwardAll(const float *restrict in, const float *restrict yt, const float *restrict mean, + const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dbias, + float *restrict dscale, float *restrict dx, bool is_train) { +#endif + NNACL_CHECK_ZERO_RETURN(size); + float N = (float)size; + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + int ix = i * ch + c; + dbias[c] += yt[ix]; + // in fact, x_hat should also mul invar[c]. now put this step to the end. + float x_hat = in[ix] - mean[c]; + dscale[c] += (yt[ix] * x_hat); + } + } + for (int c = 0; c < ch; c++) { + dscale[c] *= invar[c]; + } + backwardComputeDx(in, yt, mean, invar, scale, size, ch, dbias, dscale, dx, N, is_train); +} + +#ifdef _MSC_VER +void backwardP1(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, + int ch, float *dbias, float *dscale) { +#else +void backwardP1(const float *restrict in, const float *restrict yt, const float *restrict mean, + const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dbias, + float *restrict dscale) { +#endif + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + int ix = i * ch + c; + dbias[c] += yt[ix]; + // in fact, x_hat should also mul invar[c]. now put this step to the end. + float x_hat = in[ix] - mean[c]; + dscale[c] += (yt[ix] * x_hat); + } + } + for (int c = 0; c < ch; c++) { + dscale[c] *= invar[c]; + } +} + +#ifdef _MSC_VER +void backwardP2(const float *in, const float *yt, const float *mean, const float *invar, const float *dscale, + const float *dbias, const float *scale, int size, int total_size, int ch, float *dx, bool is_train) { +#else +void backwardP2(const float *restrict in, const float *restrict yt, const float *restrict mean, + const float *restrict invar, const float *restrict dscale, const float *restrict dbias, + const float *restrict scale, int size, int total_size, int ch, float *restrict dx, bool is_train) { +#endif + NNACL_CHECK_ZERO_RETURN(total_size); + const float N = (float)total_size; + backwardComputeDx(in, yt, mean, invar, scale, size, ch, dbias, dscale, dx, N, is_train); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..a8c03fb40401e323b6230f1761224b032a9dc6d2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_grad.h @@ -0,0 +1,37 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_H_ +#define CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_H_ + +#include "nnacl_c/fp32_grad/batch_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void var2Invar(float *save_var, int size, float eps); +void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, + int ch, float *dbias, float *dscale, float *dx, bool is_train); +void backwardP1(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, + int ch, float *dbias, float *dscale); +void backwardP2(const float *in, const float *yt, const float *mean, const float *invar, const float *dscale, + const float *dbias, const float *scale, int size, int total_size, int ch, float *dx, bool is_train); +#ifdef __cplusplus +} +#endif + +#endif // CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..271a5acdddabecae6f3d71287303622b7edd26aa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_BATCH_NORM_PARAMETER_H_ +#define NNACL_FP32_GRAD_BATCH_NORM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct BNGradParameter { + OpParameter op_parameter_; + float epsilon_; + bool is_training_; +} BNGradParameter; + +#endif // NNACL_FP32_GRAD_BATCH_NORM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy.c new file mode 100644 index 0000000000000000000000000000000000000000..ed422911f86d6b13f173f42b288651c5640430f3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy.c @@ -0,0 +1,75 @@ +/* + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/fp32_grad/binary_cross_entropy.h" + +static void BinaryCrossEntropyLossKernel(const int input_size, const int reduction, const float *input_x, + const float *input_y, const float *weight, float *loss, float *tmp_loss, + bool weight_defined) { + const float epsilon = 1e-12; + + if (reduction == Reduction_None) { + if (weight_defined) { + for (int i = 0; i < input_size; i++) { + float value = + -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); + loss[i] = value; + } + } else { + for (int i = 0; i < input_size; i++) { + float value = -(input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); + loss[i] = value; + } + } + } else { + if (weight_defined) { + for (int i = 0; i < input_size; i++) { + float value = + -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); + tmp_loss[i] = value; + } + } else { + for (int i = 0; i < input_size; i++) { + float value = -(input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); + tmp_loss[i] = value; + } + } + } +} + +void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y, + const float *weight, float *loss, float *tmp_loss, bool weight_defined) { + loss[0] = 0.0f; + BinaryCrossEntropyLossKernel(input_size, reduction, input_x, input_y, weight, loss, tmp_loss, weight_defined); + if (reduction != Reduction_None) { + if (input_size % 2 == 1) { + tmp_loss[0] += tmp_loss[input_size - 1]; + } + for (int stride = input_size / 2; stride > 0; stride = stride / 2) { + for (int i = 0; i < stride; i++) { + tmp_loss[i] += tmp_loss[i + stride]; + } + if (stride > 2 && stride % 2 == 1) { + tmp_loss[0] += tmp_loss[stride - 1]; + } + } + loss[0] += tmp_loss[0]; + if (reduction == Reduction_Mean) { + loss[0] /= input_size; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy.h new file mode 100644 index 0000000000000000000000000000000000000000..348d130fda9c1bb6a1b5898300d53b3ea9e1bdc2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BINARY_CROSS_ENTROPY_H_ +#define NNACL_BINARY_CROSS_ENTROPY_H_ + +#include "nnacl_c/op_base.h" + +typedef struct { + OpParameter op_parameter_; + int reduction; +} BinaryCrossEntropyParameter; + +#ifdef __cplusplus +extern "C" { +#endif + +void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y, + const float *weight, float *loss, float *tmp_loss, bool weight_defined); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_BINARY_CROSS_ENTROPY_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..c4d1775408215eba8bfe7255aab2df35dd319157 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy_grad.c @@ -0,0 +1,56 @@ +/* + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/binary_cross_entropy_grad.h" + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y, + const float *weight, const float *dloss, float *dx, bool weight_defined) { + const float epsilon = 1e-12f; + if (reduction == Reduction_None) { + if (weight_defined) { + for (int i = 0; i < input_size; i++) { + float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); + float value = weight[i] * (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss[i]; + } + } else { + for (int i = 0; i < input_size; i++) { + float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); + float value = (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss[i]; + } + } + } else { + float dloss1 = dloss[0]; + if (reduction == Reduction_Mean) { + dloss1 = dloss[0] / input_size; + } + for (int i = 0; i < input_size; i++) { + if (weight_defined) { + float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); + float value = weight[i] * (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss1; + } else { + float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); + float value = (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss1; + } + } + } + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..bc8c3dbbaf539910f999dd7eea11aaa0f0ef5b00 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy_grad.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BINARY_CROSS_ENTROPY_GRAD_H_ +#define NNACL_BINARY_CROSS_ENTROPY_GRAD_H_ + +#include "nnacl_c/op_base.h" + +typedef struct { + OpParameter op_parameter_; + int reduction; +} BinaryCrossEntropyGradParameter; + +#ifdef __cplusplus +extern "C" { +#endif + +int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y, + const float *weight, const float *dloss, float *dx, bool weight_defined); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_BINARY_CROSS_ENTROPY_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_filter.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_filter.c new file mode 100644 index 0000000000000000000000000000000000000000..b6de2e7629e884e9b36134b712e75753e3dd9ad5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_filter.c @@ -0,0 +1,380 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/convolution_grad_filter.h" +#include "nnacl_c/errorcode.h" +#ifdef ENABLE_ARM +#include +#endif + +#ifdef ENABLE_ARM +static int FilterGrad16Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~15); i_c += 16) { + float32x4_t sum_03_4 = vdupq_n_f32(0.0f); + float32x4_t sum_47_4 = vdupq_n_f32(0.0f); + float32x4_t sum_9x_4 = vdupq_n_f32(0.0f); + float32x4_t sum_12x_4 = vdupq_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy); + sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4); + + float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4); + float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4); + sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4); + + float32x4_t x_9x_4 = vld1q_f32(x_addr + offset_x + 8); + float32x4_t dy_9x_4 = vld1q_f32(dy_addr + offset_dy + 8); + sum_9x_4 = vmlaq_f32(sum_9x_4, x_9x_4, dy_9x_4); + + float32x4_t x_12x_4 = vld1q_f32(x_addr + offset_x + 12); + float32x4_t dy_12x_4 = vld1q_f32(dy_addr + offset_dy + 12); + sum_12x_4 = vmlaq_f32(sum_12x_4, x_12x_4, dy_12x_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3]; + + dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0]; + dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1]; + dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2]; + dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3]; + + dw[(i_c + 8) * k_spatial + k_idx] = sum_9x_4[0]; + dw[(i_c + 9) * k_spatial + k_idx] = sum_9x_4[1]; + dw[(i_c + 10) * k_spatial + k_idx] = sum_9x_4[2]; + dw[(i_c + 11) * k_spatial + k_idx] = sum_9x_4[3]; + + dw[(i_c + 12) * k_spatial + k_idx] = sum_12x_4[0]; + dw[(i_c + 13) * k_spatial + k_idx] = sum_12x_4[1]; + dw[(i_c + 14) * k_spatial + k_idx] = sum_12x_4[2]; + dw[(i_c + 15) * k_spatial + k_idx] = sum_12x_4[3]; + } + return i_c; +} + +static int FilterGrad12Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + if ((out_ch - i_c) >= 12) { + float32x4_t sum_03_4 = vdupq_n_f32(0.0f); + float32x4_t sum_47_4 = vdupq_n_f32(0.0f); + float32x4_t sum_9x_4 = vdupq_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy); + sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4); + + float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4); + float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4); + sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4); + + float32x4_t x_9x_4 = vld1q_f32(x_addr + offset_x + 8); + float32x4_t dy_9x_4 = vld1q_f32(dy_addr + offset_dy + 8); + sum_9x_4 = vmlaq_f32(sum_9x_4, x_9x_4, dy_9x_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3]; + + dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0]; + dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1]; + dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2]; + dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3]; + + dw[(i_c + 8) * k_spatial + k_idx] = sum_9x_4[0]; + dw[(i_c + 9) * k_spatial + k_idx] = sum_9x_4[1]; + dw[(i_c + 10) * k_spatial + k_idx] = sum_9x_4[2]; + dw[(i_c + 11) * k_spatial + k_idx] = sum_9x_4[3]; + + i_c += 12; + } + return i_c; +} + +static int FilterGrad8Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + + if ((out_ch - i_c) >= 8) { + float32x4_t sum_03_4 = vdupq_n_f32(0.0f); + float32x4_t sum_47_4 = vdupq_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy); + sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4); + + float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4); + float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4); + sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3]; + + dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0]; + dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1]; + dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2]; + dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3]; + i_c += 8; + } + return i_c; +} +static int FilterGrad4Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + if ((out_ch - i_c) >= 4) { + float32x4_t sum_4 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_4 = vld1q_f32(dy_addr + offset_dy); + sum_4 = vmlaq_f32(sum_4, x_4, dy_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_4[3]; + i_c += 4; + } + return i_c; +} + +static int Filtergrad2Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + + if ((out_ch - i_c) >= 2) { + float32x2_t sum_2 = vdup_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x2_t x_4 = vld1_f32(x_addr + offset_x); + float32x2_t dy_4 = vld1_f32(dy_addr + offset_dy); + sum_2 = vmla_f32(sum_2, x_4, dy_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_2[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_2[1]; + i_c += 2; + } + return i_c; +} +#endif +int ConvDwFilterGrad(const float *x, const float *dy, float *dw, int start, int count, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + + for (int i_k = 0; i_k < count; i_k++) { + int k_idx = start + i_k; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + int i_c = 0; +#ifdef ENABLE_ARM + i_c = FilterGrad16Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad12Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad8Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad4Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = Filtergrad2Arm(x, dy, i_c, k_idx, dw, conv_param); +#endif + for (; i_c < out_ch; i_c++) { + float sum = 0; + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + sum += x_addr[offset_x] * dy_addr[offset_dy]; + } + } + } + dw[i_c * k_spatial + k_idx] = sum; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_filter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_filter.h new file mode 100644 index 0000000000000000000000000000000000000000..1ed95e72c0d6623b00b45e9309e4c78ce4a77dec --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_filter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ +#define NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ + +#include +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwFilterGrad(const float *x, const float *dy, float *dw, int start, int count, const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_input.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_input.c new file mode 100644 index 0000000000000000000000000000000000000000..c791ed916aa0c5219577443725625c85e2050225 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_input.c @@ -0,0 +1,100 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/convolution_grad_input.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" +#ifdef ENABLE_ARM +#include +#endif + +int ConvDwInputGrad(const float *dy, const float *w, float *dx, int start, int count, const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_spatial = conv_param->output_h_ * conv_param->output_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int k_spatial = k_h * k_w; + int end = start + count; + + int j = start; + for (; j <= (end - C4NUM); j += C4NUM) { + float *c = dx + j; + const float *mat_b_0 = w + (j + C0NUM) * k_spatial; + const float *mat_b_1 = w + (j + C1NUM) * k_spatial; + const float *mat_b_2 = w + (j + C2NUM) * k_spatial; + const float *mat_b_3 = w + (j + C3NUM) * k_spatial; + + for (int si = 0; si < out_spatial; si++) { + const float *a = dy + j + si * out_ch; +#ifdef ENABLE_ARM + float32x4_t mat_a = vld1q_f32(a); +#else + float mat_a[C4NUM] = {a[C0NUM], a[C1NUM], a[C2NUM], a[C3NUM]}; +#endif + int output_row = (si) / out_w; + int output_col = (si) % out_w; + for (int k = 0; k < k_spatial; k++) { + int row_stride_offset = output_row * conv_param->stride_h_; + int col_stride_offset = output_col * conv_param->stride_w_; + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = -conv_param->pad_u_ + kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = -conv_param->pad_l_ + kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; +#ifdef ENABLE_ARM + float32x4_t mat_b = {mat_b_0[k], mat_b_1[k], mat_b_2[k], mat_b_3[k]}; + float32x4_t mat_c = vld1q_f32(c + offset); + mat_c = vmlaq_f32(mat_c, mat_b, mat_a); + vst1q_f32(c + offset, mat_c); +#else + c[offset + C0NUM] += mat_a[C0NUM] * mat_b_0[k]; + c[offset + C1NUM] += mat_a[C1NUM] * mat_b_1[k]; + c[offset + C2NUM] += mat_a[C2NUM] * mat_b_2[k]; + c[offset + C3NUM] += mat_a[C3NUM] * mat_b_3[k]; +#endif + } + } + } + } + + for (; j < end; j++) { + float *c = dx + j; + const float *b = w + j * k_spatial; + for (int si = 0; si < out_spatial; si++) { + const float *a = dy + j + si * out_ch; + int output_row = si / out_w; + int output_col = si % out_w; + int row_stride_offset = output_row * conv_param->stride_h_; + int col_stride_offset = output_col * conv_param->stride_w_; + for (int k = 0; k < k_spatial; k++) { + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = -conv_param->pad_u_ + kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = -conv_param->pad_l_ + kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; + c[offset] += a[0] * b[k]; + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_input.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_input.h new file mode 100644 index 0000000000000000000000000000000000000000..734b53df927881e512285f1afcc840e6f7069cca --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_input.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ +#define NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ + +#include +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwInputGrad(const float *dy, const float *w, float *dx, int start, int count, const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..c04aee3652f030be627a6e1bd256ce4ab03de5e7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_grad.c @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/dropout_grad.h" + +void DropoutGrad(const float *yt_ptr, const float *mask, float *output_ptr, int length, float scale) { + for (int i = 0; i < length; i++) { + output_ptr[i] = yt_ptr[i] * mask[i] * scale; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..1338ec7889c56d92439b637878f6d2830831ae3e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_grad.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_DROPOUT_GRAD_H_ +#define NNACL_FP32_GRAD_DROPOUT_GRAD_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void DropoutGrad(const float *yt_ptr, const float *mask, float *output_ptr, int length, float ratio); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_DROPOUT_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..51015ed35f48715cec6d9b4203ed3fdb5d2fcb78 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_DROPOUT_PARAMETER_H_ +#define NNACL_FP32_GRAD_DROPOUT_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct { + OpParameter op_parameter_; + float ratio_; +} DropoutParameter; + +#endif // NNACL_FP32_GRAD_DROPOUT_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.c new file mode 100644 index 0000000000000000000000000000000000000000..8f51b499269a1289fdae1e639c6677d576a0a7bd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.c @@ -0,0 +1,855 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/gemm.h" +#include +#ifdef __ARM_NEON +#include +#endif +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +#ifdef _MSC_VER +void AddMatrix(const float *v1, float *v2, float beta, int row, int col, int stride) { +#else +void AddMatrix(const float *restrict v1, float *restrict v2, float beta, int row, int col, int stride) { +#endif + const float *src_ptr = v1; + float *dst_ptr = v2; + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + dst_ptr[c] += beta * src_ptr[c]; + } + src_ptr += stride; + dst_ptr += stride; + } +} + +int MatSize(int row, int col, int round) { + int res = UP_ROUND(row, round) * col; + int res1 = UP_ROUND(col, round) * row; + return (res > res1) ? res : res1; +} + +int MatSizeTotal(int row, int col, int deep, int stride) { +#ifdef ENABLE_ARM32 + const int num0 = C4NUM; +#elif ENABLE_AVX + const int num0 = C6NUM; +#else + const int num0 = C12NUM; +#endif + +#ifdef ENABLE_AVX + const int num1 = C16NUM; +#else + const int num1 = C8NUM; +#endif + int res = MatSize(row, deep, num0) + MatSize(col, deep, num1); + if (stride > 0) res += row * stride; + return res; +} +#ifdef ENABLE_ARM32 +static void RowMajor2Row4MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * lead; + for (int c = 0; c < col; c++) { + int cd8 = c / 4; + int cm8 = c % 4; + dst_ptr[cd8 * 4 * row + r * 4 + cm8] = src[c]; + } + } +} +#endif + +void RowMajor2Row8MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * lead; + for (int c = 0; c < col; c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[c]; + } + } + return; +} + +#ifdef ENABLE_ARM64 +static void RowMajor2Col12MajorStrideArm64(const float *src_c, float *dst_c, int lead) { + size_t stride = lead * sizeof(float); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.4s}, [x10], %[stride]\n" + "ld1 {v1.4s}, [x10], %[stride]\n" + "ld1 {v2.4s}, [x10], %[stride]\n" + "ld1 {v3.4s}, [x10], %[stride]\n" + + "ld1 {v4.4s}, [x10], %[stride]\n" + "ld1 {v5.4s}, [x10], %[stride]\n" + "ld1 {v6.4s}, [x10], %[stride]\n" + "ld1 {v7.4s}, [x10], %[stride]\n" + + "zip1 v12.4s, v0.4s, v1.4s\n" + "zip2 v13.4s, v0.4s, v1.4s\n" + "zip1 v14.4s, v2.4s, v3.4s\n" + "zip2 v15.4s, v2.4s, v3.4s\n" + + "ld1 {v8.4s}, [x10], %[stride]\n" + "ld1 {v9.4s}, [x10], %[stride]\n" + "ld1 {v10.4s}, [x10], %[stride]\n" + "ld1 {v11.4s}, [x10], %[stride]\n" + + "zip1 v16.4s, v4.4s, v5.4s\n" + "zip2 v17.4s, v4.4s, v5.4s\n" + "zip1 v18.4s, v6.4s, v7.4s\n" + "zip2 v19.4s, v6.4s, v7.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v23.2d, v12.2d, v14.2d\n" + "trn1 v26.2d, v13.2d, v15.2d\n" + "trn2 v29.2d, v13.2d, v15.2d\n" + + "trn1 v21.2d, v16.2d, v18.2d\n" + "trn2 v24.2d, v16.2d, v18.2d\n" + "trn1 v27.2d, v17.2d, v19.2d\n" + "trn2 v30.2d, v17.2d, v19.2d\n" + + "zip1 v12.4s, v8.4s, v9.4s\n" + "zip2 v13.4s, v8.4s, v9.4s\n" + "zip1 v14.4s, v10.4s, v11.4s\n" + "zip2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v22.2d, v12.2d, v14.2d\n" + "trn2 v25.2d, v12.2d, v14.2d\n" + "trn1 v28.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif // ENABLE_ARM64 + +#ifdef ENABLE_ARM32 +void RowMajor2Col12MajorStrideArm32(const float *src_c, float *dst_c, int lead) { + size_t stride = lead * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q10}, [r10], %[stride]\n" + "vld1.32 {q13}, [r10], %[stride]\n" + + "vtrn.32 d0, d6\n" + "vtrn.32 d1, d7\n" + "vtrn.32 d20, d26\n" + "vtrn.32 d21, d27\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q8}, [r10], %[stride]\n" + "vld1.32 {q11}, [r10], %[stride]\n" + "vld1.32 {q14}, [r10], %[stride]\n" + + "vswp d1, d20\n" + "vswp d7, d26\n" + + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q9}, [r10], %[stride]\n" + "vld1.32 {q12}, [r10], %[stride]\n" + "vld1.32 {q15}, [r10], %[stride]\n" + + "vtrn.32 d2, d16\n" + "vtrn.32 d3, d17\n" + "vtrn.32 d22, d28\n" + "vtrn.32 d23, d29\n" + + "vswp d3, d22\n" + "vswp d17, d28\n" + + "vtrn.32 d4, d18\n" + "vtrn.32 d5, d19\n" + "vtrn.32 d24, d30\n" + "vtrn.32 d25, d31\n" + + "vswp d5, d24\n" + "vswp d19, d30\n" + + "vst1.32 {q0, q1}, [r12]!\n" + "vst1.32 {q2, q3}, [r12]!\n" + "vst1.32 {q8, q9}, [r12]!\n" + "vst1.32 {q10, q11}, [r12]!\n" + "vst1.32 {q12, q13}, [r12]!\n" + "vst1.32 {q14, q15}, [r12]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +} +#endif // ENABLE_ARM32 + +#ifndef ENABLE_ARM32 +void RowMajor2Row12MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * lead; + for (int c = 0; c < col; c++) { + int cd8 = c / C12NUM; + int cm8 = c % C12NUM; + dst_ptr[cd8 * C12NUM * row + r * C12NUM + cm8] = src[c]; + } + } + return; +} + +void RowMajor2Col12MajorStride(const float *src_ptr, float *dst_ptr, size_t row, size_t col, int lead) { + size_t row_up_12 = UP_ROUND(row, C12NUM); + size_t row12 = row / C12NUM * C12NUM; + size_t col4 = col / C4NUM * C4NUM; + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + size_t ri = 0; + for (; ri < row12; ri += C12NUM) { + size_t ci = 0; + for (; ci < col4; ci += C4NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C12NUM; + + /* 12x4 row-major to col-major */ +#ifdef ENABLE_ARM64 + RowMajor2Col12MajorStrideArm64(src_c, dst_c, lead); +#elif ENABLE_ARM32 + RowMajor2Col12MajorStrideArm32(src_c, dst_c, lead); +#else + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C12NUM + tr] = src_c[tr * lead + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C12NUM; + for (int i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * lead]; + } + } + src_r += C12NUM * lead; + dst_r += C12NUM * col; + } + + for (; ri < row; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C12NUM] = src_r[i]; + } + src_r += lead; + dst_r += 1; + } + + for (; ri < row_up_12; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C12NUM] = 0; + } + dst_r += 1; + } +} +#endif + +#ifdef ENABLE_ARM64 +static void RowMajor2Col8MajorStrideArm64(const float *src_c, float *dst_c, int lead) { + /* 8x8 row-major to col-major */ + size_t stride = lead * sizeof(float); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.4s, v1.4s}, [x10], %[stride]\n" + "ld1 {v2.4s, v3.4s}, [x10], %[stride]\n" + "ld1 {v4.4s, v5.4s}, [x10], %[stride]\n" + "ld1 {v6.4s, v7.4s}, [x10], %[stride]\n" + + "zip1 v8.4s, v0.4s, v2.4s\n" + "zip2 v9.4s, v0.4s, v2.4s\n" + "zip1 v10.4s, v4.4s, v6.4s\n" + "zip2 v11.4s, v4.4s, v6.4s\n" + + "ld1 {v16.4s, v17.4s}, [x10], %[stride]\n" + "ld1 {v18.4s, v19.4s}, [x10], %[stride]\n" + "ld1 {v20.4s, v21.4s}, [x10], %[stride]\n" + "ld1 {v22.4s, v23.4s}, [x10], %[stride]\n" + + "zip1 v12.4s, v1.4s, v3.4s\n" + "zip2 v13.4s, v1.4s, v3.4s\n" + "zip1 v14.4s, v5.4s, v7.4s\n" + "zip2 v15.4s, v5.4s, v7.4s\n" + + "trn1 v0.2d, v8.2d, v10.2d\n" + "trn2 v1.2d, v8.2d, v10.2d\n" + "trn1 v2.2d, v9.2d, v11.2d\n" + "trn2 v3.2d, v9.2d, v11.2d\n" + + "zip1 v24.4s, v16.4s, v18.4s\n" + "zip2 v25.4s, v16.4s, v18.4s\n" + "zip1 v26.4s, v20.4s, v22.4s\n" + "zip2 v27.4s, v20.4s, v22.4s\n" + + "trn1 v4.2d, v12.2d, v14.2d\n" + "trn2 v5.2d, v12.2d, v14.2d\n" + "trn1 v6.2d, v13.2d, v15.2d\n" + "trn2 v7.2d, v13.2d, v15.2d\n" + + "zip1 v28.4s, v17.4s, v19.4s\n" + "zip2 v29.4s, v17.4s, v19.4s\n" + "zip1 v30.4s, v21.4s, v23.4s\n" + "zip2 v31.4s, v21.4s, v23.4s\n" + + "trn1 v16.2d, v24.2d, v26.2d\n" + "trn2 v17.2d, v24.2d, v26.2d\n" + "trn1 v18.2d, v25.2d, v27.2d\n" + "trn2 v19.2d, v25.2d, v27.2d\n" + + "trn1 v20.2d, v28.2d, v30.2d\n" + "trn2 v21.2d, v28.2d, v30.2d\n" + "trn1 v22.2d, v29.2d, v31.2d\n" + "trn2 v23.2d, v29.2d, v31.2d\n" + + "st1 {v0.4s}, [x11], #16\n" + "st1 {v16.4s}, [x11], #16\n" + "st1 {v1.4s}, [x11], #16\n" + "st1 {v17.4s}, [x11], #16\n" + "st1 {v2.4s}, [x11], #16\n" + "st1 {v18.4s}, [x11], #16\n" + "st1 {v3.4s}, [x11], #16\n" + "st1 {v19.4s}, [x11], #16\n" + "st1 {v4.4s}, [x11], #16\n" + "st1 {v20.4s}, [x11], #16\n" + "st1 {v5.4s}, [x11], #16\n" + "st1 {v21.4s}, [x11], #16\n" + "st1 {v6.4s}, [x11], #16\n" + "st1 {v22.4s}, [x11], #16\n" + "st1 {v7.4s}, [x11], #16\n" + "st1 {v23.4s}, [x11], #16\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif // ENABLE_ARM64 + +#ifdef ENABLE_ARM32 +#ifndef SUPPORT_NNIE +static void RowMajor2Col8MajorStrideArm32(const float *src_c, float *dst_c, size_t col) { + /* 8x4 row-major to col-major */ + size_t stride = col * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r11, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q4}, [r10], %[stride]\n" + "vld1.32 {q6}, [r10], %[stride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q5}, [r10], %[stride]\n" + "vld1.32 {q7}, [r10], %[stride]\n" + + "vswp d1, d8\n" + "vswp d5, d12\n" + + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vswp d3, d10\n" + "vswp d7, d14\n" + + "vst1.32 {q0, q1}, [r11]!\n" + "vst1.32 {q2, q3}, [r11]!\n" + "vst1.32 {q4, q5}, [r11]!\n" + "vst1.32 {q6, q7}, [r11]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +} + +#else +static void RowMajor2Col8MajorStrideArm32Nnie(const float *src_c, float *dst_c, size_t col) { + /* 8x4 row-major to col-major */ + size_t stride = col * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r7, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q4}, [r10], %[stride]\n" + "vld1.32 {q6}, [r10], %[stride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q5}, [r10], %[stride]\n" + "vld1.32 {q7}, [r10], %[stride]\n" + + "vswp d1, d8\n" + "vswp d5, d12\n" + + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vswp d3, d10\n" + "vswp d7, d14\n" + + "vst1.32 {q0, q1}, [r7]!\n" + "vst1.32 {q2, q3}, [r7]!\n" + "vst1.32 {q4, q5}, [r7]!\n" + "vst1.32 {q6, q7}, [r7]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r7", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +} +#endif // SUPPORT_NNIE +#endif // ENABLE_ARM32 + +void RowMajor2Col8MajorStride(const float *src_ptr, float *dst_ptr, size_t row, size_t col, int lead) { + size_t row8 = row / C8NUM * C8NUM; +#ifdef ENABLE_ARM64 + size_t col_skip = col / C8NUM * C8NUM; + size_t skip_size = C8NUM; +#else + size_t col_skip = col / C4NUM * C4NUM; + size_t skip_size = C4NUM; +#endif + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + size_t ri = 0; + for (; ri < row8; ri += C8NUM) { + size_t ci = 0; + for (; ci < col_skip; ci += skip_size) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + +#ifdef ENABLE_ARM64 + RowMajor2Col8MajorStrideArm64(src_c, dst_c, lead); +#elif ENABLE_ARM32 +#ifndef SUPPORT_NNIE + RowMajor2Col8MajorStrideArm32(src_c, dst_c, lead); +#else + RowMajor2Col8MajorStrideArm32Nnie(src_c, dst_c, lead); +#endif +#else + for (int tr = 0; tr < 8; tr++) { + for (int tc = 0; tc < 4; tc++) { + dst_c[tc * 8 + tr] = src_c[tr * lead + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + for (int i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * lead]; + } + } + src_r += C8NUM * lead; + dst_r += C8NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C8NUM] = src_r[i]; + } + src_r += lead; + dst_r += 1; + } + return; +} +#ifdef ENABLE_ARM32 +static void RowMajor2Col4MajorStride(const float *src_ptr, float *dst_ptr, size_t row, size_t col, int lead) { + size_t row8 = row / C4NUM * C4NUM; + size_t col4 = col / C4NUM * C4NUM; + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + size_t ri = 0; + for (; ri < row8; ri += C4NUM) { + size_t ci = 0; + for (; ci < col4; ci += C4NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C4NUM; + + /* 4x4 row-major to col-major */ +#ifdef ENABLE_ARM32 + size_t stride = col * 4; + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + + "vtrn.32 d0, d2\n" + "vtrn.32 d1, d3\n" + "vtrn.32 d4, d6\n" + "vtrn.32 d5, d7\n" + + "vswp d1, d4\n" + "vswp d3, d6\n" + + "vst1.32 {q0}, [r12]!\n" + "vst1.32 {q1}, [r12]!\n" + "vst1.32 {q2}, [r12]!\n" + "vst1.32 {q3}, [r12]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r12", "q0", "q1", "q2", "q3"); +#else + for (int tr = 0; tr < C4NUM; tr++) { + for (int tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C4NUM + tr] = src_c[tr * lead + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C4NUM; + for (size_t i = 0; i < C4NUM; i++) { + dst_c[i] = src_c[i * lead]; + } + } + src_r += C4NUM * col; + dst_r += C4NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C4NUM] = src_r[i]; + } + src_r += lead; + dst_r += 1; + } + return; +} +#endif + +void RowMajor2Row6MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + int max = 0; + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * lead; + int c = 0; + for (; c < col; c++) { + int cd6 = c / C6NUM; + int cm6 = c % C6NUM; + int offset = cd6 * C6NUM * row + r * C6NUM + cm6; + dst_ptr[offset] = src[c]; + if (offset > max) { + max = offset; + } + } + for (; c < UP_ROUND(col, C6NUM); c++) { + int cd6 = c / C6NUM; + int cm6 = c % C6NUM; + int offset = cd6 * C6NUM * row + r * C6NUM + cm6; + dst_ptr[offset] = 0.0f; + if (offset > max) { + max = offset; + } + } + } + return; +} + +void RowMajor2Col6MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + int totalRow = UP_ROUND(row, C6NUM); + int row6 = row / C6NUM * C6NUM; + int col8 = col / C8NUM * C8NUM; + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + int ri = 0; + for (; ri < row6; ri += C6NUM) { + int ci = 0; + for (; ci < col8; ci += C8NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C6NUM; + +#ifdef ENABLE_AVX + __m256 src0 = _mm256_loadu_ps(src_c); + __m256 src1 = _mm256_loadu_ps(src_c + lead); + __m256 src2 = _mm256_loadu_ps(src_c + 2 * lead); + __m256 src3 = _mm256_loadu_ps(src_c + 3 * lead); + __m256 src4 = _mm256_loadu_ps(src_c + 4 * lead); + __m256 src5 = _mm256_loadu_ps(src_c + 5 * lead); + __m256 trans0 = _mm256_unpacklo_ps(src0, src1); + __m256 trans1 = _mm256_unpacklo_ps(src2, src3); + __m256 trans2 = _mm256_unpacklo_ps(src4, src5); + __m256 trans3 = _mm256_unpackhi_ps(src0, src1); + __m256 trans4 = _mm256_unpackhi_ps(src2, src3); + __m256 trans5 = _mm256_unpackhi_ps(src4, src5); + __m128 lo0 = _mm256_castps256_ps128(trans0); + __m128 lo1 = _mm256_castps256_ps128(trans1); + __m128 lo2 = _mm256_castps256_ps128(trans2); + __m128 lo3 = _mm256_castps256_ps128(trans3); + __m128 lo4 = _mm256_castps256_ps128(trans4); + __m128 lo5 = _mm256_castps256_ps128(trans5); + __m128 hi0 = _mm256_extractf128_ps(trans0, 1); + __m128 hi1 = _mm256_extractf128_ps(trans1, 1); + __m128 hi2 = _mm256_extractf128_ps(trans2, 1); + __m128 hi3 = _mm256_extractf128_ps(trans3, 1); + __m128 hi4 = _mm256_extractf128_ps(trans4, 1); + __m128 hi5 = _mm256_extractf128_ps(trans5, 1); + __m128 res0 = _mm_shuffle_ps(lo0, lo1, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res1 = _mm_shuffle_ps(lo2, lo0, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res2 = _mm_shuffle_ps(lo1, lo2, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res3 = _mm_shuffle_ps(lo3, lo4, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res4 = _mm_shuffle_ps(lo5, lo3, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res5 = _mm_shuffle_ps(lo4, lo5, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res6 = _mm_shuffle_ps(hi0, hi1, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res7 = _mm_shuffle_ps(hi2, hi0, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res8 = _mm_shuffle_ps(hi1, hi2, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res9 = _mm_shuffle_ps(hi3, hi4, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res10 = _mm_shuffle_ps(hi5, hi3, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res11 = _mm_shuffle_ps(hi4, hi5, _MM_SHUFFLE(3, 2, 3, 2)); + _mm_storeu_ps(dst_c, res0); + _mm_storeu_ps(dst_c + C4NUM, res1); + _mm_storeu_ps(dst_c + C8NUM, res2); + _mm_storeu_ps(dst_c + C12NUM, res3); + _mm_storeu_ps(dst_c + C16NUM, res4); + _mm_storeu_ps(dst_c + C20NUM, res5); + _mm_storeu_ps(dst_c + C24NUM, res6); + _mm_storeu_ps(dst_c + C28NUM, res7); + _mm_storeu_ps(dst_c + C32NUM, res8); + _mm_storeu_ps(dst_c + C36NUM, res9); + _mm_storeu_ps(dst_c + C40NUM, res10); + _mm_storeu_ps(dst_c + C44NUM, res11); +#else + for (int tr = 0; tr < C6NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C6NUM + tr] = src_c[tr * lead + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C6NUM; + for (int i = 0; i < C6NUM; i++) { + dst_c[i] = src_c[i * lead]; + } + } + src_r += C6NUM * lead; + dst_r += C6NUM * col; + } + + for (; ri < row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C6NUM] = src_r[i]; + } + src_r += lead; + dst_r += 1; + } + + for (; ri < totalRow; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C6NUM] = 0; + } + dst_r += 1; + } +} + +void RowMajor2Col16MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + int row16 = row / C16NUM * C16NUM; + int col8 = col / C8NUM * C8NUM; + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + int ri = 0; + for (; ri < row16; ri += C16NUM) { + int ci = 0; + for (; ci < col8; ci += C8NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C16NUM; +#ifdef ENABLE_AVX + Transpose8X8Fp32Avx(src_c, dst_c, lead, C16NUM); + Transpose8X8Fp32Avx(src_c + C8NUM * lead, dst_c + C8NUM, lead, C16NUM); +#else + for (int tr = 0; tr < C16NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C16NUM + tr] = src_c[tr * lead + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C16NUM; + for (int i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * lead]; + } + } + src_r += C16NUM * lead; + dst_r += C16NUM * col; + } + for (; ri < row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C16NUM] = src_r[i]; + } + src_r += lead; + dst_r += 1; + } + + int total_row = UP_ROUND(row, C16NUM); + for (; ri < total_row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C16NUM] = 0; + } + dst_r += 1; + } +} + +void RowMajor2Row16MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + int max = 0; + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * lead; + int c = 0; + for (; c < col; c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = src[c]; + if ((cd16 * C16NUM * row + r * C16NUM + cm16) > max) max = cd16 * C16NUM * row + r * C16NUM + cm16; + } + for (; c < UP_ROUND(col, C16NUM); c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = 0; + if ((cd16 * C16NUM * row + r * C16NUM + cm16) > max) max = cd16 * C16NUM * row + r * C16NUM + cm16; + } + } + return; +} +void GemmMatmul(int ta, int tb, int M, int N, int K, float alpha, const float *mat_a, int lda, const float *mat_b, + int ldb, float beta, float *mat_c, int ldc, float *workspace) { + GemmCb gcb; + gcb.atype = ActType_No; + gcb.ca = 0; + gcb.cb = 0; + gcb.bias = NULL; + gcb.mat_a = NULL; + gcb.mat_b = NULL; + GemmMatmulPlus(ta, tb, M, N, K, alpha, mat_a, lda, mat_b, ldb, beta, mat_c, ldc, workspace, &gcb); +} + +void GemmMatmulPlus(int ta, int tb, int M, int N, int K, float alpha, const float *mat_a, int lda, const float *mat_b, + int ldb, float beta, float *mat_c, int ldc, float *workspace, GemmCb *gcb) { +#ifdef ENABLE_ARM32 + const int num = C4NUM; + const int num1 = C8NUM; +#elif ENABLE_AVX + const int num = C6NUM; + const int num1 = C16NUM; +#else + const int num = C12NUM; + const int num1 = C8NUM; +#endif + float *output = mat_c; + float *fworkspace = workspace; + int incremental = (beta < 0.f) || (beta > 0.f); + float *mat_a_input = (float *)mat_a; + float *mat_b_input = (float *)mat_b; + + if (!gcb->ca) { + mat_a_input = fworkspace; + if (ta) { + fworkspace += MatSize(K, M, num); +#ifdef ENABLE_ARM32 + RowMajor2Row4MajorStride(mat_a, mat_a_input, K, M, lda); +#elif ENABLE_AVX + RowMajor2Row6MajorStride(mat_a, mat_a_input, K, M, lda); +#else + RowMajor2Row12MajorStride(mat_a, mat_a_input, K, M, lda); +#endif + } else { + fworkspace += MatSize(M, K, num); +#ifdef ENABLE_ARM32 + RowMajor2Col4MajorStride(mat_a, mat_a_input, M, K, lda); +#elif ENABLE_AVX + RowMajor2Col6MajorStride(mat_a, mat_a_input, M, K, lda); +#else + RowMajor2Col12MajorStride(mat_a, mat_a_input, M, K, lda); +#endif + } + } + if (!gcb->cb) { + mat_b_input = fworkspace; + if (tb) { + fworkspace += MatSize(N, K, num1); +#ifdef ENABLE_AVX + RowMajor2Col16MajorStride(mat_b, mat_b_input, N, K, ldb); +#else + RowMajor2Col8MajorStride(mat_b, mat_b_input, N, K, ldb); +#endif + } else { + fworkspace += MatSize(K, N, num1); +#ifdef ENABLE_AVX + RowMajor2Row16MajorStride(mat_b, mat_b_input, K, N, ldb); +#else + RowMajor2Row8MajorStride(mat_b, mat_b_input, K, N, ldb); +#endif + } + } + if (incremental) output = fworkspace; +#ifdef ENABLE_ARM32 + MatmulFloatNeon32Opt(mat_a_input, mat_b_input, output, gcb->bias, (int)gcb->atype, K, M, N, ldc, 1); +#else + MatMulOpt(mat_a_input, mat_b_input, output, gcb->bias, gcb->atype, K, M, N, ldc, OutType_Nhwc); +#endif + if (incremental) AddMatrix(output, mat_c, beta, M, N, ldc); + gcb->mat_a = mat_a_input; + gcb->mat_b = mat_b_input; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..9d2f6e496b432fc1f3434ba8539a66e6724f8162 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_GEMM_H_ +#define NNACL_FP32_GRAD_GEMM_H_ + +#include +#include "nnacl_c/op_base.h" +#ifdef __cplusplus +extern "C" { +#endif +typedef struct { + int ca; + int cb; + ActType atype; + float *bias; + float *mat_a; + float *mat_b; +} GemmCb; + +void GemmMatmulPlus(int ta, int tb, int M, int N, int K, float alpha, const float *mat_a, int lda, const float *mat_b, + int ldb, float beta, float *mat_c, int ldc, float *workspace, GemmCb *cb); +void GemmMatmul(int ta, int tb, int M, int N, int K, float alpha, const float *mat_a, int lda, const float *mat_b, + int ldb, float beta, float *mat_c, int ldc, float *workspace); +int MatSize(int row, int col, int round); +int MatSizeTotal(int row, int col, int deep, int inc); +void AddMatrix(const float *v1, float *v2, float beta, int row, int col, int stride); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_GEMM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernorm_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernorm_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..ce0bce9213c9848e6e80fb284763db9c8af18fec --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernorm_grad.c @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32_grad/layernorm_grad.h" +#include +#include +#include "nnacl_c/errorcode.h" + +int LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma, + int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db) { + // var is actually layer_norm forward output var + const float eps = 1e-12; + const float *var_sqrt_rev = var; + if (block_size <= 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + for (int i = 0; i < param_num; ++i) { + float dgamma = 0.0f; + float dbeta = 0.0f; + for (int j = i; j < param_size * param_num; j += param_num) { + int norm_shift = (int)(j / block_size); + dgamma += dy[j] * pow(var[norm_shift] + eps, -0.5) * (x[j] - mean[norm_shift]); + dbeta += dy[j]; + } + dg[i] = dgamma; + db[i] = dbeta; + } + for (int i = 0; i < block_num; ++i) { + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + for (int j = 0; j < block_size; ++j) { + int index = i * block_size + j; + float dxm = x[index] - mean[i]; + int param_shift = index % param_num; + float dyg = dy[index] * gamma[param_shift]; + sum1 += -0.5f * dyg * dxm * pow(var_sqrt_rev[i] + eps, -1.5); + sum2 += dyg; + sum3 += -2.0f * dxm; + } + for (int j = 0; j < block_size; ++j) { + int index = i * block_size + j; + float var_sqrt = pow(var_sqrt_rev[i] + eps, -0.5); + int param_shift = index % param_num; + float dx1 = dy[index] * gamma[param_shift] * var_sqrt; + float dx2 = sum1 * 2.0f / block_size * (x[index] - mean[i]); + float dx3 = (-1.0f * var_sqrt * sum2 + (1.0f / block_size) * sum1 * sum3) * (1.0f / block_size); + dx[index] = dx1 + dx2 + dx3; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernorm_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernorm_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..8558c448241b0a0f099c887a26aaf4f0fed629f0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernorm_grad.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_LAYERNORM_GRAD_H_ +#define NNACL_FP32_GRAD_LAYERNORM_GRAD_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma, + int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_GRAD_LAYERNORM_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernormgrad_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernormgrad_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..ebb3c8b791684d3e95a9179d7611437abedf64f7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernormgrad_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_LAYERNORMGRAD_PARAMETER_H_ +#define NNACL_FP32_GRAD_LAYERNORMGRAD_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct { + OpParameter op_parameter_; + int begin_norm_axis_; + int begin_params_axis_; +} LayerNormGradParameter; + +#endif // NNACL_FP32_GRAD_LAYERNORMGRAD_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/lstm_grad_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/lstm_grad_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8400af94f014c65d7d0daf031a25243a5fc12b0c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/lstm_grad_fp32.c @@ -0,0 +1,237 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" +#include +#include +#include "nnacl_c/lstm_parameter.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32_grad/gemm.h" +#include "nnacl_c/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/nnacl_utils.h" + +static const int num_of_gates = 4; +static const int no_of_temp_matrices_sized_output_step = 5; + +static inline float *AllocteFromScrachPad(float **scrach_pad, int size) { + float *buffer = *scrach_pad; + *scrach_pad += size; + return buffer; +} + +static const int weights_order_IOFG[2 * 4] = {0, 3, 1, 2, 4, 7, 5, 6}; // IOFG order to IFGO order +static const int weights_order_IFGO[2 * 4] = {0, 2, 3, 1, 4, 6, 7, 5}; // IFGO order to IOFG order + +const int *getLstmOrderIOFG(void) { return weights_order_IOFG; } + +const int *getLstmOrderIFGO(void) { return weights_order_IFGO; } + +void PackLstmWeightTranspose(float *dst, const float *src, int batch, int col, int row, int row_align, + const int *order) { + for (int i = 0; i < batch; i++) { + const float *src_batch = src + i * col * row; + float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col * row_align; +#ifdef ENABLE_AVX + RowMajor2Row16Major(src_batch, dst_batch, row, col); +#elif defined(ENABLE_ARM32) + RowMajor2Row4Major(src_batch, dst_batch, row, col); +#else + RowMajor2Row8Major(src_batch, dst_batch, row, col); +#endif + } +} + +void ReorderLstmWeights(float *dst, const float *src, int nof_martices, int col, int row, const int *order) { + int matrix_size = col * row; + for (int i = 0; i < nof_martices; i++) { + const float *src_block = src + i * matrix_size; + float *dst_block = dst + ((order == NULL) ? i : order[i]) * matrix_size; + memcpy(dst_block, src_block, matrix_size * sizeof(float)); + } +} + +void sumCols(int m, int n, int stride, float *inMat, float *outMat, bool accumulate) { + for (int idn = 0; idn < n; idn++) { + float *col = inMat + idn; + if (!accumulate) { + *outMat = 0; + } + for (int idm = 0; idm < m; idm++) { + *outMat += *col; + col += stride; + } + outMat++; + } +} + +int GetGemmMatMullWorkspace(int batch, int input_size, int hidden_size) { + int workspace_size, temp; + // if the appropriate GemmMatNul use beta>0 matSizeTotal must have col as last parameter. + workspace_size = MatSizeTotal(batch, input_size, hidden_size, input_size); + temp = MatSizeTotal(batch, hidden_size, hidden_size, hidden_size); + workspace_size = (temp > workspace_size) ? temp : workspace_size; + temp = MatSizeTotal(hidden_size, input_size, batch, input_size); + workspace_size = (temp > workspace_size) ? temp : workspace_size; + temp = MatSizeTotal(hidden_size, hidden_size, batch, hidden_size); + workspace_size = (temp > workspace_size) ? temp : workspace_size; + return workspace_size; +} + +int GetRunWorkspaceSize(const LstmGradParameter *lstm_param) { + int time_stamp_len = lstm_param->batch_ * lstm_param->hidden_size_; + int workspace_size = no_of_temp_matrices_sized_output_step * time_stamp_len; + workspace_size += GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_); + return workspace_size; +} + +size_t GetRunWorkspaceGemmOffset(const LstmGradParameter *lstm_param) { + int time_stamp_len = lstm_param->batch_ * lstm_param->hidden_size_; + return no_of_temp_matrices_sized_output_step * time_stamp_len; +} + +void LstmGradReorderDy(float *src, float *dst, LstmGradParameter *lstm_param) { + int dir_mult = lstm_param->bidirectional_ ? C2NUM : C1NUM; + for (int b = 0; b < lstm_param->batch_; b++) { + int batch_offset = b * dir_mult * lstm_param->hidden_size_; + float *dy = src + batch_offset; + memcpy(dst + b * lstm_param->hidden_size_, dy, lstm_param->hidden_size_ * sizeof(float)); + } +} + +void LstmGradDoInputStep(const float *output_gate, float *cell_state, float *prev_cell_state, float *cell_gate, + float *input_gate, float *forget_gate, float *dY, float *dC, float *dH, float **dA, float *dX, + float *w, float *v, float *workspace, const LstmGradParameter *lstm_param) { + float *scratchPad = workspace; + + int seq_len = lstm_param->batch_ * lstm_param->hidden_size_; + float *temp0 = AllocteFromScrachPad(&scratchPad, seq_len); + float *temp1 = AllocteFromScrachPad(&scratchPad, seq_len); + float *temp2 = AllocteFromScrachPad(&scratchPad, seq_len); + float *temp3 = AllocteFromScrachPad(&scratchPad, seq_len); + float *temp4 = AllocteFromScrachPad(&scratchPad, seq_len); + + // Accumulate gradients into dH + ElementAdd(dH, dY, dH, seq_len); + + ElementMul(dH, output_gate, temp1, seq_len); + Tanh(cell_state, seq_len, temp0); + ElementMul(temp0, temp0, temp2, seq_len); + ElementMul(temp1, temp2, temp4, seq_len); + ElementSub(temp1, temp4, temp1, seq_len); + ElementAdd(dC, temp1, dC, seq_len); + + // calculate dI, dO, dF and dG + float *dI = temp1; // dI = dC_{t} * G + ElementMul(dC, cell_gate, dI, seq_len); + float *dO = temp2; // dO = dH * Tanh(C_{t}) + ElementMul(dH, temp0, dO, seq_len); + float *dF = temp3; // dF = dC_{t} * C_{t-1} + ElementMul(dC, prev_cell_state, dF, seq_len); + float *dG = temp4; // dG = dC_{t} * I + ElementMul(dC, input_gate, dG, seq_len); + + // dAi = dI * I * (1 - I) + float *dAi = temp1; + *dA = dAi; + ElementMul(dI, input_gate, dAi, seq_len); + ElementMul(dAi, input_gate, temp0, seq_len); + ElementSub(dAi, temp0, dAi, seq_len); + + // dAo = dO * O * (1 - O) + float *dAo = temp2; + ElementMul(dO, output_gate, dAo, seq_len); + ElementMul(dAo, output_gate, temp0, seq_len); + ElementSub(dAo, temp0, dAo, seq_len); + + // dAf = dF * F * (1 - F) + float *dAf = temp3; + ElementMul(dF, forget_gate, dAf, seq_len); + ElementMul(dAf, forget_gate, temp0, seq_len); + ElementSub(dAf, temp0, dAf, seq_len); + + float *dAg = temp4; + ElementMul(cell_gate, cell_gate, temp0, seq_len); + ElementMul(dG, temp0, temp0, seq_len); + ElementSub(dG, temp0, dAg, seq_len); + + float *mat_workspace = AllocteFromScrachPad( + &scratchPad, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_)); + float *weights_loop = w; + float *dA_loop = dAi; // dAi, dAo, dAf, dAg + for (int idx = 0; idx < num_of_gates; idx++) { + GemmMatmul(0, 0, lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_, 1.0, dA_loop, + lstm_param->hidden_size_, weights_loop, lstm_param->input_size_, 1.0, dX, lstm_param->input_size_, + mat_workspace); + weights_loop += lstm_param->hidden_size_ * lstm_param->input_size_; + dA_loop += seq_len; + } + + // calculate dH next + size_t dH_size = lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float); + memset(dH, 0, dH_size); + dA_loop = dAi; + weights_loop = v; + for (int idx = 0; idx < num_of_gates; idx++) { + GemmMatmul(0, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, 1.0, dA_loop, + lstm_param->hidden_size_, weights_loop, lstm_param->hidden_size_, 1.0, dH, lstm_param->hidden_size_, + mat_workspace); + weights_loop += lstm_param->hidden_size_ * lstm_param->hidden_size_; + dA_loop += seq_len; + } + // calculate dC next + ElementMul(dC, forget_gate, dC, seq_len); +} + +void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *dV, float *dB, + float *workspace, const LstmGradParameter *lstm_param) { + // Calc dWi, dWo, dWf, dWg, dVi, dVo, dVf, dVg, dBi, dBo, dBf, dBg + int seq_len = lstm_param->batch_ * lstm_param->hidden_size_; + float *mat_workspace = AllocteFromScrachPad( + &workspace, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_)); + float *dA_loop = dA; // dAi, dAo, dAf, dAg + int dW_size = lstm_param->input_size_ * lstm_param->hidden_size_; + int dV_size = lstm_param->hidden_size_ * lstm_param->hidden_size_; + int dB_size = 0; + float *dW_loop = dW; + float *dV_loop = dV; + float *dB_loop = 0; + if (lstm_param->has_bias_) { + dB_loop = dB; + dB_size = lstm_param->hidden_size_; + } + + for (int idx = 0; idx < num_of_gates; idx++) { + // Calc dW + GemmMatmul(1, 0, lstm_param->hidden_size_, lstm_param->input_size_, lstm_param->batch_, 1.0, dA_loop, + lstm_param->hidden_size_, input_t, lstm_param->input_size_, 1.0, dW_loop, lstm_param->input_size_, + mat_workspace); + // Calc dV + GemmMatmul(1, 0, lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->batch_, 1.0, dA_loop, + lstm_param->hidden_size_, prev_hidden_state, lstm_param->hidden_size_, 1.0, dV_loop, + lstm_param->hidden_size_, mat_workspace); + // Clac dB + if (dB_loop != 0) { + sumCols(lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, dA_loop, dB_loop, true); + } + dA_loop += seq_len; + dW_loop += dW_size; + dV_loop += dV_size; + dB_loop += dB_size; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/lstm_grad_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/lstm_grad_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..12bc3758a446959cb3ccbef801a5fb057c0a4a85 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/lstm_grad_fp32.h @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_LSTM_GRAD_H_ +#define NNACL_FP32_GRAD_LSTM_GRAD_H_ + +#include "nnacl_c/op_base.h" + +typedef struct LstmGradParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int input_size_; + int hidden_size_; // output_size + int seq_len_; + int batch_; + // other parameter + int output_step_; + bool bidirectional_; + float zoneout_cell_; + float zoneout_hidden_; + int input_row_align_; + int input_col_align_; + int state_row_align_; + int state_col_align_; + int has_bias_; +} LstmGradParameter; + +#ifdef __cplusplus +extern "C" { +#endif + +const int *getLstmOrderIOFG(void); + +const int *getLstmOrderIFGO(void); + +int GetRunWorkspaceSize(const LstmGradParameter *lstm_param); + +size_t GetRunWorkspaceGemmOffset(const LstmGradParameter *lstm_param); + +void LstmGradReorderDy(float *src, float *dst, LstmGradParameter *lstm_param); + +void PackLstmWeightTranspose(float *dst, const float *src, int batch, int col, int row, int row_align, + const int *order); + +void ReorderLstmWeights(float *dst, const float *src, int nof_martices, int col, int row, const int *order); + +void LstmGradDoInputStep(const float *output_gate, float *cell_state, float *prev_cell_state, float *cell_gate, + float *input_gate, float *forget_gate, float *dY, float *dC, float *dH, float **dA, float *dX, + float *w, float *v, float *workspace, const LstmGradParameter *lstm_param); + +void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *dV, float *dB, + float *workspace, const LstmGradParameter *lstm_param); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_GRAD_LSTM_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/maxpool_grad_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/maxpool_grad_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..5ffdb880cf6c5f3d11473e64ee9f060e1bd0e23f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/maxpool_grad_grad.c @@ -0,0 +1,147 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32_grad/maxpool_grad_grad.h" +#include "nnacl_c/errorcode.h" + +int MaxPoolGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end, + PoolingParameter *param, PoolingComputeParam *args) { + const int channel = args->input_channel_; + const int input_height = args->input_h_; + const int input_width = args->input_w_; + + const int window_height = args->window_h_; + const int window_width = args->window_w_; + + const int stride_height = param->stride_h_; + const int stride_width = param->stride_w_; + + const int pad_top = param->pad_u_; + const int pad_left = param->pad_l_; + + const int output_height = args->output_h_; + NNACL_CHECK_ZERO_RETURN_ERR(output_height); + const int output_width = args->output_w_; + NNACL_CHECK_ZERO_RETURN_ERR(output_width); + + const int output_chw = channel * output_height * output_width; + NNACL_CHECK_ZERO_RETURN_ERR(output_chw); + const int output_hw = output_height * output_width; + NNACL_CHECK_ZERO_RETURN_ERR(output_hw); + + for (size_t pos = start; pos < end; pos++) { + const int pos_n = pos / output_chw; + const int pos_c = pos / output_hw % channel; + const int pos_h = pos / output_width % output_height; + const int pos_w = pos % output_width; + + int h_start = pos_h * stride_height - pad_top; + int w_start = pos_w * stride_width - pad_left; + const int h_end = MSMIN(h_start + window_height, input_height); + const int w_end = MSMIN(w_start + window_width, input_width); + h_start = MSMAX(h_start, 0); + w_start = MSMAX(w_start, 0); + + int input_start = pos_n * channel * input_height * input_width + pos_c * input_height * input_width; + int max_idx = h_start * input_width + w_start; + float max_data = input[input_start + max_idx]; + + for (int h_cur = h_start; h_cur < h_end; ++h_cur) { + for (int w_cur = w_start; w_cur < w_end; ++w_cur) { + int input_idx = h_cur * input_width + w_cur; + float input_data = input[input_start + input_idx]; + if (input_data > max_data) { + max_idx = input_idx; + max_data = input_data; + } + } + } + output[pos] = grad[input_start + max_idx]; + } + return NNACL_OK; +} + +int MaxPool3DGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end, + Pooling3DParameter *param, PoolingComputeParam *args) { + PoolingParameter *param_2d = (PoolingParameter *)(param); + const int channel = args->input_channel_; + const int input_depth = param->input_d_; + const int input_height = args->input_h_; + const int input_width = args->input_w_; + + const int window_depth = param->window_d_; + const int window_height = args->window_h_; + const int window_width = args->window_w_; + + const int stride_depth = param->stride_d_; + const int stride_height = param_2d->stride_h_; + const int stride_width = param_2d->stride_w_; + + const int pad_front = param->pad_f_; + const int pad_top = param_2d->pad_u_; + const int pad_left = param_2d->pad_l_; + + const int output_depth = param->output_d_; + NNACL_CHECK_ZERO_RETURN_ERR(output_depth); + const int output_height = args->output_h_; + NNACL_CHECK_ZERO_RETURN_ERR(output_height); + const int output_width = args->output_w_; + NNACL_CHECK_ZERO_RETURN_ERR(output_width); + + const int output_cdhw = channel * output_depth * output_height * output_width; + NNACL_CHECK_ZERO_RETURN_ERR(output_cdhw); + const int output_dhw = output_depth * output_height * output_width; + NNACL_CHECK_ZERO_RETURN_ERR(output_dhw); + const int output_hw = output_height * output_width; + NNACL_CHECK_ZERO_RETURN_ERR(output_hw); + + for (size_t pos = start; pos < end; pos++) { + const int pos_n = pos / output_cdhw; + const int pos_c = pos / output_dhw % channel; + const int pos_d = pos / output_hw % output_depth; + const int pos_h = pos / output_width % output_height; + const int pos_w = pos % output_width; + + int d_start = pos_d * stride_depth - pad_front; + int h_start = pos_h * stride_height - pad_top; + int w_start = pos_w * stride_width - pad_left; + const int d_end = MSMIN(d_start + window_depth, input_depth); + const int h_end = MSMIN(h_start + window_height, input_height); + const int w_end = MSMIN(w_start + window_width, input_width); + d_start = MSMAX(d_start, 0); + h_start = MSMAX(h_start, 0); + w_start = MSMAX(w_start, 0); + + int input_start = + pos_n * channel * input_depth * input_height * input_width + pos_c * input_depth * input_height * input_width; + int max_idx = d_start * input_height * input_width + h_start * input_width + w_start; + float max_data = input[input_start + max_idx]; + + for (int d_cur = d_start; d_cur < d_end; ++d_cur) { + for (int h_cur = h_start; h_cur < h_end; ++h_cur) { + for (int w_cur = w_start; w_cur < w_end; ++w_cur) { + int input_idx = d_cur * input_height * input_width + h_cur * input_width + w_cur; + float input_data = input[input_start + input_idx]; + if (input_data > max_data) { + max_idx = input_idx; + max_data = input_data; + } + } + } + } + output[pos] = grad[input_start + max_idx]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/maxpool_grad_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/maxpool_grad_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..9edeef77c2641740bdeafcea5ad8148dcd0dcbe0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/maxpool_grad_grad.h @@ -0,0 +1,36 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_MAXPOOL_GRAD_GARD_H_ +#define NNACL_FP32_GRAD_MAXPOOL_GRAD_GARD_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +int MaxPoolGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end, + PoolingParameter *param, PoolingComputeParam *args); + +int MaxPool3DGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end, + Pooling3DParameter *param, PoolingComputeParam *args); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_MAXPOOL_GRAD_GARD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..63c452f8fdfc6446f9869d403d5c811fa7177440 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.c @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/nllloss_grad_fp32.h" + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +int NLLLossGrad(const float *logits, const float *loss_grad, const int *labels, const float *weight, + const float *total_weight, float *logits_grad, int batch, int class_num, ReductionType reduction_type) { + if (logits == NULL || loss_grad == NULL || labels == NULL || weight == NULL || total_weight == NULL || + logits_grad == NULL) { + return NNACL_NULL_PTR; + } + + memset(logits_grad, 0, batch * class_num * sizeof(float)); + for (int i = 0; i < batch; i++) { + int index = i * class_num + labels[i]; + float n_weight = weight[labels[i]]; + if (reduction_type == Reduction_Sum) { + logits_grad[index] = -loss_grad[0] * n_weight; + } else if (reduction_type == Reduction_Mean) { + logits_grad[index] = -loss_grad[0] * n_weight / *total_weight; + } else { + logits_grad[index] = -loss_grad[i] * n_weight; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..c49bf25fce44103fc272a99266b649fff1f5c8d8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_NLLLOSS_GRAD_FP32_H_ +#define NNACL_FP32_GRAD_NLLLOSS_GRAD_FP32_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +int NLLLossGrad(const float *logits, const float *loss_grad, const int *labels, const float *weight, + const float *total_weight, float *logits_grad, int batch, int class_num, ReductionType reduction_type); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_NLLLOSS_GRAD_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/optimizer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..835c40d60e606b6c94f9fb11887041ba393230e7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/optimizer.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_OPTIMIZER_H_ +#define NNACL_FP32_GRAD_OPTIMIZER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct { + OpParameter op_parameter_; + bool use_nesterov_; + float grad_scale_; +} ApplyMomentumParameter; + +typedef struct { + OpParameter op_parameter_; + float dampening_; + bool use_nesterov_; + float weight_decay_; +} SgdParameter; + +typedef struct { + OpParameter op_parameter_; + bool use_nesterov_; +} AdamParameter; + +#endif // NNACL_FP32_GRAD_OPTIMIZER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pack_ext.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pack_ext.c new file mode 100644 index 0000000000000000000000000000000000000000..235b859936259c976e1f4688e22fa031e1ccc03b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pack_ext.c @@ -0,0 +1,301 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/fp32_grad/pack_ext.h" + +void RollingIm2ColPackDwUnitFp32(const float *in_data, const ConvParameter *conv_param, float *data_col_orig, + int real_cal_num, int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + + const int channels = conv_param->input_channel_; + const int stride = kernel_h * kernel_w; + + int kernel_row, kernel_col; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = start + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + float *data_col = data_col_orig + i * channels * stride; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * channels; + for (int c = 0; c < channels; c++) { + data_col[c * stride] = in_data[offset + c]; + } + data_col++; + } else { + for (int c = 0; c < channels; c++) { + data_col[c * stride] = 0; + } + data_col++; + } + } + } + } +} + +void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int real_cal_num, + int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col; + + if (channels == 1) { + for (int i = 0; i < real_cal_num; i++) { + int block_start = start + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels; + *data_col = in_data[offset]; + data_col++; + } else { + *data_col = 0; + data_col++; + } + } + } + } + } else { + for (int i = 0; i < real_cal_num; i++) { + int block_start = start + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels; + memcpy(data_col, in_data + offset, sizeof(float) * channels); + data_col += channels; + } else { + memset(data_col, 0, sizeof(float) * channels); + data_col += channels; + } + } + } + } + } +} + +void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index) { + rolling_im2col_hwc(input_data, packed_input, conv_param, real_cal_num, block_index); +} + +void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->output_h_; + const int in_width = conv_param->output_w_; + + const int output_w = conv_param->input_w_; + + const int tot_channels = conv_param->output_channel_; + const int channels = tot_channels / conv_param->group_; + int channel, kernel_row, kernel_col, output_rows, output_col; + for (channel = 0; channel < channels; channel++) { + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + for (output_rows = start; output_rows < start + rows; output_rows++) { + int input_row = -pad_up + kernel_row * dilation_h + output_rows * stride_h; + if (!((unsigned)(input_row) < (unsigned)(in_height))) { + for (output_col = output_w; output_col; output_col--) { + *(data_row++) = 0; + } + } else { + int input_col = -pad_left + kernel_col * dilation_w; + for (output_col = output_w; output_col; output_col--) { + if (((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels + channel; + *(data_row++) = in_data[offset]; + } else { + *(data_row++) = 0; + } + input_col += stride_w; + } + } + } + } + } + } +} + +void col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_h = conv_param->output_h_; + const int output_w = conv_param->output_w_; + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col, output_rows, output_col; + + int row_stride_offset = 0; + + for (output_rows = output_h; output_rows; output_rows--) { + int col_stride_offset = 0; + for (output_col = output_w; output_col; output_col--) { + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float *data_im_ptr = data_im + offset; + for (int i = 0; i < channels; i++) { + data_im_ptr[i] += data_col[i]; + } + } + data_col += channels; + } + } + col_stride_offset += stride_w; + } + row_stride_offset += stride_h; + } +} + +void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param, int rows, int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col; + + if (channels == 1) { + for (int r = 0; r < rows; r++) { + int output_col = (start + r) % output_w; + int output_row = (start + r) / output_w; + int row_stride_offset = output_row * stride_h; + int col_stride_offset = output_col * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float *data_im_ptr = data_im + offset; + *data_im_ptr += *data_col; + } + data_col++; + } + } + } + } else { + for (int r = 0; r < rows; r++) { + int output_col = (start + r) % output_w; + int output_row = (start + r) / output_w; + int row_stride_offset = output_row * stride_h; + int col_stride_offset = output_col * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float *data_im_ptr = &data_im[offset]; + for (int i = 0; i < channels; i++) { + data_im_ptr[i] += data_col[i]; + } + } + data_col += channels; + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pack_ext.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pack_ext.h new file mode 100644 index 0000000000000000000000000000000000000000..a29ae2047436803d4cd037ecbadff36546ce3250 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pack_ext.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_PACK_EXT_H_ +#define NNACL_FP32_GRAD_PACK_EXT_H_ + +#include +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index); +void RollingIm2ColPackDwUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index); + +void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int rows, int start); +void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start); +void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param, int rows, int start); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_PACK_EXT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pooling_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pooling_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..2ba7604a65780740f58093ca2f3fa5016673f934 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pooling_grad.c @@ -0,0 +1,190 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32_grad/pooling_grad.h" +#include +#include +#include +#include "nnacl_c/op_base.h" + +void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + + const float kk = 1.0f / (float)(win_h * win_w); +#if ENABLE_ARM + const float32x4_t factor = vdupq_n_f32(kk); +#endif + for (int ib = 0; ib < count; ib++) { + float *out = output_ptr + ib * in_h * in_w * channel; + const float *inPtr = input_ptr + ib * output_h * output_w * channel; + // iterate over yt + for (int yh = 0; yh < output_h; yh++) { + int over_h = pad_h - yh * stride_h; + int kh_s = MSMAX(0, over_h); + int kh_e = MSMIN(win_h, in_h + over_h); + for (int yw = 0; yw < output_w; yw++) { + int over_w = pad_w - yw * stride_w; + int kw_s = MSMAX(0, over_w); + int kw_e = MSMIN(win_w, in_w + over_w); + int ic = 0; + for (; ic < channel - C4NUM; ic += C4NUM) { + int idx = (yw + yh * output_w) * channel + ic; +#ifdef ENABLE_ARM + float32x4_t in = vld1q_f32(inPtr + idx); + float32x4_t delta = vmulq_f32(in, factor); +#else + float delta[C4NUM] = {inPtr[idx], inPtr[idx + C1NUM], inPtr[idx + C2NUM], inPtr[idx + C3NUM]}; + for (int i = 0; i < C4NUM; i++) delta[i] *= kk; +#endif + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; +#ifdef ENABLE_ARM + float *out_vec = out + (xw + in_w * xh) * channel + ic; + float32x4_t outr = vld1q_f32(out + (xw + in_w * xh) * channel + ic); + float32x4_t outs = vaddq_f32(outr, delta); + vst1q_f32(out_vec, outs); +#else + + for (int i = 0; i < C4NUM; i++) { + out[(xw + in_w * xh) * channel + ic + i] += ((float *)&delta)[i]; + } +#endif + } + } + } + for (; ic < channel; ic++) { + int idx = (yw + yh * output_w) * channel + ic; + float delta = inPtr[idx] * kk; + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + out[(xw + in_w * xh) * channel + ic] += delta; + } + } + } + } + } + } +} + +#ifdef ENABLE_ARM +static int32x4_t MaxIndex(float32x4_t in, float32x4_t *max, int32x4_t index, int32x4_t prev_index) { + uint32x4_t res = vcgtq_f32(in, *max); + int32x4_t m_index = vbslq_s32(res, index, prev_index); + *max = vbslq_f32(res, in, *max); + return m_index; +} +#endif + +void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_ptr, int output_batch, + const PoolingParameter *pooling_param, const PoolingComputeParam *pooling_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + for (int ib = 0; ib < output_batch; ib++) { + float *out = output_ptr + ib * in_h * in_w * channel; + const float *inPtr = input_ptr + ib * in_h * in_w * channel; + const float *dyPtr = dy_ptr + ib * output_h * output_w * channel; + for (int yh = 0; yh < output_h; yh++) { + int over_h = pad_h - yh * stride_h; + int kh_s = MSMAX(0, over_h); + int kh_e = MSMIN(win_h, in_h + over_h); + for (int yw = 0; yw < output_w; yw++) { + int over_w = pad_w - yw * stride_w; + int kw_s = MSMAX(0, over_w); + int kw_e = MSMIN(win_w, in_w + over_w); + int ic = 0; + for (; ic <= channel - C4NUM; ic += C4NUM) { + int idx = (yw + yh * output_w) * channel + ic; +#ifdef ENABLE_ARM + uint32x4_t max_idx = vdupq_n_u32(0); + float32x4_t max_val = vdupq_n_f32(-FLT_MAX); + float32x4_t delta = vld1q_f32(dyPtr + idx); +#else + float delta[C4NUM] = {dyPtr[idx], dyPtr[idx + C1NUM], dyPtr[idx + C2NUM], dyPtr[idx + C3NUM]}; + float max_val[C4NUM] = {-FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX}; + int max_idx[C4NUM] = {0}; +#endif + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + int val_idx = (xw + in_w * xh) * channel + ic; +#ifdef ENABLE_ARM + uint32x4_t index = {val_idx, val_idx + 1, val_idx + 2, val_idx + 3}; + float32x4_t in = vld1q_f32(inPtr + val_idx); + max_idx = vreinterpretq_u32_s32( + MaxIndex(in, &max_val, vreinterpretq_s32_u32(index), vreinterpretq_s32_u32(max_idx))); +#else + float val[C4NUM] = {inPtr[val_idx], inPtr[val_idx + C1NUM], inPtr[val_idx + C2NUM], + inPtr[val_idx + C3NUM]}; + for (int i = 0; i < C4NUM; i++) { + if (val[i] > max_val[i]) { + max_val[i] = val[i]; + max_idx[i] = val_idx + i; + } + } +#endif + } + } + for (int i = 0; i < C4NUM; i++) { + out[((int *)&max_idx)[i]] += ((float *)&delta)[i]; + } + } + for (; ic < channel; ic++) { + float max_val = -FLT_MAX; + int max_idx = 0; + int idx = (yw + yh * output_w) * channel + ic; + float delta = dyPtr[idx]; + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + int val_idx = (xw + in_w * xh) * channel + ic; + float val = inPtr[val_idx]; + if (val > max_val) { + max_val = val; + max_idx = val_idx; + } + } + } + out[max_idx] += delta; + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pooling_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pooling_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..7f1e12ebf447453ac914ef6ce0ccf5717083430b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pooling_grad.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_POOLING_GRAD_H_ +#define NNACL_FP32_GRAD_POOLING_GRAD_H_ + +#include "nnacl_c/fp32/pooling_fp32.h" +#include "nnacl_c/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args); +void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_ptr, int output_batch, + const PoolingParameter *pooling_param, const PoolingComputeParam *pooling_args); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_POOLING_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/reduce_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/reduce_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..11d670a38d26f033421b4dc4f186cb796a013a3e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/reduce_grad.c @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32_grad/reduce_grad.h" +#include "nnacl_c/fp32_grad/utils.h" +#include "nnacl_c/op_base.h" + +void ReduceMeanByAxes(const float *input_data, int *input_iter, const int *input_dims, int input_num_dims, + const int *axes, int num_axes, float *output_data, const int *output_dims, int output_num_dims) { + size_t num_outputs = 1; + for (int idx = 0; idx < output_num_dims; ++idx) { + size_t current = (size_t)(output_dims[idx]); + num_outputs *= current; + } + + // Reset input iterator. + for (int idx = 0; idx < input_num_dims; ++idx) { + input_iter[idx] = 0; + } + // Iterate through input_data. + do { + size_t input_offset = GetInputOffset(input_num_dims, input_dims, input_iter); + size_t output_offset = GetOutputOffset(input_num_dims, input_dims, input_iter, num_axes, axes); + output_data[output_offset] += input_data[input_offset]; + } while (NextIndex(input_num_dims, input_dims, input_iter)); + + // Calculate mean by dividing output_data by num of aggregated element. + size_t num_elements_in_axis = 1; + for (int idx = 0; idx < num_axes; ++idx) { + size_t current = (size_t)(input_dims[axes[idx]]); + num_elements_in_axis *= current; + } + + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = output_data[idx] / (float)(num_elements_in_axis); + } +} + +float ReduceMeanAll(const float *src, int size) { + float sum = 0; + for (int i = 0; i < size; ++i) { + sum += src[i]; + } + return sum / size; +} + +void ReduceSumByAxes(const float *input, const int *input_dims, float *output, const int *output_dims, int num_dims) { + int num_outputs = 1; + int same_shape = 1; + for (int idx = 0; idx < num_dims; ++idx) { + num_outputs *= output_dims[idx]; + if (output_dims[idx] != input_dims[idx]) same_shape = 0; + } + if (same_shape) { + memcpy(output, input, (size_t)(num_outputs) * sizeof(float)); + return; + } + + memset(output, 0, (size_t)(num_outputs) * sizeof(float)); // zero output + + int input_iter[C8NUM] = {0}; + int axes[C5NUM] = {0}; + int num_axes = 0; + for (int i = 0; i < num_dims; i++) { + if (output_dims[i] == C1NUM && num_axes < C5NUM) { + axes[num_axes++] = i; + } + } + + // Iterate through input_data. + do { + size_t input_offset = GetInputOffset(num_dims, input_dims, input_iter); + size_t output_offset = GetOutputOffset(num_dims, input_dims, input_iter, num_axes, axes); + output[output_offset] += input[input_offset]; + } while (NextIndex(num_dims, input_dims, input_iter)); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/reduce_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/reduce_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..edb610254cc86fc7d60a4ba58ae2cdf827224e15 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/reduce_grad.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_REDUCE_GRAD_H_ +#define NNACL_FP32_GRAD_REDUCE_GRAD_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif +float ReduceMeanAll(const float *src, int size); +void ReduceSumByAxes(const float *input, const int *input_dims, float *output, const int *output_dims, int num_dims); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_GRAD_REDUCE_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..678abd47fc2fdfbcefae2a73c25e9ff283f7aa9d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad.c @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32_grad/resize_grad.h" +#include +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/errorcode.h" + +int ResizeNearestNeighborGrad(const float *in_addr, float *out_addr, int batch_size, int channel, int format, + const ResizeGradParameter *param) { + bool align_corners = param->align_corners_; + size_t in_hw_size = param->in_width_ * param->in_height_; + size_t out_hw_size = param->out_width_ * param->out_height_; + + if (format == Format_NHWC) { + NNACL_CHECK_ZERO_RETURN_ERR(param->in_width_); + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t i = 0; i < in_hw_size; ++i) { + size_t in_y = i / param->in_width_; + size_t in_x = i % param->in_width_; + for (size_t c = 0; c < (size_t)channel; ++c) { + size_t out_y = MSMIN( + (align_corners) ? (size_t)roundf(in_y * param->height_scale_) : (size_t)floorf(in_y * param->height_scale_), + param->out_height_ - 1); + size_t out_x = MSMIN( + (align_corners) ? (size_t)roundf(in_x * param->width_scale_) : (size_t)floorf(in_x * param->width_scale_), + param->out_width_ - 1); + size_t out_offset = out_y * (param->out_width_ * channel) + (out_x * channel) + c; + size_t in_offset = in_y * (param->in_width_ * channel) + (in_x * channel) + c; + out_addr[out_offset] += in_addr[in_offset]; + } + } + out_addr += out_hw_size * channel; + in_addr += in_hw_size * channel; + } + } else if (format == Format_NCHW) { + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t c = 0; c < (size_t)channel; ++c) { + for (size_t h = 0; h < param->in_height_; ++h) { + size_t out_y = + MSMIN((align_corners) ? (size_t)roundf(h * param->height_scale_) : (size_t)floorf(h * param->height_scale_), + param->out_height_ - 1); + for (size_t w = 0; w < param->in_width_; ++w) { + size_t out_x = + MSMIN((align_corners) ? (size_t)roundf(w * param->width_scale_) : (size_t)floorf(w * param->width_scale_), + param->out_width_ - 1); + out_addr[out_y * param->out_width_ + out_x] += in_addr[h * param->in_width_ + w]; + } + } + out_addr += out_hw_size; + in_addr += in_hw_size; + } + } + } + return NNACL_OK; +} + +int ResizeBiLinearGrad(const float *in_addr, float *out_addr, int batch_size, int channel, int format, + const ResizeGradParameter *param) { + size_t in_hw_size = param->in_width_ * param->in_height_; + size_t out_hw_size = param->out_width_ * param->out_height_; + + if (format == Format_NHWC) { + NNACL_CHECK_ZERO_RETURN_ERR(param->in_width_); + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t i = 0; i < in_hw_size; ++i) { + size_t h = i / param->in_width_; + size_t w = i % param->in_width_; + for (size_t c = 0; c < (size_t)channel; ++c) { + float in_y = (float)h * param->height_scale_; + size_t top_y_index = MSMAX((size_t)(floorf(in_y)), (size_t)(0)); + size_t bottom_y_index = MSMIN((size_t)(ceilf(in_y)), param->out_height_ - 1); + float y_lerp = in_y - floorf(in_y); + const float inverse_y_lerp = 1.0 - y_lerp; + + float in_x = (float)w * param->width_scale_; + size_t left_x_index = MSMAX((size_t)(floorf(in_x)), (size_t)(0)); + size_t right_x_index = MSMIN((size_t)(ceilf(in_x)), param->out_width_ - 1); + float x_lerp = in_x - floorf(in_x); + const float inverse_x_lerp = 1.0 - x_lerp; + + size_t in_offset = h * (param->in_width_ * channel) + (w * channel) + c; + size_t out_offset_top_y_left_x = top_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; + size_t out_offset_top_y_right_x = top_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; + size_t out_offset_bottom_y_left_x = + bottom_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; + size_t out_offset_bottom_y_right_x = + bottom_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; + + out_addr[out_offset_top_y_left_x] += in_addr[in_offset] * (float)(inverse_y_lerp * inverse_x_lerp); + out_addr[out_offset_top_y_right_x] += in_addr[in_offset] * (float)(inverse_y_lerp * x_lerp); + out_addr[out_offset_bottom_y_left_x] += in_addr[in_offset] * (float)(y_lerp * inverse_x_lerp); + out_addr[out_offset_bottom_y_right_x] += in_addr[in_offset] * (float)(y_lerp * x_lerp); + } + } + out_addr += out_hw_size * channel; + in_addr += in_hw_size * channel; + } + } else if (format == Format_NCHW) { + size_t in_height = param->in_height_; + size_t in_width = param->in_width_; + size_t out_height = param->out_height_; + size_t out_width = param->out_width_; + out_hw_size = out_height * out_width; + in_hw_size = in_height * in_width; + + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t c = 0; c < (size_t)channel; ++c) { + for (size_t h = 0; h < in_height; ++h) { + const float in_y = (float)(h)*param->height_scale_; + const size_t top_y_index = MSMAX((size_t)floorf(in_y), 0); + const size_t bottom_y_index = MSMIN((size_t)ceilf(in_y), out_height - 1); + const float y_lerp = in_y - floorf(in_y); + const float inverse_y_lerp = 1.0 - y_lerp; + for (size_t w = 0; w < in_width; ++w) { + const float in_x = (float)(w)*param->width_scale_; + const size_t left_x_index = MSMAX((size_t)floorf(in_x), 0); + const size_t right_x_index = MSMIN((size_t)ceilf(in_x), out_width - 1); + const float x_lerp = in_x - floorf(in_x); + const float inverse_x_lerp = 1.0 - x_lerp; + out_addr[top_y_index * out_width + left_x_index] += + in_addr[h * in_width + w] * (float)(inverse_y_lerp * inverse_x_lerp); + out_addr[top_y_index * out_width + right_x_index] += + in_addr[h * in_width + w] * (float)(inverse_y_lerp * x_lerp); + out_addr[bottom_y_index * out_width + left_x_index] += + in_addr[h * in_width + w] * (float)(y_lerp * inverse_x_lerp); + out_addr[bottom_y_index * out_width + right_x_index] += + in_addr[h * in_width + w] * (float)(y_lerp * x_lerp); + } + } + out_addr += out_hw_size; + in_addr += in_hw_size; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..b0f65a605abb91e8c2417cec10d78d7e926079ab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_RESIZE_GRAD_H_ +#define NNACL_FP32_GRAD_RESIZE_GRAD_H_ + +#include "nnacl_c/fp32_grad/resize_grad_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ResizeNearestNeighborGrad(const float *in_addr, float *out_addr, int batch_size, int channel, int format, + const ResizeGradParameter *param); +int ResizeBiLinearGrad(const float *in_addr, float *out_addr, int batch_size, int channel, int format, + const ResizeGradParameter *param); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_GRAD_RESIZE_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..2d0b46296519b2389a896dfa57edaa5e3439b53f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad_parameter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_RESIZE_PARAMETER_GRAD_H_ +#define NNACL_FP32_GRAD_RESIZE_PARAMETER_GRAD_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ResizeGradParameter { + OpParameter op_parameter_; + bool align_corners_; + int method; + size_t in_height_; + size_t in_width_; + size_t out_height_; + size_t out_width_; + float height_scale_; + float width_scale_; +} ResizeGradParameter; + +#endif // NNACL_FP32_GRAD_RESIZE_PARAMETER_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/smooth_l1_loss.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/smooth_l1_loss.h new file mode 100644 index 0000000000000000000000000000000000000000..7c46b4b5b4ddb432d3217505f04201c17566c02b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/smooth_l1_loss.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_SMOOTH_L1_LOSS_PARAMETER_H_ +#define NNACL_FP32_SMOOTH_L1_LOSS_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct { + OpParameter op_parameter_; + float beta_; +} SmoothL1LossParameter; + +#endif // NNACL_FP32_SMOOTH_L1_LOSS_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.c new file mode 100644 index 0000000000000000000000000000000000000000..594ef6de071e8871df061f3b80975d9f31b3b794 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.c @@ -0,0 +1,43 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.h" +#include + +void ForwardPostExecute(const float *labels, const float *logits, float *grads, float *output2, + size_t number_of_classes, int batch_size) { + float eps = 1e-6; + if (grads != NULL) { + for (size_t i = 0; i < (size_t)(batch_size); ++i) { + float loss = 0.f; + for (size_t j = 0; j < number_of_classes; ++j) { + float logit = -logf(logits[i * number_of_classes + j] <= 0.0 ? eps : logits[i * number_of_classes + j]); + grads[i * number_of_classes + j] = (logits[i * number_of_classes + j] - labels[i * number_of_classes + j]); + loss += labels[i * number_of_classes + j] * logit; + } + output2[i] = loss; + } + } else { + for (size_t i = 0; i < (size_t)(batch_size); ++i) { + float loss = 0.f; + for (size_t j = 0; j < number_of_classes; ++j) { + float logit = -logf(logits[i * number_of_classes + j] <= 0.0 ? eps : logits[i * number_of_classes + j]); + loss += labels[i * number_of_classes + j] * logit; + } + output2[i] = loss; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.h new file mode 100644 index 0000000000000000000000000000000000000000..7cd53bc4a9d4214f3f39c716635333a3b1f35019 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ +#define NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +void ForwardPostExecute(const float *labels, const float *logits, float *grads, float *output2, + size_t number_of_classes, int batch_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_crossentropy_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_crossentropy_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..57a030698d8bbdeceb282a213f30f2f169ade200 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_crossentropy_parameter.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_SOFTMAX_CROSSENTROPY_PARAMETER_H_ +#define NNACL_FP32_GRAD_SOFTMAX_CROSSENTROPY_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct SoftmaxCrossEntropyParameter { + // primitive parameter + OpParameter op_parameter_; + int n_dim_; + + // shape correlative + int input_shape_[5]; + + // other parameter + int32_t batch_size_; + unsigned int number_of_classes_; + bool is_grad_; +} SoftmaxCrossEntropyParameter; + +#endif // NNACL_FP32_GRAD_SOFTMAX_CROSSENTROPY_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..bdc6e75081bb621b9796c436c75a4da1aadeb9ac --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad.c @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/softmax_grad.h" +#include + +void SoftmaxGrad(const float *input_ptr, const float *yt_ptr, float *output_ptr, float *sum_data, float *sum_mul, + const int *input_shape, int n_dim, int ele_size, int32_t axis) { + int dim = 1; + int inner_size = 1, outter_size = 1; + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + NNACL_CHECK_ZERO_RETURN(outter_size); + for (int i = 0; i < inner_size * input_shape[axis]; i++) sum_mul[i] = 1.0; + for (int i = 0; i < n_dim; i++) dim *= input_shape[i]; + dim /= outter_size; + memcpy(output_ptr, yt_ptr, (size_t)(ele_size) * sizeof(float)); + + const int M = input_shape[axis]; + const int N = inner_size; + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * dim; + memset(sum_data, 0, (size_t)(inner_size) * sizeof(float)); + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + for (int j = 0; j < input_shape[axis]; j++) { + int offset = inner_offset + j * inner_size; + sum_data[k] += output_ptr[offset] * input_ptr[offset]; + } + } + for (int k = 0; k < M; ++k) { + float a = -sum_mul[k]; + for (int j = 0; j < N; ++j) { + *(output_ptr + outter_offset + k * N + j) += a * sum_data[j]; + } + } + } + + for (int i = 0; i < ele_size; i++) { + output_ptr[i] *= input_ptr[i]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..69f9b3312f4e442852c2345bc26c746a30719638 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_SOFTMAX_GRAD_H_ +#define NNACL_FP32_GRAD_SOFTMAX_GRAD_H_ + +#include "nnacl_c/fp32/softmax_fp32.h" +#include "nnacl_c/fp32_grad/softmax_crossentropy_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void SoftmaxGrad(const float *input_ptr, const float *yt_ptr, float *output_ptr, float *sum_data, float *sum_mul, + const int *input_shape, int n_dim, int ele_size, int32_t axis); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_SOFTMAX_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad_utils.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad_utils.c new file mode 100644 index 0000000000000000000000000000000000000000..24c62ea373c6d56dbf9d3c9dcd5ad4c71b2ff42b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad_utils.c @@ -0,0 +1,102 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/softmax_grad_utils.h" +#include +#include +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void ExpFp32Offset(const float *src, float *dst, float sub_bias, int num) { + int i = 0; +#ifdef ENABLE_ARM64 + int count = (num / C4NUM) * C4NUM; + for (; i < count; i += C4NUM) { + MS_FLOAT32X4 input = vld1q_f32(src + i); + MS_FLOAT32X4 bias = vdupq_n_f32(sub_bias); + MS_FLOAT32X4 i1 = vsubq_f32(input, bias); + simd_exp128(i1, dst + i); + } +#endif + for (; i < num; ++i) { + simd_exp32(src[i] - sub_bias, dst + i); + } +} + +// output = exp(input) / reduce_sum(exp(input), axis) +static void SoftMaxP1Simple(const float *input_ptr, float *output_ptr, float *sum_data, int start, int count, + int length) { + for (int i = start; i < start + count; i++) { + int inner_offset = i * length; + float max_data = input_ptr[inner_offset]; + for (int j = 0; j < length; j++) { + int axis_offset = inner_offset + j; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + ExpFp32Offset(input_ptr + inner_offset, output_ptr + inner_offset, max_data, length); + float _sum_data = 0; + for (int j = 0; j < length; j++) { + int axis_offset = inner_offset + j; + _sum_data += output_ptr[axis_offset]; + } + sum_data[i] = _sum_data; + } +} + +void SoftMaxP1(const float *input_ptr, float *output_ptr, float *sum_data, int start, int count, int length, + int inner_size) { + if (inner_size == 1) { + SoftMaxP1Simple(input_ptr, output_ptr, sum_data, start, count, length); + return; + } + for (int i = start; i < start + count; i++) { + int outter_offset = i * length * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + float max_data = input_ptr[inner_offset]; + for (int j = 0; j < length; j++) { + int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < length; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = exp(input_ptr[axis_offset] - max_data); + } + float _sum_data = 0; + for (int j = 0; j < length; j++) { + int axis_offset = inner_offset + j * inner_size; + _sum_data += output_ptr[axis_offset]; + } + sum_data[k + sum_outter_offset] = _sum_data; + } + } +} + +void SoftMaxP2(const float *input_ptr, float *output_ptr, const float *sum_data, int start, int count, int length, + int inner_size) { + for (int i = start; i < start + count; i++) { + int outter_offset = i * length * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < length; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[k + sum_outter_offset]; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad_utils.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..68aac00d02245d7aeba0fd48e4a2653201a3925b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad_utils.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_SOFTMAX_H_ +#define NNACL_FP32_GRAD_SOFTMAX_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +void SoftMaxP1(const float *input_ptr, float *output_ptr, float *sum_data, int start, int count, int length, + int inner_size); +void SoftMaxP2(const float *input_ptr, float *output_ptr, const float *sum_data, int start, int count, int length, + int inner_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_SOFTMAX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/strided_slice_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/strided_slice_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..4a9fe7659c5d62819e03c965780dff2833d2f891 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/strided_slice_grad.c @@ -0,0 +1,68 @@ +/** + * Copyright 2019-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/strided_slice_grad.h" +#include "nnacl_c/errorcode.h" + +static size_t CalcIndex(const int *shape, size_t size, int i, size_t pos) { + size_t res = 1; + for (size_t j = 0; j < size; j++) { + res *= shape[((size_t)(i) + 1) + j]; + } + NNACL_CHECK_ZERO_RETURN_ERR(res); + NNACL_CHECK_ZERO_RETURN_ERR(shape[i]); + return (pos / res % shape[i]); +} + +int DoStridedSliceGrad(const float *inputs, float *output, const int *dx_shape, const StridedSliceParameter *param) { + if (inputs == NULL || output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->num_axes_ > DIMENSION_8D) { + return NNACL_PARAM_INVALID; + } + + size_t size = 1; + const int *s = param->strides_; + const int *b = param->begins_; + for (int i = 0; i < DIMENSION_8D; i++) { + size *= (size_t)(param->in_shape_[i]); + } + + for (size_t pos = 0; pos < size; pos++) { + size_t i = CalcIndex(param->in_shape_, C7NUM, C0NUM, pos); + size_t j = CalcIndex(param->in_shape_, C6NUM, C1NUM, pos); + size_t k = CalcIndex(param->in_shape_, C5NUM, C2NUM, pos); + size_t l = CalcIndex(param->in_shape_, C4NUM, C3NUM, pos); + size_t m = CalcIndex(param->in_shape_, C3NUM, C4NUM, pos); + size_t n = CalcIndex(param->in_shape_, C2NUM, C5NUM, pos); + size_t o = CalcIndex(param->in_shape_, C1NUM, C6NUM, pos); + size_t p = CalcIndex(param->in_shape_, C0NUM, C7NUM, pos); + size_t input_idx = + (i * s[C0NUM] + b[C0NUM]) * dx_shape[C1NUM] * dx_shape[C2NUM] * dx_shape[C3NUM] * dx_shape[C4NUM] * + dx_shape[C5NUM] * dx_shape[C6NUM] * dx_shape[C7NUM] + + (j * s[C1NUM] + b[C1NUM]) * dx_shape[C2NUM] * dx_shape[C3NUM] * dx_shape[C4NUM] * dx_shape[C5NUM] * + dx_shape[C6NUM] * dx_shape[C7NUM] + + (k * s[C2NUM] + b[C2NUM]) * dx_shape[C3NUM] * dx_shape[C4NUM] * dx_shape[C5NUM] * dx_shape[C6NUM] * + dx_shape[C7NUM] + + (l * s[C3NUM] + b[C3NUM]) * dx_shape[C4NUM] * dx_shape[C5NUM] * dx_shape[C6NUM] * dx_shape[C7NUM] + + (m * s[C4NUM] + b[C4NUM]) * dx_shape[C5NUM] * dx_shape[C6NUM] * dx_shape[C7NUM] + + (n * s[C5NUM] + b[C5NUM]) * dx_shape[C6NUM] * dx_shape[C7NUM] + (o * s[C6NUM] + b[C6NUM]) * dx_shape[C7NUM] + + (p * s[C7NUM] + b[C7NUM]); + output[input_idx] = inputs[pos]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/strided_slice_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/strided_slice_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..e779e6de22bfdcb2b42c2a8bb86382fc4828bfea --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/strided_slice_grad.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_ +#define NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/strided_slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoStridedSliceGrad(const float *inputs, float *output, const int *dx_shape, const StridedSliceParameter *param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/utils.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..4798715f5959f9952c7d14fe9855b5a49191c441 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/utils.h @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_UTILS_H_ +#define NNACL_FP32_GRAD_UTILS_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +static inline size_t GetInputOffset(int num_dims, const int *dims, const int *iter) { + size_t offset = 0; + for (int idx = 0; idx < num_dims; ++idx) { + offset = offset * (size_t)(dims[idx]) + (size_t)(iter[idx]); + } + + return offset; +} + +static inline size_t GetOutputOffset(int num_dims, const int *dims, const int *iter, int num_axis, const int *axes) { + size_t offset = 0; + for (int idx = 0; idx < num_dims; ++idx) { + // if we need to skip this axis + int is_axis = 0; + for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { + if (idx == axes[axis_idx]) { + is_axis = 1; + break; + } + } + + if (is_axis == 0) { + offset = offset * (size_t)(dims[idx]) + (size_t)(iter[idx]); + } + } + return offset; +} + +static inline int NextIndex(int num_dims, const int *dims, int *current) { + int carry = 1; + for (int idx = num_dims - 1; idx >= 0; --idx) { + int current_val = current[idx] + carry; + if (dims[idx] == current_val) { + current[idx] = 0; + } else { + current[idx] = current_val; + carry = 0; + break; + } + } + return (carry == 0); +} + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_UTILS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..53c2fe9223baaf34889b13a5e5b4ba8ed8e32b09 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.c @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.h" +#ifdef ENABLE_ARM64 +#include +#endif + +void MatMulSparse8x8(const float *a, const float *b, const uint32_t *nnz, const size_t *dmap, float *c, + const float *bias, ActType act_type, int out_stride) { +#ifndef ENABLE_ARM64 + return; +#else + // mul-acc + for (int oc = 0; oc < 8; oc++) { + uint32_t cur_nnz = nnz[oc]; + // init 8x1 C with bias + float32x4_t vacc1 = vld1q_dup_f32(bias + oc); + float32x4_t vacc2 = vacc1; + for (uint32_t nz = 0; nz < cur_nnz; nz++) { + // load w + float32x4_t vw = vld1q_dup_f32(b++); + // load 8 inputs + const float *input = a + (*(dmap++) / sizeof(float)); + float32x4_t vi1 = vld1q_f32(input); + float32x4_t vi2 = vld1q_f32(input + 4); + vacc1 = vfmaq_f32(vacc1, vi1, vw); + vacc2 = vfmaq_f32(vacc2, vi2, vw); + } + // save output + *(c + oc) = vacc1[0]; + *(c + 1 * out_stride + oc) = vacc1[1]; + *(c + 2 * out_stride + oc) = vacc1[2]; + *(c + 3 * out_stride + oc) = vacc1[3]; + *(c + 4 * out_stride + oc) = vacc2[0]; + *(c + 5 * out_stride + oc) = vacc2[1]; + *(c + 6 * out_stride + oc) = vacc2[2]; + *(c + 7 * out_stride + oc) = vacc2[3]; + } +#endif +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..d95d235c0c89e732d1b82578ee7e7d6e63ea1c98 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_MATMUL_SPARSE_X1_H_ +#define NNACL_FP32_MATMUL_SPARSE_X1_H_ + +#include +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef ENABLE_ARM64 +void SPMM8x8Fp32(const float *a, const float *b, const uint32_t *nnz, const size_t *dmap, float *c, const float *bias, + ActType act_type, size_t out_stride); +#endif + +void MatMulSparse8x8(const float *a, const float *b, const uint32_t *nnz, const size_t *dmap, float *c, + const float *bias, ActType act_type, int out_stride); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_MATMUL_SPARSE_X1_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gather_nd_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gather_nd_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..b52fa90d0430dc0a6acf7a9bad00b66e4f5f70a8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gather_nd_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_GATHER_ND_PARAMETER_H_ +#define NNACL_GATHER_ND_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct { + OpParameter op_parameter_; +} GatherNdParameter; + +#endif // NNACL_GATHER_ND_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gather_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gather_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..8b9d729c4c25ea9558015c35b6f59e010d9d97f0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gather_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_GATHER_PARAMETER_H_ +#define NNACL_GATHER_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct GatherParameter { + // Primitive parameter + OpParameter op_parameter_; + int axis_; +} GatherParameter; + +#endif // NNACL_GATHER_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gelu_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gelu_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..fa1e4ca87346346823d388c0a404924c089e3eae --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gelu_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_GELU_PARAMETER_H_ +#define NNACL_GELU_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct GeLUParameter { + // Primitive parameter + OpParameter op_parameter_; + bool approximate_; +} GeLUParameter; + +#endif // NNACL_GELU_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/glu_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/glu_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..b01912a9804ca91cff8f1def654b49ff9cda03a6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/glu_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_GLU_PARAMETER_H_ +#define NNACL_GLU_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct GluParameter { + OpParameter op_parameter_; + int axis_; +} GluParameter; + +#endif // NNACL_GLU_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/grid_sampler_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/grid_sampler_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..422162a1be6f75c4cfe1413849ccb04eaa781999 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/grid_sampler_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_GRID_SAMPLER_PARAMETER_H_ +#define NNACL_GRID_SAMPLER_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct GridSamplerParameter { + OpParameter op_parameter_; + int64_t interpolation_mode_; + int64_t padding_mode_; + bool align_corners_; +} GridSamplerParameter; + +#endif // NNACL_GRID_SAMPLER_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/group_norm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/group_norm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..a733d1b22b6b089c36ad8b02422ce89a263b5b89 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/group_norm_parameter.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_GROUP_NORM_PARAMETER_H_ +#define NNACL_GROUP_NORM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" +typedef struct GroupNormParameter { + // Primitive parameter + OpParameter op_parameter_; + float epsilon_; + int num_groups_; + int channel_; + int unit_; + int batch_; + bool affine_; + void *mean_; + void *variance_; +} GroupNormParameter; + +typedef struct GroupNormQuantArg { + int32_t in_zp_; + int32_t out_zp_; + double in_scale_; + double out_scale_; +} GroupNormQuantArg; + +#endif // NNACL_GROUP_NORM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gru_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gru_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..36c84e2fb4feda379e5d51b1191f10dbd0069719 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gru_parameter.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_GRU_PARAMETER_H_ +#define NNACL_GRU_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct GruParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int input_size_; + int hidden_size_; // output_size + int seq_len_; + int batch_; + // other parameter + int output_step_; + bool bidirectional_; + int input_row_align_; + int input_col_align_; + int state_row_align_; + int state_col_align_; +} GruParameter; + +#endif // NNACL_GRU_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/activation_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/activation_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..4ee74d682a0c5185e16f1701d2864edbbefa257a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/activation_grad_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/activation_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int ActivationGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + const TensorC *input = inputs[0]; + const TensorC *input_grad = inputs[1]; + if (input->shape_size_ != input_grad->shape_size_) { + return NNACL_ERR; + } + for (size_t i = 0; i < input->shape_size_; i++) { + if (input->shape_[i] != input_grad->shape_[i]) { + return NNACL_ERR; + } + } + + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +REG_INFER(ActivationGrad, PrimType_ActivationGrad, ActivationGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/activation_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/activation_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..cb128bd57963c790b65630e0101c57da313d8c54 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/activation_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ACTIVATION_GRAD_INFER_H +#define MINDSPORE_NNACL_ACTIVATION_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ActivationGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ACTIVATION_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..ebc7ba14272967a2b464cfeab12e39ba4599c2eb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_infer.c @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/adam_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int AdamInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 10); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[1]) || + NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[2]) || + NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[9]) || NNACLGetElementNum(inputs[3]) != 1 || + NNACLGetElementNum(inputs[4]) != 1 || NNACLGetElementNum(inputs[5]) != 1 || NNACLGetElementNum(inputs[6]) != 1 || + NNACLGetElementNum(inputs[7]) != 1 || NNACLGetElementNum(inputs[8]) != 1) { + return NNACL_ERR; + } + if (outputs_size != 0) { + TensorC *out = outputs[0]; + SetDataTypeFormat(out, inputs[0]); + out->shape_size_ = 1; + out->shape_[0] = 1; + } + + return NNACL_OK; +} + +REG_INFER(Adam, PrimType_Adam, AdamInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..b251f1ca9e95f997c0ebe1e71a2ba8c89da115a5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ADAM_INFER_H +#define MINDSPORE_NNACL_ADAM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AdamInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ADAM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_weight_decay_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_weight_decay_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..fbb8a918af819ff9578bc07a70d5b97f8bb969eb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_weight_decay_infer.c @@ -0,0 +1,56 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/adam_weight_decay_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int AdamWeightDecayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + const size_t expected_inputs_size = 10; + const int var_idx = 0; + const int m_idx = 1; + const int v_idx = 2; + const int lr_idx = 3; + const int beta1_idx = 4; + const int beta2_idx = 5; + const int epsilon = 6; + const int decay_idx = 7; + const int grad_idx = 8; + int check_ret = + CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, expected_inputs_size); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (NNACLGetElementNum(inputs[var_idx]) != NNACLGetElementNum(inputs[m_idx]) || + NNACLGetElementNum(inputs[var_idx]) != NNACLGetElementNum(inputs[v_idx]) || + NNACLGetElementNum(inputs[var_idx]) != NNACLGetElementNum(inputs[grad_idx]) || + NNACLGetElementNum(inputs[lr_idx]) != 1 || NNACLGetElementNum(inputs[beta1_idx]) != 1 || + NNACLGetElementNum(inputs[beta2_idx]) != 1 || NNACLGetElementNum(inputs[epsilon]) != 1 || + NNACLGetElementNum(inputs[decay_idx]) != 1) { + return NNACL_ERR; + } + if (outputs_size != 0) { + TensorC *out = outputs[0]; + SetDataTypeFormat(out, inputs[0]); + out->shape_size_ = 1; + out->shape_[0] = 1; + } + return NNACL_OK; +} + +REG_INFER(AdamWeightDecay, PrimType_AdamWeightDecay, AdamWeightDecayInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_weight_decay_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_weight_decay_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..13d5c63624d9978ee56c4e50e620b50c6883ae18 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_weight_decay_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H +#define MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AdamWeightDecayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/add_sub_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/add_sub_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..b84416ea33cffac5bac47e70e8e7f758585bc5af --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/add_sub_grad_infer.c @@ -0,0 +1,62 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/add_sub_grad_infer.h" +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/infer/infer_register.h" + +int AddSubGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *dy = inputs[0]; + const TensorC *x1 = inputs[1]; + const TensorC *x2 = inputs[2]; + TensorC *dx1 = outputs[0]; + TensorC *dx2 = outputs[1]; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + ArithmeticParameter *param = (ArithmeticParameter *)parameter; + + param->ndim_ = dy->shape_size_; + param->in_elements_num0_ = (int)param->ndim_; + param->in_elements_num1_ = (int)param->ndim_; + param->out_elements_num_ = (int)param->ndim_; + size_t fillDimNum0 = dy->shape_size_ - x1->shape_size_; + size_t fillDimNum1 = dy->shape_size_ - x2->shape_size_; + size_t j0 = 0; + size_t j1 = 0; + for (size_t i = 0; i < dy->shape_size_; i++) { + param->in_shape0_[i] = (i < fillDimNum0) ? 1 : x1->shape_[j0++]; + param->in_shape1_[i] = (i < fillDimNum1) ? 1 : x2->shape_[j1++]; + param->out_shape_[i] = dy->shape_[i]; + } + + SetShapeTensor(dx1, x1); + SetShapeTensor(dx2, x2); + SetDataTypeFormat(dx1, dy); + SetDataTypeFormat(dx2, dy); + return NNACL_OK; +} + +REG_INFER(AddGrad, PrimType_AddGrad, AddSubGradInferShape) +REG_INFER(SubGrad, PrimType_SubGrad, AddSubGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/add_sub_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/add_sub_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..216edefaf5dffbb708ce68aa86f6b8cd6ffbc588 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/add_sub_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ADD_SUB_GRAD_INFER_H +#define MINDSPORE_NNACL_ADD_SUB_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AddSubGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ADD_SUB_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/addn_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/addn_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..c2f2a8fd8cd6d34e483966df8e8df3b2f981c04f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/addn_infer.c @@ -0,0 +1,86 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/addn_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int AddnInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + if (inputs_size < 2) { + return NNACL_ERR; + } + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + size_t max_dims = input->shape_size_; + size_t max_dims_idx = 0; + + // check zerp dimension + for (size_t i = 0; i < max_dims; i++) { + NNACL_CHECK_FALSE(input->shape_[i] == 0, NNACL_ERR); + } + + // determine max_dims + for (size_t i = 1; i < inputs_size; ++i) { + if (inputs[i]->shape_size_ > max_dims) { + max_dims = inputs[i]->shape_size_; + max_dims_idx = i; + } + } + ShapeSet(output->shape_, &output->shape_size_, inputs[max_dims_idx]->shape_, inputs[max_dims_idx]->shape_size_); + + // make sure all elements have the same size or 1 (broadcasting) in all dimensions + for (size_t i = 1; i < inputs_size; ++i) { + if ((inputs[i]->shape_size_ != max_dims) && + (NNACLGetElementNum(inputs[i]) != NNACLGetElementNum(inputs[max_dims_idx]))) { + return NNACL_ERR; + } + if (inputs[i]->shape_size_ == max_dims) { + for (size_t j = 0; j < max_dims; j++) { + if (inputs[i]->shape_[j] != inputs[max_dims_idx]->shape_[j] && inputs[i]->shape_[j] != 1 && + inputs[max_dims_idx]->shape_[j] != 1) { + return NNACL_ERR; + } + } + } + } + + for (size_t d = 0; d < inputs[max_dims_idx]->shape_size_; ++d) { + size_t max_dim = 0; + for (size_t i = 0; i < inputs_size; ++i) { + size_t shift = max_dims - (size_t)(inputs[i]->shape_size_); + size_t dim = (i < shift) ? 1 : (size_t)(inputs[i]->shape_[d]); + if (dim > max_dim) { + max_dim = dim; + } + } + output->shape_[d] = (int)(max_dim); // set the biggest dimension in the output tensor + } + + return NNACL_OK; +} + +REG_INFER(AddN, PrimType_AddN, AddnInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/addn_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/addn_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..3655d9c77155beb47fb819096e17b7979bc59709 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/addn_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ADDN_INFER_H +#define MINDSPORE_NNACL_ADDN_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AddnInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ADDN_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/affine_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/affine_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..b52f762726000147aa52b9714da6af32c0c3e5bf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/affine_infer.c @@ -0,0 +1,122 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/affine_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int MatmulInfer(const AffineParameter *param, int a_shape[MAX_SHAPE_SIZE], size_t a_shape_size, + int b_shape[MAX_SHAPE_SIZE], size_t b_shape_size) { + MatMulParameter *matmul_param = param->matmul_parameter_; + NNACL_CHECK_NULL_RETURN_ERR(matmul_param); + if (matmul_param->a_transpose_) { + if (a_shape_size < 2) { + return NNACL_ERR; + } + iswap(&a_shape[a_shape_size - 1], &a_shape[a_shape_size - 2]); + } + if (matmul_param->b_transpose_) { + if (b_shape_size < 2) { + return NNACL_ERR; + } + iswap(&b_shape[b_shape_size - 1], &b_shape[b_shape_size - 2]); + } + return NNACL_OK; +} + +int AffineInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 3, 4, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + // splice + matmul + TensorC *input0 = (TensorC *)inputs[0]; + TensorC *input1 = (TensorC *)inputs[1]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input0); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + AffineParameter *param = (AffineParameter *)parameter; + if (param == NULL) { + return NNACL_NULL_PTR; + } + + int a_shape[MAX_SHAPE_SIZE] = {0}; + size_t a_shape_size = 0; + ShapeSet(a_shape, &a_shape_size, input0->shape_, input0->shape_size_); + if (a_shape_size == 4 && a_shape[2] == 1 && a_shape[3] == 1) { + a_shape_size = 2; + SetShapeArray(input0, a_shape, a_shape_size); + } + int context_min = param->context_[0]; + int context_max = param->context_[param->context_size_ - 1]; + + a_shape[1] = input0->shape_[1] - (context_max - context_min); + a_shape[2] = param->output_dim_; + + int b_shape[MAX_SHAPE_SIZE] = {0}; + size_t b_shape_size = 0; + ShapeSet(b_shape, &b_shape_size, input1->shape_, input1->shape_size_); + + bool del_start = false; + bool del_end = false; + if (a_shape_size == 1) { + int ret = ShapeInsert(a_shape, &a_shape_size, 0, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + SetShapeArray(input0, a_shape, a_shape_size); + del_start = true; + } + if (b_shape_size == 1) { + ShapePush(b_shape, &b_shape_size, 1); + SetShapeArray(input1, b_shape, b_shape_size); + del_end = true; + } + for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) { + if (a_shape[a_shape_size - 3 - i] != b_shape[b_shape_size - 3 - i]) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + + int ret = MatmulInfer(param, a_shape, a_shape_size, b_shape, b_shape_size); + if (ret != NNACL_OK) { + return ret; + } + + int c_shape[MAX_SHAPE_SIZE]; + size_t c_shape_size = 0; + ShapeSet(c_shape, &c_shape_size, a_shape, a_shape_size); + if (c_shape_size < 1 || b_shape_size < 1) { + return NNACL_ERR; + } + c_shape[c_shape_size - 1] = b_shape[b_shape_size - 1]; + if (del_start) { + int erase_ret = ShapeErase(c_shape, &c_shape_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + } + if (del_end) { + c_shape_size--; + } + SetShapeArray(output, c_shape, c_shape_size); + return NNACL_OK; +} + +REG_INFER(Affine, PrimType_Affine, AffineInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/affine_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/affine_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..04987c2152441cc8a03b3a4f054eaec537206a8a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/affine_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_INFER_AFFINE_INFER_H_ +#define MINDSPORE_NNACL_INFER_AFFINE_INFER_H_ +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/affine_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AffineInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_AFFINE_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/all_gather_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/all_gather_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..2f97643c2b2dd31ddae73ed3da55765994f87caf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/all_gather_infer.c @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/all_gather_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int AllGatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (inputs_size != 1 || outputs_size != 1) { + return NNACL_NULL_PTR; + } + if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { + return NNACL_NULL_PTR; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + AllGatherParameter *param = (AllGatherParameter *)parameter; + if (param->rank_size_ <= 0) { + return NNACL_INFER_INVALID; + } + + const TensorC *input_tensor = inputs[0]; + const int *in_shape = input_tensor->shape_; + TensorC *out_tensor = outputs[0]; + + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + out_shape[0] = in_shape[0] * param->rank_size_; + out_shape_size++; + for (int i = 1; i < input_tensor->shape_size_; i++) { + out_shape[i] = in_shape[i]; + out_shape_size++; + } + SetShapeArray(out_tensor, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(AllGather, PrimType_AllGather, AllGatherInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/all_gather_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/all_gather_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..40527aa60774409ba61e064293d61ce1f83a9a37 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/all_gather_infer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_INFER_ALL_GATHER_INFER_H_ +#define MINDSPORE_NNACL_INFER_ALL_GATHER_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/all_gather_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AllGatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_ALL_GATHER_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/apply_momentum_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/apply_momentum_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..15bb66047c048a148bbd17b012ba41fcc686579b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/apply_momentum_infer.c @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/apply_momentum_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int ApplyMomentumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 5); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[1]) || + NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[3]) || NNACLGetElementNum(inputs[2]) != 1 || + NNACLGetElementNum(inputs[4]) != 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (outputs_size != 0) { + TensorC *out = outputs[0]; + if (out == NULL) { + return NNACL_NULL_PTR; + } + out->data_type_ = inputs[0]->data_type_; + out->format_ = inputs[0]->format_; + out->shape_size_ = 1; + out->shape_[0] = 1; + } + + return NNACL_OK; +} + +REG_INFER(ApplyMomentum, PrimType_ApplyMomentum, ApplyMomentumInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/apply_momentum_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/apply_momentum_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..57e3ae1031337de24eff8fd0675ef4b99aacc8cb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/apply_momentum_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_APPLY_MOMENTUM_INFER_H +#define MINDSPORE_NNACL_APPLY_MOMENTUM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ApplyMomentumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_APPLY_MOMENTUM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/argmin_max_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/argmin_max_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..b36e7fa8638e7cb4f30fc788e2affac0962b5b59 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/argmin_max_infer.c @@ -0,0 +1,83 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/argmin_max_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int ArgMinMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size != 1 || outputs_size > 2) { + return NNACL_ERR; + } + + ArgMinMaxParameter *param = (ArgMinMaxParameter *)parameter; + const TensorC *input = inputs[0]; + TensorC *output_1 = NULL; + TensorC *output_2 = NULL; + if (outputs_size == 2) { + output_1 = outputs[0]; + output_2 = outputs[1]; + } else if (param->out_value_) { + output_2 = outputs[0]; + } else { + output_1 = outputs[0]; + } + + if (output_1 != NULL) { + output_1->format_ = input->format_; + output_1->data_type_ = kNumberTypeInt32; + } + if (output_2 != NULL) { + SetDataTypeFormat(output_2, input); + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + ShapeSet(output_shape, &output_shape_size, input->shape_, input->shape_size_); + int input_shape_size = (int)input->shape_size_; + int axis = param->axis_ < 0 ? param->axis_ + input_shape_size : param->axis_; + if (axis >= input_shape_size || axis < 0) { + return NNACL_PARAM_INVALID; + } + if (param->topk_ == 1 && !param->keep_dims_) { + int erase_ret = ShapeErase(output_shape, &output_shape_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + } else { + output_shape[axis] = param->topk_; + } + + if (output_1 != NULL) { + SetShapeArray(output_1, output_shape, output_shape_size); + } + if (output_2 != NULL) { + SetShapeArray(output_2, output_shape, output_shape_size); + } + return NNACL_OK; +} + +REG_INFER(ArgMin, PrimType_ArgMinFusion, ArgMinMaxInferShape) +REG_INFER(ArgMax, PrimType_ArgMaxFusion, ArgMinMaxInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/argmin_max_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/argmin_max_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..1afc71fc06d515281736ce5ae88b1d6fe1e222f8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/argmin_max_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ARGMAX_INFER_H +#define MINDSPORE_NNACL_ARGMAX_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/arg_min_max_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ArgMinMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ARGMAX_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_compare_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_compare_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..7ad9d70c0d7a8fefb1fb3fc318823f460520aea6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_compare_infer.c @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/arithmetic_compare_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int ArithmeticCompareInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int res = ArithmeticInferShape(inputs, inputs_size, outputs, outputs_size, parameter); + TensorC *output = outputs[0]; + if (output == NULL) { + return NNACL_NULL_PTR; + } + output->data_type_ = kNumberTypeBool; + return res; +} + +REG_INFER(Equal, PrimType_Equal, ArithmeticCompareInferShape) +REG_INFER(Greater, PrimType_Greater, ArithmeticCompareInferShape) +REG_INFER(GreaterEqual, PrimType_GreaterEqual, ArithmeticCompareInferShape) +REG_INFER(Less, PrimType_Less, ArithmeticCompareInferShape) +REG_INFER(LessEqual, PrimType_LessEqual, ArithmeticCompareInferShape) +REG_INFER(NotEqual, PrimType_NotEqual, ArithmeticCompareInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_compare_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_compare_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..3fce52c6caf4d33cf2b9e8fc793bb0002686fd43 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_compare_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ARITHMETIC_COMPARE_INFER_H +#define MINDSPORE_NNACL_ARITHMETIC_COMPARE_INFER_H + +#include "nnacl_c/infer/arithmetic_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ArithmeticCompareInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ARITHMETIC_COMPARE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..ed5cdab65b4401de9377d0d42fec786d01d6e3d8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_grad_infer.c @@ -0,0 +1,107 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/arithmetic_grad_infer.h" +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +/* + * the Arithmetic Grad op include AddGrad, SubGrad, MulGrad, DivGrad, MaximumGrad, MinimumGrad + * according to the arithmetic_fp32.h now + * the MaximumGrad, MinimumGrad run through MaximumGradInfershape + * the AddGrad, SubGrad run through AddSubGradInfershape + * the others run through this function + * */ +int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *dy = inputs[0]; + const TensorC *x1 = inputs[1]; + const TensorC *x2 = inputs[2]; + TensorC *dx1 = outputs[0]; + TensorC *dx2 = outputs[1]; + + if (dy->shape_size_ > MAX_SHAPE_SIZE || x1->shape_size_ > MAX_SHAPE_SIZE || x2->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int in_shape0[MAX_SHAPE_SIZE] = {0}; + size_t in_shape0_size = 0; + ShapeSet(in_shape0, &in_shape0_size, x1->shape_, x1->shape_size_); + int in_shape1[MAX_SHAPE_SIZE] = {0}; + size_t in_shape1_size = 0; + ShapeSet(in_shape1, &in_shape1_size, x2->shape_, x2->shape_size_); + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, dy->shape_, dy->shape_size_); + + ArithmeticParameter *param = (ArithmeticParameter *)parameter; + + if (NNACLGetElementNum(dx1) < NNACLGetElementNum(dx2)) { + param->ndim_ = in_shape1_size; + param->in_elements_num0_ = (int)param->ndim_; + param->in_elements_num1_ = (int)param->ndim_; + param->out_elements_num_ = (int)param->ndim_; + size_t fill_dim_num = in_shape1_size - in_shape0_size; // This will not work for batch! + int j = 0; + for (unsigned int i = 0; i < in_shape1_size; i++) { + if (i < fill_dim_num) { + param->in_shape1_[i] = 1; + } else { + param->in_shape1_[i] = in_shape0[j++]; + } + param->in_shape0_[i] = in_shape1[i]; + param->out_shape_[i] = out_shape[i]; + } + } else if (NNACLGetElementNum(dx2) < NNACLGetElementNum(dx1)) { + param->ndim_ = in_shape0_size; + param->in_elements_num0_ = (int)param->ndim_; + param->in_elements_num1_ = (int)param->ndim_; + param->out_elements_num_ = (int)param->ndim_; + param->broadcasting_ = true; + int j = 0; + size_t fill_dim_num = in_shape0_size - in_shape1_size; + for (unsigned int i = 0; i < in_shape0_size; i++) { + if (i < fill_dim_num) { + param->in_shape1_[i] = 1; + } else { + param->in_shape1_[i] = in_shape1[j++]; + } + param->in_shape0_[i] = in_shape0[i]; + param->out_shape_[i] = out_shape[i]; + } + } else { + param->broadcasting_ = false; + for (unsigned int i = 0; i < in_shape0_size; i++) { + param->in_shape1_[i] = in_shape1[i]; + param->in_shape0_[i] = in_shape0[i]; + param->out_shape_[i] = out_shape[i]; + } + } + + SetShapeTensor(dx1, x1); + SetShapeTensor(dx2, x2); + dx1->data_type_ = dy->data_type_; + dx2->data_type_ = dy->data_type_; + return NNACL_OK; +} + +REG_INFER(DivGrad, PrimType_DivGrad, ArithmeticGradInferShape) +REG_INFER(MulGrad, PrimType_MulGrad, ArithmeticGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..b7d6bb5498c863297872f2ce31d0a1b2ccb4e68f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_ARITHMETIC_GRAD_INFER_H_ +#define MINDSPORE_NNACL_INFER_ARITHMETIC_GRAD_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_ARITHMETIC_GRAD_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..247ec130fb57e7119298bfbe0fc02da1f441a4b2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_infer.c @@ -0,0 +1,123 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/arithmetic_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/infer/broadcast_to_infer.h" + +void SetOutputDtypeFormat(const TensorC *input0, const TensorC *input1, TensorC *output) { + output->format_ = input0->format_; + output->data_type_ = input0->data_type_; + // e.g. input0's shape is 1 and input1's shape is 1 15 15 1 + // only regard larger shape size input as the right format input currently + // legacy problem: if input0 infer failed before, its shape is [-1], and input1's shape is [1,2] which need to + // be broadcasted. In this case our program will use input1's format, that's wrong and need to be solved later. + if (input0->data_ != NULL || input0->shape_size_ < input1->shape_size_) { + output->format_ = input1->format_; + } + // when input0 is const, it is quanted before insert quant trans op, so use input1 data type instead + if (((input0->data_ != NULL) && (input1->data_type_ != kTypeUnknown)) || + ((input0->data_type_ == kNumberTypeInt8) && (input1->data_type_ == kNumberTypeFloat32))) { + output->data_type_ = input1->data_type_; + } +} + +int BroadCastInferShape(const int input_shape0_size, const int input_shape1_size, const int *input_shape0, + const int *input_shape1, int *ndim, int *in_shape0, int *in_shape1, int *out_shape, + bool *has_broad_cast) { + if (input_shape0_size > MAX_SHAPE_SIZE || input_shape1_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + MakeUpInputShapes(input_shape0_size, input_shape1_size, input_shape0, input_shape1, ndim, in_shape0, in_shape1); + if (*ndim >= MAX_SHAPE_SIZE) { + return NNACL_INFER_INVALID; + } + + return BroadCastOutputShape(in_shape0, in_shape1, *ndim, out_shape, has_broad_cast); +} + +int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + ArithmeticParameter *param = (ArithmeticParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + param->broadcasting_ = false; + + const TensorC *input0 = inputs[0]; + const TensorC *input1 = inputs[1]; + TensorC *output = outputs[0]; + + const int *input_shape0 = input0->shape_; + size_t input_shape0_size = input0->shape_size_; + const int *input_shape1 = input1->shape_; + size_t input_shape1_size = input1->shape_size_; + SetOutputDtypeFormat(input0, input1, output); + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int in_shape0[MAX_SHAPE_SIZE] = {0}; + int in_shape1[MAX_SHAPE_SIZE] = {0}; + int output_shape[MAX_SHAPE_SIZE] = {0}; + int ndim = (int)input_shape0_size; + bool has_broad_cast = false; + if (BroadCastInferShape(input_shape0_size, input_shape1_size, input_shape0, input_shape1, &ndim, in_shape0, in_shape1, + output_shape, &has_broad_cast) != NNACL_OK) { + return NNACL_ERR; + } + + SetShapeArray(output, output_shape, ndim); + + param->broadcasting_ = has_broad_cast; + param->ndim_ = (size_t)ndim; + if (ndim > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + memcpy(param->in_shape0_, in_shape0, ndim * sizeof(int)); + memcpy(param->in_shape1_, in_shape1, ndim * sizeof(int)); + memcpy(param->out_shape_, output_shape, ndim * sizeof(int)); + + param->in_elements_num0_ = 1; + param->in_elements_num1_ = 1; + param->out_elements_num_ = 1; + for (int i = 0; i < ndim; i++) { + param->in_elements_num0_ *= param->in_shape0_[i]; + param->in_elements_num1_ *= param->in_shape1_[i]; + param->out_elements_num_ *= param->out_shape_[i]; + } + return NNACL_OK; +} + +REG_INFER(Add, PrimType_AddFusion, ArithmeticInferShape) +REG_INFER(BiasAdd, PrimType_BiasAdd, ArithmeticInferShape) +REG_INFER(Div, PrimType_DivFusion, ArithmeticInferShape) +REG_INFER(Eltwise, PrimType_Eltwise, ArithmeticInferShape) +REG_INFER(FloorDiv, PrimType_FloorDiv, ArithmeticInferShape) +REG_INFER(FloorMod, PrimType_FloorMod, ArithmeticInferShape) +REG_INFER(LogicalAnd, PrimType_LogicalAnd, ArithmeticInferShape) +REG_INFER(LogicalOr, PrimType_LogicalOr, ArithmeticInferShape) +REG_INFER(Maximum, PrimType_Maximum, ArithmeticInferShape) +REG_INFER(Minimum, PrimType_Minimum, ArithmeticInferShape) +REG_INFER(Mod, PrimType_Mod, ArithmeticInferShape) +REG_INFER(Mul, PrimType_MulFusion, ArithmeticInferShape) +REG_INFER(RealDiv, PrimType_RealDiv, ArithmeticInferShape) +REG_INFER(Sub, PrimType_SubFusion, ArithmeticInferShape) +REG_INFER(SquaredDifference, PrimType_SquaredDifference, ArithmeticInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d7a53551ab0a4e28b4a0d7a4681a6f4449d889fd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ARITHMETIC_INFER_H +#define MINDSPORE_NNACL_ARITHMETIC_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/arithmetic_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outpus_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ARITHMETIC_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assert_op_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assert_op_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..d4a594969d018ee15e825c17e2e94eca2b88c982 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assert_op_infer.c @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/assert_op_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int AssertOpInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + return NNACL_OK; +} + +REG_INFER(Assert, PrimType_Assert, AssertOpInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assert_op_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assert_op_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..49836662b0001249835abfacd1fa60a743706e97 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assert_op_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ASSERT_OP_INFER_H +#define MINDSPORE_NNACL_ASSERT_OP_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AssertOpInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ASSERT_OP_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_add_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_add_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..6de394efbe63d12b3c60490aaac3d6820c4c5a5c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_add_infer.c @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/assign_add_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int AssignAddInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *x = inputs[0]; + const TensorC *y = inputs[1]; + TensorC *out = outputs[0]; + if (x->data_type_ != y->data_type_) { + return NNACL_ERR; + } + SetDataTypeFormat(out, x); + SetShapeTensor(out, x); + return NNACL_OK; +} + +REG_INFER(AssignAdd, PrimType_AssignAdd, AssignAddInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_add_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_add_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..ec5cda41e28bedfdc5ea9779db6b8c11dd8276eb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_add_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ASSIGN_ADD_INFER_H +#define MINDSPORE_NNACL_ASSIGN_ADD_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AssignAddInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ASSIGN_ADD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..350c55a724f1c81c916486d2a208bd0a107492fe --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_infer.c @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/assign_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int AssignInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[1])) { + return NNACL_ERR; + } + + if (outputs_size != 0) { + TensorC *out = outputs[0]; + SetDataTypeFormat(out, inputs[0]); + out->shape_size_ = 1; + out->shape_[0] = 1; + } + return NNACL_OK; +} + +REG_INFER(Assign, PrimType_Assign, AssignInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..4d3a836467620e766c01506c9dd2ac302bd0dcce --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ASSIGN_INFER_H +#define MINDSPORE_NNACL_ASSIGN_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AssignInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ASSIGN_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/attention_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/attention_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..3f0ee5af56184fc1cfe3c053407fb0a1e3c518f6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/attention_infer.c @@ -0,0 +1,74 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/attention_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/attention_parameter.h" + +int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 7, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + AttentionParameter *param = (AttentionParameter *)parameter; + const TensorC *q_input = inputs[FIRST_INPUT]; + const TensorC *k_input = inputs[SECOND_INPUT]; + TensorC *output0 = outputs[FIRST_INPUT]; + SetDataTypeFormat(output0, q_input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *q_weight = inputs[FOURTH_INPUT]; + if (q_input->shape_size_ != C2NUM && q_input->shape_size_ != C3NUM) { + return NNACL_ERR; + } + if (q_weight->shape_size_ != C2NUM) { + return NNACL_ERR; + } + int batch = (q_input->shape_size_ == C2NUM) ? 1 : q_input->shape_[0]; + int f_seq = (q_input->shape_size_ == C2NUM) ? q_input->shape_[0] : q_input->shape_[1]; + int t_seq_len = k_input->shape_[1]; + if (q_input->shape_size_ == C2NUM) { + output0->shape_[FIRST_INPUT] = batch * f_seq; + output0->shape_[SECOND_INPUT] = param->head_num_ * param->head_size_; + output0->shape_size_ = C2NUM; + } else { + output0->shape_[FIRST_INPUT] = batch; + output0->shape_[SECOND_INPUT] = f_seq; + output0->shape_[THIRD_INPUT] = param->head_num_ * param->head_size_; + output0->shape_size_ = C3NUM; + } + if (outputs_size >= C3NUM) { + TensorC *output1 = outputs[SECOND_INPUT]; + SetDataTypeFormat(output1, q_input); + output1->shape_[FIRST_INPUT] = batch; + output1->shape_[SECOND_INPUT] = param->head_num_; + output1->shape_[THIRD_INPUT] = param->head_size_; + output1->shape_[FOURTH_INPUT] = t_seq_len; + output1->shape_size_ = C4NUM; + TensorC *output2 = outputs[THIRD_INPUT]; + SetDataTypeFormat(output2, q_input); + output2->shape_[FIRST_INPUT] = batch; + output2->shape_[SECOND_INPUT] = param->head_num_; + output2->shape_[THIRD_INPUT] = t_seq_len; + output2->shape_[FOURTH_INPUT] = param->head_size_; + output2->shape_size_ = C4NUM; + } + return NNACL_OK; +} + +REG_INFER(Attention, PrimType_Attention, AttentionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/attention_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/attention_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..f483d03fbf09ba9cc3527ce8f9c21462649998e2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/attention_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ATTENTION_INFER_H +#define MINDSPORE_NNACL_ATTENTION_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ATTENTION_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/audio_spectrogram_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/audio_spectrogram_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..a2ebeba9654c3007cff9e25fb8e94aa5c2346a79 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/audio_spectrogram_infer.c @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/audio_spectrogram_infer.h" +#include "nnacl_c/infer/infer_register.h" + +unsigned Log2Ceil(unsigned length) { + if (length == 0) { + return 0; + } + int floor = 0; + for (int i = 4; i >= 0; --i) { + const unsigned shift = (1 << (unsigned)i); + unsigned tmp = length >> shift; + if (tmp != 0) { + length = tmp; + floor += shift; + } + } + return length == (length & ~(length - 1)) ? floor : floor + 1; +} + +unsigned GetFftLength(unsigned length) { + unsigned shift = Log2Ceil(length); + return 1 << shift; +} + +int AudioSpectrogramInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 2) { + return NNACL_ERR; + } + AudioSpectrogramParameter *param = (AudioSpectrogramParameter *)parameter; + if (param->window_size_ < 2) { + return NNACL_ERR; + } + if (param->stride_ < 1) { + return NNACL_ERR; + } + int output_shape[3]; + output_shape[0] = input->shape_[1]; + int sample_sub_window = input->shape_[0] - param->window_size_; + output_shape[1] = sample_sub_window < 0 ? 0 : 1 + sample_sub_window / param->stride_; + // compute fft length + int fft_length = (int)GetFftLength(param->window_size_); + output_shape[2] = fft_length / 2 + 1; + SetShapeArray(output, output_shape, 3); + return NNACL_OK; +} + +REG_INFER(AudioSpectrogram, PrimType_AudioSpectrogram, AudioSpectrogramInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/audio_spectrogram_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/audio_spectrogram_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..e0b0d6e3de4cf0ab8d122fdf81380f4dd64d0a4a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/audio_spectrogram_infer.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_AUDIO_SPECTROGRAM_INFER_H +#define MINDSPORE_NNACL_AUDIO_SPECTROGRAM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct AudioSpectrogramParameter { + OpParameter op_parameter_; + int window_size_; + int stride_; +} AudioSpectrogramParameter; + +int AudioSpectrogramInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_AUDIO_SPECTROGRAM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/batch_to_space_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/batch_to_space_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..87607774d571b3453e3bd9a91c28d47912e68d96 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/batch_to_space_infer.c @@ -0,0 +1,144 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/batch_to_space_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int SetOutputShapeFromParam(const TensorC *const *inputs, TensorC **outputs, const OpParameter *parameter) { + int input_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, inputs[0]->shape_, inputs[0]->shape_size_); + + if (input_shape_size != 4) { + return NNACL_PARAM_INVALID; + } + + const BatchToSpaceParameter *param = (const BatchToSpaceParameter *)parameter; + const int32_t *block_shape = param->block_shape_; + const int32_t *crops = param->crops_; + int mul_block_shape = 1; + + for (size_t i = 0; i < 2; ++i) { + if (block_shape[i] <= 0) { + return NNACL_PARAM_INVALID; + } + if (input_shape[kNHWC_N] % block_shape[i]) { + return NNACL_ERR; + } + mul_block_shape *= block_shape[i]; + } + + if (input_shape[kNHWC_N] < mul_block_shape) { + return NNACL_PARAM_INVALID; + } + for (size_t i = 0; i < 4; ++i) { + if (crops[i] < 0) { + return NNACL_PARAM_INVALID; + } + } + if (mul_block_shape == 0) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input_shape_size; + output_shape[kNHWC_N] = input_shape[kNHWC_N] / mul_block_shape; + output_shape[kNHWC_H] = input_shape[kNHWC_H] * block_shape[0] - crops[0] - crops[1]; + output_shape[kNHWC_W] = input_shape[kNHWC_W] * block_shape[1] - crops[2] - crops[3]; + output_shape[kNHWC_C] = input_shape[kNHWC_C]; + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +int SetOutputShapeFromInput(const TensorC *const *inputs, TensorC **outputs) { + int input_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, inputs[0]->shape_, inputs[0]->shape_size_); + if (input_shape_size != 4) { + return NNACL_PARAM_INVALID; + } + int *block_shape = (int *)(inputs[1]->data_); + int *crops = (int *)(inputs[2]->data_); + if (NNACLGetElementNum(inputs[1]) != 2) { + return NNACL_PARAM_INVALID; + } + if (NNACLGetElementNum(inputs[2]) != 4) { + return NNACL_PARAM_INVALID; + } + int mul_block_shape_ = 1; + + for (size_t i = 0; i < 2; ++i) { + if (block_shape[i] <= 0) { + return NNACL_PARAM_INVALID; + } + if (input_shape[kNHWC_N] % block_shape[i]) { + return 1; + } + mul_block_shape_ *= block_shape[i]; + } + + if (input_shape[kNHWC_N] < mul_block_shape_) { + return NNACL_PARAM_INVALID; + } + for (size_t i = 0; i < 4; ++i) { + if (crops[i] < 0) { + return NNACL_PARAM_INVALID; + } + } + if (mul_block_shape_ == 0) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input_shape_size; + output_shape[kNHWC_N] = input_shape[kNHWC_N] / mul_block_shape_; + output_shape[kNHWC_H] = input_shape[kNHWC_H] * block_shape[0] - crops[0] - crops[1]; + output_shape[kNHWC_W] = input_shape[kNHWC_W] * block_shape[1] - crops[2] - crops[3]; + output_shape[kNHWC_C] = input_shape[kNHWC_C]; + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +int BatchToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (outputs_size != 1 || (inputs_size != 1 && inputs_size != 3)) { + return NNACL_PARAM_INVALID; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[0], input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (inputs_size == 1) { + ret = SetOutputShapeFromParam(inputs, outputs, parameter); + return ret; + } + if (inputs[1]->data_ == NULL || inputs[2]->data_ == NULL) { + return NNACL_INFER_INVALID; + } + ret = SetOutputShapeFromInput(inputs, outputs); + return ret; +} + +REG_INFER(BatchToSpace, PrimType_BatchToSpace, BatchToSpaceInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/batch_to_space_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/batch_to_space_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..fa073047a29735f9a61844a7e47cd02d5e990cb9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/batch_to_space_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_BATCH_TO_SPACE_INFER_H +#define MINDSPORE_NNACL_BATCH_TO_SPACE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/batch_to_space_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int BatchToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_BATCH_TO_SPACE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bias_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bias_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..a698b1394fa15359fff3b3119eac36d2e6a4c954 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bias_grad_infer.c @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/bias_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int BiasGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in0 = inputs[0]; + TensorC *out = outputs[0]; + + if (in0->shape_size_ > MAX_SHAPE_SIZE || in0->shape_size_ < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + int inshape[] = {in0->shape_[in0->shape_size_ - 1]}; + size_t inshape_size = 1; + SetDataTypeFormat(out, in0); + SetShapeArray(out, inshape, inshape_size); + + return NNACL_OK; +} + +REG_INFER(BiasAddGrad, PrimType_BiasAddGrad, BiasGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bias_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bias_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..0ab82b54247b92c556c845c14e0e637125d48bb5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bias_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_BIAS_GRAD_INFER_H +#define MINDSPORE_NNACL_BIAS_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int BiasGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_BIAS_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/binary_cross_entropy_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/binary_cross_entropy_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..bb45ff28eb5bd9b6f09cf1898a6adefdaeb10c44 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/binary_cross_entropy_infer.c @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/binary_cross_entropy_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int BinaryCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (ret != NNACL_OK) { + return ret; + } + const TensorC *x = inputs[0]; + TensorC *out = outputs[0]; + SetDataTypeFormat(out, x); + BinaryCrossEntropyParameter *param = (BinaryCrossEntropyParameter *)parameter; + ReductionType reduction = (ReductionType)(param->reduction); + if (reduction == Reduction_Mean || reduction == Reduction_Sum) { + out->shape_size_ = 1; + out->shape_[0] = 1; + } else { + SetShapeTensor(out, x); + } + return NNACL_OK; +} + +REG_INFER(BinaryCrossEntropy, PrimType_BinaryCrossEntropy, BinaryCrossEntropyInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/binary_cross_entropy_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/binary_cross_entropy_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..18e66918b69b341ebdbe6a8a538e8b3d00f46b81 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/binary_cross_entropy_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_BINARY_CROSS_ENTROPY_INFER_H +#define MINDSPORE_NNACL_BINARY_CROSS_ENTROPY_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32_grad/binary_cross_entropy.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int BinaryCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_BINARY_CROSS_ENTROPY_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bn_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bn_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..989f96dd260246b68267d91baa8ca7dd500f22ff --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bn_grad_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/bn_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int BnGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 6, 3); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in = inputs[1]; + if ((inputs[0]->shape_size_ == 4 && inputs[0]->format_ != Format_NHWC) || + (in->shape_size_ == 4 && in->format_ != Format_NHWC)) { + return NNACL_FORMAT_ERROR; + } + const TensorC *scale = inputs[2]; + SetShapeTensor(outputs[0], in); + SetDataTypeFormat(outputs[0], in); + SetShapeTensor(outputs[1], scale); + SetDataTypeFormat(outputs[1], scale); + SetShapeTensor(outputs[2], scale); + SetDataTypeFormat(outputs[2], scale); + return NNACL_OK; +} + +REG_INFER(BatchNormGrad, PrimType_BatchNormGrad, BnGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bn_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bn_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..fd607a4314cadb51ac8367ce74de1797d3ee7d40 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bn_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_BN_GRAD_INFER_H +#define MINDSPORE_NNACL_BN_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int BnGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_BN_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/broadcast_to_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/broadcast_to_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..f08a31863e8ccb449c482e57d5591a82866970da --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/broadcast_to_infer.c @@ -0,0 +1,200 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/broadcast_to_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int GetShapeByType(const TensorC *shape_tensor, int shape_size, int *dst_shape) { + if (shape_tensor == NULL || dst_shape == NULL) { + return NNACL_ERR; + } + if (shape_size == 0) { + return NNACL_INFER_INVALID; + } + NNACL_CHECK_NULL_RETURN_ERR(shape_tensor->data_); + switch (shape_tensor->data_type_) { + case kNumberTypeInt8: { + int8_t *data = (int8_t *)(shape_tensor->data_); + for (int i = 0; i < shape_size; i++) { + dst_shape[i] = data[i]; + } + } break; + case kNumberTypeInt32: { + int32_t *data = (int32_t *)(shape_tensor->data_); + for (int i = 0; i < shape_size; i++) { + dst_shape[i] = data[i]; + } + } break; + case kNumberTypeInt64: { + int64_t *data = (int64_t *)(shape_tensor->data_); + for (int i = 0; i < shape_size; i++) { + dst_shape[i] = (int)data[i]; + } + } break; + case kNumberTypeFloat: { + float *data = (float *)(shape_tensor->data_); + for (int i = 0; i < shape_size; i++) { + dst_shape[i] = data[i]; + } + } break; + case kNumberTypeUInt32: { + uint32_t *data = (uint32_t *)(shape_tensor->data_); + for (int i = 0; i < shape_size; i++) { + dst_shape[i] = (int)data[i]; + } + } break; + default: { + return NNACL_ERR; + } + } + return NNACL_OK; +} + +void MakeUpInputShapes(const int input_shape0_size, const int input_shape1_size, const int *input_shape0, + const int *input_shape1, int *ndim, int *in_shape0, int *in_shape1) { + if (input_shape0_size < input_shape1_size) { + *ndim = input_shape1_size; + int fill_dim_num = input_shape1_size - input_shape0_size; + int j = 0; + for (int i = 0; i < input_shape1_size; i++) { + if (i < fill_dim_num) { + in_shape0[i] = 1; + } else { + in_shape0[i] = input_shape0[j++]; + } + in_shape1[i] = input_shape1[i]; + } + } else if (input_shape0_size > input_shape1_size) { + *ndim = input_shape0_size; + int fill_dim_num = input_shape0_size - input_shape1_size; + int j = 0; + for (int i = 0; i < input_shape0_size; i++) { + if (i < fill_dim_num) { + in_shape1[i] = 1; + } else { + in_shape1[i] = input_shape1[j++]; + } + in_shape0[i] = input_shape0[i]; + } + } else { + for (int i = 0; i < input_shape0_size; i++) { + in_shape1[i] = input_shape1[i]; + in_shape0[i] = input_shape0[i]; + } + } +} + +int BroadCastOutputShape(const int *in_shape0, const int *in_shape1, const int ndim, int *out_shape, + bool *has_broad_cast) { + for (int i = 0; i < ndim; i++) { + if (in_shape0[i] != in_shape1[i]) { + if (in_shape0[i] == 1) { + out_shape[i] = in_shape1[i]; + } else if (in_shape1[i] == 1) { + out_shape[i] = in_shape0[i]; + } else { + return NNACL_ERR; + } + *has_broad_cast = true; + } else { + out_shape[i] = in_shape0[i]; + } + } + return NNACL_OK; +} + +int BroadCastToShape(const int input_shape0_size, const int input_shape1_size, const int *input_shape0, + const int *input_shape1, int *ndim, int *out_shape, bool *has_broad_cast) { + if (input_shape0_size > MAX_SHAPE_SIZE || input_shape1_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + + int in_shape0[MAX_SHAPE_SIZE] = {0}; + int in_shape1[MAX_SHAPE_SIZE] = {0}; + + MakeUpInputShapes(input_shape0_size, input_shape1_size, input_shape0, input_shape1, ndim, in_shape0, in_shape1); + if (*ndim >= MAX_SHAPE_SIZE) { + return NNACL_INFER_INVALID; + } + + return BroadCastOutputShape(in_shape0, in_shape1, *ndim, out_shape, has_broad_cast); +} + +int BroadcastToInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (inputs_size != 1 && inputs_size != 2) { + return NNACL_ERR; + } + if (outputs_size != 1) { + return NNACL_ERR; + } + + const TensorC *input = inputs[0]; + SetDataTypeFormat(outputs[0], input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int dst_shape[MAX_SHAPE_SIZE] = {0}; + int dst_shape_size; + const int *input_shape = input->shape_; + int input_shape_size = input->shape_size_; + int output_shape[MAX_SHAPE_SIZE] = {0}; + int ndim = input_shape_size; + bool has_broad_cast = false; + if (inputs_size == 1) { + BroadcastToParameter *param = (BroadcastToParameter *)parameter; + dst_shape_size = (int)param->shape_size_; + if (dst_shape_size > MAX_SHAPE_SIZE) { + return NNACL_PARAM_INVALID; + } + for (int i = 0; i < dst_shape_size; i++) { + dst_shape[i] = param->shape_[i]; + } + } else { + const TensorC *shape_tensor = inputs[1]; + if (shape_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + dst_shape_size = NNACLGetElementNum(shape_tensor); + if (dst_shape_size > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + ret = GetShapeByType(shape_tensor, dst_shape_size, dst_shape); + if (ret != NNACL_OK) { + return ret; + } + for (int i = 0; i < dst_shape_size; ++i) { + if (dst_shape[i] == -1) { + dst_shape[i] = inputs[0]->shape_[i]; + } + } + } + + if (BroadCastToShape(input_shape_size, dst_shape_size, input_shape, dst_shape, &ndim, output_shape, + &has_broad_cast) != NNACL_OK) { + return NNACL_ERR; + } + + SetShapeArray(outputs[0], output_shape, (size_t)ndim); + return NNACL_OK; +} + +REG_INFER(BroadcastTo, PrimType_BroadcastTo, BroadcastToInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/broadcast_to_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/broadcast_to_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..9d607f2d4bbb96304cd2f7f7b40b7a47d466c592 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/broadcast_to_infer.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_BROADCAST_TO_INFER_H +#define MINDSPORE_NNACL_BROADCAST_TO_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/base/broadcast_to.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int BroadcastToInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outpus_size, + OpParameter *parameter); +void MakeUpInputShapes(const int input_shape0_size, const int input_shape1_size, const int *input_shape0, + const int *input_shape1, int *ndim, int *in_shape0, int *in_shape1); +int BroadCastOutputShape(const int *in_shape0, const int *in_shape1, const int ndim, int *out_shape, + bool *has_broad_cast); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_BROADCAST_TO_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_gather_reduce_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_gather_reduce_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..c8d715c923b32846e14cba86eb77ea21a455bdb9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_gather_reduce_infer.c @@ -0,0 +1,77 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/cast_gather_reduce_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/split_parameter.h" + +int CastGatherReduceFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + const size_t kMinimumGradInputsNum = 3; + if (inputs_size < kMinimumGradInputsNum || outputs_size != 1) { + return NNACL_ERR; + } + const TensorC *input = inputs[0]; + const TensorC *indices = inputs[1]; + TensorC *output = outputs[0]; + output->data_type_ = input->data_type_; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE || indices->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (inputs[C2NUM]->data_ == NULL) { + return NNACL_NULL_PTR; + } + int axis = *((int *)inputs[C2NUM]->data_); + if (axis < 0) { + axis += input->shape_size_; + } + int indices_shape[MAX_SHAPE_SIZE]; + size_t indices_shape_size = 0; + ShapeSet(indices_shape, &indices_shape_size, indices->shape_, indices->shape_size_); + size_t indices_rank = indices_shape_size; + int in_shape[MAX_SHAPE_SIZE] = {0}; + size_t in_shape_size = 0; + ShapeSet(in_shape, &in_shape_size, input->shape_, input->shape_size_); + if ((int)(in_shape_size) < axis + 1) { + return NNACL_ERR; + } + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, in_shape, in_shape_size); + int erase_ret = ShapeErase(out_shape, &out_shape_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + for (int i = (int)(indices_rank - 1); i >= 0; --i) { + ret = ShapeInsert(out_shape, &out_shape_size, axis, indices_shape[i]); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + } + out_shape[1] = 1; + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(CastGatherReduceFusion, PrimType_Inner_CastGatherReduceFusion, CastGatherReduceFusionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_gather_reduce_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_gather_reduce_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..da9c11267296360ed37eacfad47a21a074907862 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_gather_reduce_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CAST_GATHER_REDUCE_INFER_H +#define MINDSPORE_NNACL_CAST_GATHER_REDUCE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CastGatherReduceFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_REDUCE_CONCAT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..035697ba63fa594c049aa0866a4ec228138c77e2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_infer.c @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/cast_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullOutputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size != 2) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->format_ = input->format_; + const TensorC *dst_type = inputs[1]; + if (dst_type->data_ == NULL) { + return NNACL_NULL_PTR; + } + output->data_type_ = *((int *)dst_type->data_); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->data_type_ != kNumberTypeBool && input->data_type_ != kNumberTypeUInt8 && + input->data_type_ != kNumberTypeInt8 && input->data_type_ != kNumberTypeInt32 && + input->data_type_ != kNumberTypeInt64 && input->data_type_ != kNumberTypeFloat32 && + input->data_type_ != kNumberTypeFloat16) { + return NNACL_INPUT_TENSOR_ERROR; + } + + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(Cast, PrimType_Cast, CastInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..530516b599257a795e599ec47f2cf734388a19bf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CAST_INFER_H +#define MINDSPORE_NNACL_CAST_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CAST_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/common_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/common_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..39ab31c2f643e515c8f4c2c4fecbefd6f357c10d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/common_infer.c @@ -0,0 +1,338 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use tensor file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/common_infer.h" +#include +#include +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/tensorlist_c_utils.h" + +bool CheckShaleValid(TensorC **tensors, int tensors_size) { + for (int i = 0; i < tensors_size; i++) { + TensorC *t = tensors[i]; + for (size_t j = 0; j < t->shape_size_; j++) { + if (t->shape_[j] == -1) { + return false; + } + } + } + return true; +} + +bool CheckInferShapeDone(TensorC **in, int in_size, TensorC **out, int out_size) { + for (int i = 0; i < in_size; i++) { + TensorC *t = in[i]; + for (size_t j = 0; j < t->shape_size_; j++) { + if (t->shape_[j] == -1) { + return false; + } + } + } + for (int i = 0; i < out_size; i++) { + TensorC *t = out[i]; + for (size_t j = 0; j < t->shape_size_; j++) { + if (t->shape_[j] == -1) { + return false; + } + } + } + return true; +} + +void ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size) { + size_t i = 0; + for (; i < src_shape_size && i < MAX_SHAPE_SIZE; i++) { + dst_shape[i] = src_shape[i]; + } + *dst_shape_size = i; +} + +bool Int64ShapeSet(int *dst_shape, size_t *dst_shape_size, const int64_t *src_shape, size_t src_shape_size) { + if (dst_shape_size == NULL || dst_shape == NULL || src_shape == NULL) { + return false; + } + size_t i = 0; + for (; i < src_shape_size && i < MAX_SHAPE_SIZE; i++) { + if (MS_UNLIKELY(src_shape[i] > (int64_t)INT32_MAX || src_shape[i] < (int64_t)INT32_MIN)) { + return false; + } + dst_shape[i] = (int32_t)(src_shape[i]); + } + *dst_shape_size = i; + return true; +} + +void ShapePush(int *shape, size_t *shape_size, int value) { + if (*shape_size >= MAX_SHAPE_SIZE) { + return; + } + shape[*shape_size] = value; + *shape_size = *shape_size + 1; +} + +int GetInt32DataFromTensor(const TensorC *tensor, int *result, size_t *result_size) { + if (tensor->data_ == NULL || result == NULL || result_size == NULL) { + return NNACL_ERR; + } + if (tensor->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int ele_num = NNACLGetElementNum(tensor); + if (ele_num <= 0) { + return NNACL_ERR; + } + *result_size = (size_t)ele_num; + if (tensor->data_type_ == kNumberTypeInt || tensor->data_type_ == kNumberTypeInt32) { + int *data = (int *)(tensor->data_); + for (int i = 0; i < ele_num; i++) { + result[i] = data[i]; + } + } else if (tensor->data_type_ == kNumberTypeInt64) { + int64_t *data = (int64_t *)(tensor->data_); + for (int i = 0; i < ele_num; i++) { + if (data[i] >= INT32_MAX) { + return NNACL_ERR; + } + result[i] = (int32_t)data[i]; + } + } else { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +int ShapeInsert(int *shape, size_t *shape_size, int index, int value) { + if (index < 0 || index > *shape_size) { + return NNACL_ERR; + } + if (*shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + for (int i = *shape_size; i > index; i--) { + shape[i] = shape[i - 1]; + } + shape[index] = value; + *shape_size = *shape_size + 1; + return NNACL_OK; +} + +int ShapeErase(int *shape, size_t *shape_size, int index) { + if (index < 0 || index >= *shape_size) { + return NNACL_ERR; + } + + for (int i = index; i < *shape_size - 1; i++) { + shape[i] = shape[i + 1]; + } + *shape_size = *shape_size - 1; + return NNACL_OK; +} + +bool ShapeEqual(const int *shape0, size_t shape0_size, const int *shape1, size_t shape1_size) { + if (shape0_size != shape1_size) { + return false; + } + for (size_t i = 0; i < shape0_size; i++) { + if (shape0[i] != shape1[i]) { + return false; + } + } + return true; +} + +void iswap(int *a, int *b) { + int tmp = *a; + *a = *b; + *b = tmp; +} + +int imin(int a, int b) { return a > b ? b : a; } + +int imax(int a, int b) { return a < b ? b : a; } + +// input == output completely refer to +// 1. zeros_like +int CommonInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { + return NNACL_NULL_PTR; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +int CommonGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); + if (ret != NNACL_OK) { + return ret; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + NNACL_CHECK_TRUE_RET(inputs[0]->shape_size_ == inputs[1]->shape_size_, NNACL_ERR); + for (int i = 0; i < inputs[0]->shape_size_; i++) { + if (inputs[0]->shape_[i] != inputs[1]->shape_[i]) { + return NNACL_ERR; + } + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +int CommonInferShapeWithOneInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (ret != NNACL_OK) { + return ret; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +int CommonInferShapeWithTwoInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); + if (ret != NNACL_OK) { + return ret; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +int CommonInferShapeWithNHWC(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { + return NNACL_NULL_PTR; + } + if (inputs[0]->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +int FftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (ret != NNACL_OK) { + return ret; + } + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->data_type_ = kNumberTypeFloat32; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_); + if (input_shape_size == 0) { + return NNACL_ERR; + } + input_shape_size--; + SetShapeArray(output, input_shape, input_shape_size); + return NNACL_OK; +} + +bool InferFlag(const TensorC *const *inputs, size_t inputs_size) { + if (inputs == NULL) { + return false; + } + for (size_t i = 0; i < inputs_size; i++) { + if (inputs[i] == NULL) { + return false; + } + if (inputs[i]->data_type_ == kObjectTypeTensorType) { + if (InferFlagTensorList((TensorC *)inputs[i]) == false) { + return false; + } + } else { + for (size_t j = 0; j < inputs[i]->shape_size_; ++j) { + if (inputs[i]->shape_[j] < 0) { + return false; + } + } + } + } + return true; +} + +REG_INFER(Abs, PrimType_Abs, CommonInferShape) +REG_INFER(AbsGrad, PrimType_AbsGrad, CommonGradInferShape) +REG_INFER(Activation, PrimType_Activation, CommonInferShape) +REG_INFER(BatchNorm, PrimType_BatchNorm, CommonInferShape) +REG_INFER(BinaryCrossEntropyGrad, PrimType_BinaryCrossEntropyGrad, CommonInferShape) +REG_INFER(Ceil, PrimType_Ceil, CommonInferShape) +REG_INFER(Clip, PrimType_Clip, CommonInferShape) +REG_INFER(Cos, PrimType_Cos, CommonInferShape) +REG_INFER(Depend, PrimType_Depend, CommonInferShape) +REG_INFER(Elu, PrimType_Elu, CommonInferShape) +REG_INFER(Erf, PrimType_Erf, CommonInferShape) +REG_INFER(Exp, PrimType_ExpFusion, CommonInferShape) +REG_INFER(FakeQuantWithMinMaxVars, PrimType_FakeQuantWithMinMaxVars, CommonInferShape) +REG_INFER(Floor, PrimType_Floor, CommonInferShapeWithOneInput) +REG_INFER(LeakyRelu, PrimType_LeakyRelu, CommonInferShape) +REG_INFER(Log, PrimType_Log, CommonInferShape) +REG_INFER(Log1p, PrimType_Log1p, CommonInferShape) +REG_INFER(LogGrad, PrimType_LogGrad, CommonGradInferShape) +REG_INFER(LogicalNot, PrimType_LogicalNot, CommonInferShape) +REG_INFER(LRN, PrimType_LRN, CommonInferShapeWithNHWC) +REG_INFER(L2Normalize, PrimType_L2NormalizeFusion, CommonInferShape) +REG_INFER(Neg, PrimType_Neg, CommonInferShape) +REG_INFER(NegGrad, PrimType_NegGrad, CommonGradInferShape) +REG_INFER(OnesLike, PrimType_OnesLike, CommonInferShape) +REG_INFER(PowerGrad, PrimType_PowerGrad, CommonGradInferShape) +REG_INFER(PReLU, PrimType_PReLUFusion, CommonInferShape) +REG_INFER(Reciprocal, PrimType_Reciprocal, CommonInferShape) +REG_INFER(ReverseSequence, PrimType_ReverseSequence, CommonInferShape) +REG_INFER(Reverse, PrimType_ReverseV2, CommonInferShape) +REG_INFER(Round, PrimType_Round, CommonInferShape) +REG_INFER(Rsqrt, PrimType_Rsqrt, CommonInferShape) +REG_INFER(Scale, PrimType_ScaleFusion, CommonInferShape) +REG_INFER(SigmoidCrossEntropyWithLogits, PrimType_SigmoidCrossEntropyWithLogits, CommonInferShape) +REG_INFER(SigmoidCrossEntropyWithLogitsGrad, PrimType_SigmoidCrossEntropyWithLogitsGrad, CommonInferShape) +REG_INFER(Sin, PrimType_Sin, CommonInferShape) +REG_INFER(SmoothL1Loss, PrimType_SmoothL1Loss, CommonInferShape) +REG_INFER(SmoothL1LossGrad, PrimType_SmoothL1LossGrad, CommonInferShape) +REG_INFER(Sqrt, PrimType_Sqrt, CommonInferShape) +REG_INFER(SqrtGrad, PrimType_SqrtGrad, CommonInferShape) +REG_INFER(Square, PrimType_Square, CommonInferShape) +REG_INFER(ZerosLike, PrimType_ZerosLike, CommonInferShape) +REG_INFER(ScatterElements, PrimType_ScatterElements, CommonInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/common_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/common_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..24ddfb3f10220f0ef7c2ba560ef1d7e03d0f65ed --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/common_infer.h @@ -0,0 +1,94 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_COMMON_H_ +#define MINDSPORE_NNACL_COMMON_H_ + +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" + +bool CheckShaleValid(TensorC **tensors, int tensors_size); +bool CheckInferShapeDone(TensorC **in, int in_size, TensorC **out, int out_size); + +#define EPSILON_VALUE 1e-6 + +enum NNACLLshProjectionType { + LshProjectionType_UNKNOWN = 0, + LshProjectionType_SPARSE = 1, + LshProjectionType_DENSE = 2, + LshProjectionType_MIN = LshProjectionType_UNKNOWN, + LshProjectionType_MAX = LshProjectionType_DENSE +}; + +typedef struct VectorC { + int *data_; + size_t size_; + size_t max_size_; + size_t per_malloc_size_; +} VectorC; + +#ifdef __cplusplus +extern "C" { +#endif + +int CheckAugmentNull(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter); +int CheckAugmentNullSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj); +int CheckAugmentNullSizeInputTwo(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, const OpParameter *parameter, size_t inputs_size_obj_0, + size_t inputs_size_obj_1, size_t outputs_size_obj); +int CheckAugmentNullInputSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj); +int CheckAugmentNullOutputSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t outputs_size_obj); +int CheckAugmentWithMinSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj); +void SetDataTypeFormat(TensorC *dst, const TensorC *src); + +void SetShapeTensor(TensorC *dst, const TensorC *src); +void SetShapeArray(TensorC *dst, const int *src, size_t src_size); +void ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size); +bool Int64ShapeSet(int *dst_shape, size_t *dst_shape_size, const int64_t *src_shape, size_t src_shape_size); +void ShapePush(int *shape, size_t *shape_size, int value); +int GetInt32DataFromTensor(const TensorC *tensor, int *result, size_t *result_size); +int ShapeInsert(int *shape, size_t *shape_size, int index, int value); +int ShapeErase(int *shape, size_t *shape_size, int index); +bool ShapeEqual(const int *shape0, size_t shape0_size, const int *shape1, size_t shape1_size); + +void iswap(int *a, int *b); + +int imin(int a, int b); +int imax(int a, int b); + +int CommonInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); +int CommonGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); +int CommonInferShapeWithOneInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); +int CommonInferShapeWithNHWC(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); +int FftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter); +bool InferFlag(const TensorC *const *inputs, size_t inputs_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_COMMON__H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/concat_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/concat_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..94f838a17b75487cc4ba2e8de6655ba2f20f8ce6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/concat_infer.c @@ -0,0 +1,97 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/concat_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int DataTypeJudge(const TensorC *input, const TensorC *output) { + if ((input->data_type_ != output->data_type_) && + !((input->data_type_ == kNumberTypeFloat16 && output->data_type_ == kNumberTypeFloat32) || + (input->data_type_ == kNumberTypeFloat32 && output->data_type_ == kNumberTypeFloat16))) { + return NNACL_PARAM_INVALID; + } + return NNACL_OK; +} + +int ConcatInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullOutputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input0 = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input0); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + const int *input0_shape = inputs[0]->shape_; + size_t input0_shape_size = inputs[0]->shape_size_; + + ConcatParameter *param = (ConcatParameter *)parameter; + int axis = param->axis_ < 0 ? param->axis_ + (int)input0_shape_size : param->axis_; + if (axis < 0 || axis >= (int)input0_shape_size) { + return NNACL_ERR; + } + if (input0_shape_size > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input0_shape_without_axis[MAX_SHAPE_SIZE] = {0}; + size_t input0_shape_without_axis_size = 0; + ShapeSet(input0_shape_without_axis, &input0_shape_without_axis_size, input0_shape, input0_shape_size); + int erase_ret = ShapeErase(input0_shape_without_axis, &input0_shape_without_axis_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + int output_axis_dim = input0_shape[axis]; + for (size_t i = 1; i < inputs_size; ++i) { + size_t input_i_shape_size = inputs[i]->shape_size_; + if (input_i_shape_size != input0_shape_size) { + return NNACL_PARAM_INVALID; + } + int shape_tmp[MAX_SHAPE_SIZE] = {0}; + size_t shape_tmp_size = 0; + ShapeSet(shape_tmp, &shape_tmp_size, inputs[i]->shape_, inputs[i]->shape_size_); + int data_type_judge = DataTypeJudge(inputs[i], output); + if (data_type_judge != NNACL_OK) { + return data_type_judge; + } + int axis_tmp = shape_tmp[axis]; + erase_ret = ShapeErase(shape_tmp, &shape_tmp_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + if (!ShapeEqual(input0_shape_without_axis, input0_shape_without_axis_size, shape_tmp, shape_tmp_size)) { + return NNACL_ERR; + } + output_axis_dim += axis_tmp; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input0_shape_size; + for (size_t i = 0; i < input0_shape_size; i++) { + output_shape[i] = input0_shape[i]; + } + output_shape[axis] = output_axis_dim; + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(Concat, PrimType_Concat, ConcatInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/concat_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/concat_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c743f3cf793218f575c12ba205ae9c46de65add3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/concat_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONCAT_INFER_H +#define MINDSPORE_NNACL_CONCAT_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/concat_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConcatInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONCAT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/constant_of_shape_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/constant_of_shape_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..e595b041897f734a9e220ab90df5538d4d8cd3e4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/constant_of_shape_infer.c @@ -0,0 +1,71 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/constant_of_shape_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int ConstantOfShapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in_tensor = inputs[0]; + TensorC *out_tensor = outputs[0]; + ConstantOfShapeParameter *param = (ConstantOfShapeParameter *)parameter; + out_tensor->data_type_ = (TypeIdC)(param->data_type_); + out_tensor->format_ = in_tensor->format_; + if (!InferFlag(inputs, inputs_size) || in_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int size = NNACLGetElementNum(in_tensor); + if (size < 0 || size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int out_shape[MAX_SHAPE_SIZE]; + int out_shape_size = size; + switch (in_tensor->data_type_) { + case kNumberTypeInt32: { + int32_t *in_data = (int32_t *)(in_tensor->data_); + for (int i = 0; i < size; ++i) { + out_shape[i] = in_data[i]; + if (out_shape[i] < 0) { + return NNACL_ERR; + } + } + break; + } + case kNumberTypeInt64: { + int64_t *in_data = (int64_t *)(in_tensor->data_); + for (int i = 0; i < size; ++i) { + out_shape[i] = in_data[i]; + if (out_shape[i] < 0) { + return NNACL_ERR; + } + } + break; + } + default: + return NNACL_INFER_INVALID; + } + + SetShapeArray(out_tensor, out_shape, (size_t)out_shape_size); + return NNACL_OK; +} + +REG_INFER(ConstantOfShape, PrimType_ConstantOfShape, ConstantOfShapeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/constant_of_shape_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/constant_of_shape_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..b7ccc57c533c2b4e472e56654a0d7177308303ff --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/constant_of_shape_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONSTANT_OF_SHAPE_INFER_H +#define MINDSPORE_NNACL_CONSTANT_OF_SHAPE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/constant_of_shape_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConstantOfShapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONSTANT_OF_SHAPE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..de4b0cdb6e9ca3aacdfcf55797f7f44b9d69ce36 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_infer.c @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensor_array_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_array_parameter.h" + +int TensorArrayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { +#ifdef Debug + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } +#endif + + TensorC *output = outputs[0]; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + TensorArrayParameter *param = (TensorArrayParameter *)parameter; + if (param == NULL) { + return NNACL_NULL_PTR; + } + + output->data_type_ = param->data_type_; + SetShapeArray(output, param->element_shape_, (size_t)param->element_shape_size_); + + return NNACL_OK; +} + +REG_INFER(TensorArray, PrimType_TensorArray, TensorArrayInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..6bdb26a51e56809c11b54876e905a2527e55c2c6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_INFER_H_ +#define MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorArrayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_read_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_read_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..fbbf22a702b7a3b188b07f3834d51334c76254da --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_read_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensor_array_read_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_array_parameter.h" + +int TensorArrayReadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + // { prim, handle, index } -> node + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + NNACL_CHECK_TRUE_RET(inputs_size >= 1, NNACL_ERR); + NNACL_CHECK_TRUE_RET(outputs_size >= 1, NNACL_ERR); + TensorC *handle = (TensorC *)inputs[0]; + TensorC *output = outputs[0]; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + output->data_type_ = handle->data_type_; + SetShapeArray(output, handle->shape_, handle->shape_size_); + + return NNACL_OK; +} + +REG_INFER(TensorArrayRead, PrimType_TensorArrayRead, TensorArrayReadInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_read_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_read_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..a9fa6fa9c705971f1b9121487d5cd0baed003d76 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_read_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_READ_INFER_H_ +#define MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_READ_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorArrayReadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_READ_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_write_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_write_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..460bee83b7b8828ccdf45a8bd6054086d0d67f1b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_write_infer.c @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensor_array_write_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_array_parameter.h" + +int TensorArrayWriteInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + // { handle, index, value, flow_in } -> empty + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 4, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + NNACL_CHECK_TRUE_RET(inputs_size >= 3, NNACL_ERR); + TensorC *handle = (TensorC *)inputs[0]; + TensorC *value = (TensorC *)inputs[2]; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + TensorArrayParameter *param = (TensorArrayParameter *)parameter; + if (param == NULL) { + return NNACL_NULL_PTR; + } + + if (handle->shape_size_ != value->shape_size_) { + return NNACL_INFER_INVALID; + } + + for (size_t i = 0; i < handle->shape_size_; ++i) { + if (handle->shape_[i] != value->shape_[i]) { + return NNACL_INFER_INVALID; + } + } + + return NNACL_OK; +} + +REG_INFER(TensorArrayWrite, PrimType_TensorArrayWrite, TensorArrayWriteInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_write_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_write_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..224c44e7e87a0aafcdfa36c2914cc18d5a0ac854 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_write_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_WRITE_INFER_H_ +#define MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_WRITE_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorArrayWriteInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_WRITE_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_fromtensor_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_fromtensor_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..c812e2d3d97aedc8e19454f837d6e572c78f9c65 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_fromtensor_infer.c @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensorlist_fromtensor_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensorlist_c_utils.h" +#include "nnacl_c/tensor_c_utils.h" + +int TensorListFromTensorInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorListC *output = (TensorListC *)(outputs[0]); + const TensorC *input0 = inputs[0]; + output->data_type_ = kObjectTypeTensorType; + output->format_ = Format_NHWC; + output->tensors_data_type_ = input0->data_type_; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input0->shape_size_ < 1) { + return NNACL_ERR; + } + int dim0 = input0->shape_[0]; + if (dim0 < 0) { + return NNACL_ERR; + } + const TensorC *input1 = inputs[1]; + if (input1->data_ == NULL) { + return NNACL_NULL_PTR; + } + int *ele_shape_ptr = (int *)(input1->data_); + NNACL_CHECK_NULL_RETURN_ERR(ele_shape_ptr); + vvector tensor_shape; + tensor_shape.size_ = (size_t)(dim0); + tensor_shape.shape_ = (int **)malloc(tensor_shape.size_ * sizeof(int *)); + if (tensor_shape.shape_ == NULL) { + return NNACL_NULL_PTR; + } + tensor_shape.shape_size_ = (int *)malloc(tensor_shape.size_ * sizeof(int)); + if (tensor_shape.shape_size_ == NULL) { + free(tensor_shape.shape_); + return NNACL_NULL_PTR; + } + for (int i = 0; i < dim0; i++) { + tensor_shape.shape_[i] = (int *)(input0->shape_ + 1); + tensor_shape.shape_size_[i] = (int)(input0->shape_size_) - 1; + } + + ShapeSet(output->element_shape_, &(output->element_shape_size_), ele_shape_ptr, (size_t)NNACLGetElementNum(input1)); + output->element_num_ = (size_t)(dim0); + int ret = MallocTensorListData(output, input0->data_type_, &tensor_shape); + if (ret != NNACL_OK) { + free(tensor_shape.shape_); + free(tensor_shape.shape_size_); + return NNACL_ERR; + } + free(tensor_shape.shape_); + free(tensor_shape.shape_size_); + return NNACL_OK; +} + +REG_INFER(TensorListFromTensor, PrimType_TensorListFromTensor, TensorListFromTensorInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_fromtensor_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_fromtensor_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..884e0422017b6111fe21e6f306c62a475308c517 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_fromtensor_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_FROMTENSOR_INFER_H +#define MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_FROMTENSOR_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorListFromTensorInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_FROMTENSOR_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_getitem_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_getitem_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..227043f3793b8d0fa2998e0f6ab478c01f1f49d0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_getitem_infer.c @@ -0,0 +1,102 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensorlist_getitem_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensorlist_c_utils.h" +#include "nnacl_c/tensor_c_utils.h" + +int TensorListGetItemInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (inputs[0]->data_type_ != kObjectTypeTensorType) { + return NNACL_ERR; + } + TensorListC *input0 = (TensorListC *)(inputs[0]); + const TensorC *get_index = inputs[1]; + if (get_index->data_ == NULL) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(get_index) != 1) { + return NNACL_ERR; + } + TensorC *output = outputs[0]; + if (!InferFlag(inputs, inputs_size) || input0->element_num_ == 0) { + return NNACL_INFER_INVALID; + } + int index = ((int *)(get_index->data_))[0]; + if (index < 0 || index > ((int)(input0->element_num_ - 1))) { + return NNACL_ERR; + } + TensorC *tensor_index = input0->tensors_[index]; + NNACL_CHECK_NULL_RETURN_ERR(tensor_index); + + if (tensor_index->data_type_ != kTypeUnknown) { + output->data_type_ = tensor_index->data_type_; + } else { + output->data_type_ = input0->tensors_data_type_; + } + output->format_ = input0->tensors_[index]->format_; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (tensor_index->data_type_ != kTypeUnknown) { + ShapeSet(output->shape_, &(output->shape_size_), tensor_index->shape_, tensor_index->shape_size_); + } else { + const TensorC *input2 = inputs[2]; + NNACL_CHECK_NULL_RETURN_ERR(input2); + NNACL_CHECK_NULL_RETURN_ERR(input2->data_); + int *ele_shape_data = (int *)(input2->data_); + NNACL_CHECK_NULL_RETURN_ERR(ele_shape_data); + int element_shape[MAX_SHAPE_SIZE] = {0}; + size_t element_shape_size = 0; + for (int i = 0; i < NNACLGetElementNum(input2); ++i) { + ShapePush(element_shape, &element_shape_size, ele_shape_data[i]); + } + int status = + TensorListMergeShape(element_shape, &element_shape_size, input0->element_shape_, input0->element_shape_size_); + if (status != NNACL_OK) { + return NNACL_ERR; + } + if (!TensorListIsFullyDefined(element_shape, element_shape_size)) { + for (size_t i = 0; i < input0->element_num_; ++i) { + TensorC *input = input0->tensors_[i]; + NNACL_CHECK_NULL_RETURN_ERR(input); + if (input->data_type_ != kTypeUnknown) { + status = TensorListMergeShape(element_shape, &element_shape_size, input->shape_, input->shape_size_); + if (status != NNACL_OK) { + return NNACL_ERR; + } + } + } + } + if (!TensorListIsFullyDefined(element_shape, element_shape_size)) { // the pre is the same judge condition + return NNACL_ERR; + } + + SetShapeArray(output, element_shape, element_shape_size); + } + + return NNACL_OK; +} + +REG_INFER(TensorListGetItem, PrimType_TensorListGetItem, TensorListGetItemInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_getitem_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_getitem_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..19d4804ad0f6cc1d3b4887870761fad4d911c81a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_getitem_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_GETITEM_INFER_H +#define MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_GETITEM_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/tensorlist_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorListGetItemInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_GETITEM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_reserve_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_reserve_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..544041a7b924135b0585fd41aa76cd9da1c66fd1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_reserve_infer.c @@ -0,0 +1,84 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensorlist_reserve_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_c_utils.h" +#include "nnacl_c/tensor_c_utils.h" + +int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorListParameter *reserve_param = (TensorListParameter *)parameter; + const TensorC *input0 = inputs[0]; + int ele_shape_type = input0->data_type_; + if (ele_shape_type != kNumberTypeInt && ele_shape_type != kNumberTypeInt32) { + return NNACL_ERR; + } + + TensorListC *output = (TensorListC *)(outputs[0]); + output->data_type_ = kObjectTypeTensorType; + output->format_ = Format_NHWC; + output->tensors_data_type_ = reserve_param->element_dtype_; + + if (input0->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int *ele_shape_ptr = (int *)(input0->data_); + + const TensorC *input1 = inputs[1]; + int num_ele_type = input1->data_type_; + if (num_ele_type != kNumberTypeInt && ele_shape_type != kNumberTypeInt32) { + return NNACL_ERR; + } + if (input1->data_ == NULL) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(input1) != 1) { + return NNACL_ERR; + } + int num_elements = ((int *)(input1->data_))[0]; + ShapeSet(output->element_shape_, &(output->element_shape_size_), ele_shape_ptr, (size_t)NNACLGetElementNum(input0)); + output->element_num_ = (size_t)(num_elements); + + vvector tmp_shape; + tmp_shape.size_ = (size_t)(num_elements); + tmp_shape.shape_ = (int **)malloc(tmp_shape.size_ * sizeof(int *)); + if (tmp_shape.shape_ == NULL) { + return NNACL_NULL_PTR; + } + tmp_shape.shape_size_ = (int *)malloc(tmp_shape.size_ * sizeof(int)); + if (tmp_shape.shape_size_ == NULL) { + free(tmp_shape.shape_); + return NNACL_NULL_PTR; + } + + for (size_t i = 0; i < num_elements; ++i) { + tmp_shape.shape_size_[i] = output->element_shape_size_; + tmp_shape.shape_[i] = output->element_shape_; + } + int ret = MallocTensorListData(output, reserve_param->element_dtype_, &tmp_shape); + free(tmp_shape.shape_size_); + free(tmp_shape.shape_); + return ret; +} + +REG_INFER(TensorListReserve, PrimType_TensorListReserve, TensorListReserveInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_reserve_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_reserve_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..1a753f364d69036f5df8a5564054736c64b5413f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_reserve_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_RESERVE_INFER_H +#define MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_RESERVE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_RESERVE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_setitem_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_setitem_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..4b2c77ceafd1144fc5ecf8915e9890f9f04b23c8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_setitem_infer.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensorlist_setitem_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensorlist_c_utils.h" +#include "nnacl_c/tensor_c_utils.h" + +int PreJudge(const TensorC *get_index, TensorListC *input0, const TensorC *value_tensor) { + if (get_index->data_ == NULL) { + return NNACL_INFER_INVALID; + } + + if (get_index->data_type_ != kNumberTypeInt && get_index->data_type_ != kNumberTypeInt32) { + return NNACL_ERR; + } + if (NNACLGetElementNum(get_index) != 1) { + return NNACL_ERR; + } + if (get_index->data_ == NULL) { + return NNACL_NULL_PTR; + } + return NNACL_OK; +} + +int TensorListSetItemInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorListC *input0 = (TensorListC *)(inputs[0]); + const TensorC *get_index = inputs[1]; + const TensorC *value_tensor = inputs[2]; + TensorListC *output0 = (TensorListC *)(outputs[0]); + output0->data_type_ = input0->data_type_; + output0->format_ = input0->format_; + output0->tensors_data_type_ = value_tensor->data_type_; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int judge_ret = PreJudge(get_index, input0, value_tensor); + if (judge_ret != NNACL_OK) { + return judge_ret; + } + + int index = ((int *)(get_index->data_))[0]; + output0->max_elements_num_ = input0->max_elements_num_; + + if (input0->element_num_ == 0 && input0->element_shape_size_ == 0 && index == 0) { + ShapeSet(input0->element_shape_, &(input0->element_shape_size_), value_tensor->shape_, value_tensor->shape_size_); + ShapeSet(output0->element_shape_, &(output0->element_shape_size_), value_tensor->shape_, value_tensor->shape_size_); + } else { + ShapeSet(output0->element_shape_, &(output0->element_shape_size_), input0->element_shape_, + input0->element_shape_size_); + } + + vvector out_shape; + out_shape.size_ = 0; + out_shape.shape_ = (int **)malloc((input0->element_num_ + 1) * sizeof(int *)); + if (out_shape.shape_ == NULL) { + return NNACL_NULL_PTR; + } + out_shape.shape_size_ = (int *)malloc((input0->element_num_ + 1) * sizeof(int)); + if (out_shape.shape_size_ == NULL) { + free(out_shape.shape_); + return NNACL_NULL_PTR; + } + + if (index == 0 && input0->element_num_ == 0) { // uninitialized tensorlist + out_shape.shape_[out_shape.size_] = (int *)(value_tensor->shape_); + out_shape.shape_size_[out_shape.size_] = value_tensor->shape_size_; + out_shape.size_++; + output0->element_num_ = 1; + } else { + output0->element_num_ = input0->element_num_; + for (size_t i = 0; i < input0->element_num_; ++i) { + TensorC *src_ptr = input0->tensors_[i]; + if (src_ptr == NULL) { + free(out_shape.shape_); + free(out_shape.shape_size_); + return NNACL_NULL_PTR; + } + if (src_ptr->data_type_ != kTypeUnknown) { + out_shape.shape_[out_shape.size_] = src_ptr->shape_; + out_shape.shape_size_[out_shape.size_] = (int)(src_ptr->shape_size_); + out_shape.size_++; + } else { + out_shape.shape_[out_shape.size_] = NULL; + out_shape.shape_size_[out_shape.size_] = 0; + out_shape.size_++; + } + } + } + + if (input0->tensors_data_type_ == kTypeUnknown) { + input0->tensors_data_type_ = value_tensor->data_type_; + } + + out_shape.shape_[index] = (int *)(value_tensor->shape_); + out_shape.shape_size_[index] = (int)value_tensor->shape_size_; + int ret = MallocTensorListData(output0, input0->tensors_data_type_, &out_shape); + if (ret != NNACL_OK) { + free(out_shape.shape_); + free(out_shape.shape_size_); + return NNACL_ERR; + } + free(out_shape.shape_); + free(out_shape.shape_size_); + return NNACL_OK; +} + +REG_INFER(TensorListSetItem, PrimType_TensorListSetItem, TensorListSetItemInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_setitem_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_setitem_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..066860e93b6c73ba340c2ac169d60983b4f81394 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_setitem_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_SETITEM_INFER_H +#define MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_SETITEM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorListSetItemInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_SETITEM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_stack_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_stack_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..bf15df5eba879730e1559b7e74741db8d7867fdb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_stack_infer.c @@ -0,0 +1,96 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensorlist_stack_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensorlist_c_utils.h" +#include "nnacl_c/tensor_c_utils.h" + +int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorC *output = outputs[0]; + if (inputs[0]->data_type_ != kObjectTypeTensorType) { + return NNACL_INPUT_TENSOR_ERROR; + } + TensorListC *input0 = (TensorListC *)(inputs[0]); + output->data_type_ = input0->tensors_data_type_; + output->format_ = input0->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input0->element_num_ == 0) { + return NNACL_INFER_INVALID; + } + const TensorC *ele_shape = inputs[1]; // element shape + if (ele_shape->data_ == NULL) { + return NNACL_NULL_PTR; + } + int *ele_shape_ptr = (int *)(ele_shape->data_); + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + if (ele_shape_ptr[0] == -1) { + if (input0->element_shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + for (size_t i = 0; i < input0->element_shape_size_; i++) { + ShapePush(output_shape, &output_shape_size, input0->element_shape_[i]); + } + } else { + int ele_shape_num = NNACLGetElementNum(ele_shape); + if (ele_shape_num > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + for (int i = 0; i < ele_shape_num; ++i) { + ShapePush(output_shape, &output_shape_size, ele_shape_ptr[i]); + } + } + + int status = + TensorListMergeShape(output_shape, &output_shape_size, input0->element_shape_, input0->element_shape_size_); + if (status == NNACL_ERR) { + return NNACL_ERR; + } + if (!TensorListIsFullyDefined(output_shape, output_shape_size)) { + return NNACL_ERR; + } + if (!TensorListIsFullyDefined(input0->element_shape_, input0->element_shape_size_)) { + for (size_t i = 0; i < input0->element_num_; ++i) { + TensorC *tensor_ele = input0->tensors_[i]; + if (tensor_ele->data_type_ != kTypeUnknown) { + status = TensorListMergeShape(output_shape, &output_shape_size, tensor_ele->shape_, tensor_ele->shape_size_); + if (status == NNACL_ERR) { + return NNACL_ERR; + } + } + } + } + if (output_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int ret = ShapeInsert(output_shape, &output_shape_size, 0, input0->element_num_); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(TensorListStack, PrimType_TensorListStack, TensorListStackInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_stack_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_stack_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d9c2b84a0feb8b7c9073f6dbf8c76f6fda68d49c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_stack_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_STACK_INFER_H +#define MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_STACK_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_STACK_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_filter_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_filter_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..44e73686c73a553b9efee2154c78188e83afc7c0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_filter_infer.c @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/infer/conv2d_grad_filter_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int Conv2dGradFilterInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (inputs_size < 3 || outputs_size != 1) { + return NNACL_ERR; + } + if (inputs[FIRST_INPUT]->format_ != Format_NHWC || inputs[SECOND_INPUT]->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[FIRST_INPUT], inputs[FIRST_INPUT]); + + if (inputs[THIRD_INPUT]->shape_size_ < DIMENSION_1D || inputs[THIRD_INPUT]->data_ == NULL) { + return NNACL_ERR; + } + if (inputs[THIRD_INPUT]->shape_[kNCHW_N] < 0) { + return NNACL_ERR; + } + size_t filter_shape_size = (size_t)(inputs[THIRD_INPUT]->shape_[kNCHW_N]); + if (filter_shape_size != DIMENSION_4D) { + return NNACL_ERR; + } + + int filter_shape[MAX_SHAPE_SIZE]; + if (inputs[THIRD_INPUT]->format_ == Format_NCHW || inputs[THIRD_INPUT]->format_ == Format_KCHW) { + const int nchw2nhwc[] = {kNCHW_N, kNCHW_H, kNCHW_W, kNCHW_C}; + for (size_t i = 0; i < filter_shape_size; i++) { + filter_shape[i] = *((int *)(inputs[THIRD_INPUT]->data_) + nchw2nhwc[i]); + } + } else if (inputs[THIRD_INPUT]->format_ == Format_NHWC || inputs[THIRD_INPUT]->format_ == Format_KHWC) { + memcpy(filter_shape, inputs[THIRD_INPUT]->data_, filter_shape_size * sizeof(int)); + } else { + return NNACL_ERR; + } + SetShapeArray(outputs[0], filter_shape, filter_shape_size); + return NNACL_OK; +} + +REG_INFER(Conv2DBackpropFilterFusion, PrimType_Conv2DBackpropFilterFusion, Conv2dGradFilterInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_filter_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_filter_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..4939068cd32453645c7b3269d6d6a876f7d34ea7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_filter_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONV2D_GRAD_FILTER_INFER_H +#define MINDSPORE_NNACL_CONV2D_GRAD_FILTER_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Conv2dGradFilterInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONV2D_GRAD_FILTER_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_input_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_input_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..62311d21dfda6ae63a9491cc46916a588104e63e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_input_infer.c @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/conv2d_grad_input_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int Conv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (inputs_size < 3 || outputs_size != 1) { + return NNACL_ERR; + } + const TensorC *in0 = inputs[0]; + TensorC *out = outputs[0]; + + if (in0 == NULL || out == NULL) { + return NNACL_NULL_PTR; + } + if (in0->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(out, in0); + + if (inputs[THIRD_INPUT]->shape_size_ < 1 || inputs[THIRD_INPUT]->data_ == NULL) { + return NNACL_ERR; + } + size_t data_size = (size_t)inputs[2]->shape_[0]; + if (data_size != 4) { + return NNACL_ERR; + } + + int shape[MAX_SHAPE_SIZE]; + if (inputs[THIRD_INPUT]->format_ == Format_NCHW || inputs[THIRD_INPUT]->format_ == Format_KCHW) { + const int nchw2nhwc[4] = {kNCHW_N, kNCHW_H, kNCHW_W, kNCHW_C}; + for (size_t i = 0; i < data_size; i++) { + shape[i] = *((int *)(inputs[THIRD_INPUT]->data_) + nchw2nhwc[i]); + } + } else if (inputs[THIRD_INPUT]->format_ == Format_NHWC || inputs[THIRD_INPUT]->format_ == Format_KHWC) { + memcpy(shape, inputs[THIRD_INPUT]->data_, data_size * sizeof(int)); + } else { + return NNACL_ERR; + } + SetShapeArray(out, shape, data_size); + return NNACL_OK; +} + +REG_INFER(Conv2DBackpropInputFusion, PrimType_Conv2DBackpropInputFusion, Conv2dGradInputInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_input_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_input_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..58349fed917efa6035a23ff9b903b2098e9de857 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_input_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONV2D_GRAD_INPUT_INFER_H +#define MINDSPORE_NNACL_CONV2D_GRAD_INPUT_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Conv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONV2D_GRAD_INPUT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..42ff7d22b9f2edaf97df964b881f4bd120a98e78 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_infer.c @@ -0,0 +1,169 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/conv2d_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int ConvInferShape(int input_h, int input_w, int *output_h, int *output_w, ConvParameter *param) { + int kernel_w = param->kernel_w_; + int kernel_h = param->kernel_h_; + int stride_w = param->stride_w_; + int stride_h = param->stride_h_; + int dilate_w = param->dilation_w_; + int dilate_h = param->dilation_h_; + + if (stride_w == 0 || stride_h == 0) { + return NNACL_PARAM_INVALID; + } + if (INT_MUL_OVERFLOW(kernel_h, dilate_h) || INT_MUL_OVERFLOW(kernel_w, dilate_w)) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + if (param->pad_mode_ == Pad_same) { // maybe error + *output_w = ceil((float)(input_w) / (float)(stride_w)); + *output_h = ceil((float)(input_h) / (float)(stride_h)); + int pad_h_all = ((*output_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - input_h); + int pad_w_all = ((*output_w - 1) * stride_w + (kernel_w - 1) * dilate_w + 1 - input_w); + if (pad_h_all < 0) { + param->pad_u_ = param->pad_d_ = 0; + } else { + param->pad_u_ = pad_h_all / 2; + param->pad_d_ = pad_h_all - param->pad_u_; + } + if (pad_w_all < 0) { + param->pad_l_ = param->pad_r_ = 0; + } else { + param->pad_l_ = pad_w_all / 2; + param->pad_r_ = pad_w_all - param->pad_l_; + } + } else if (param->pad_mode_ == Pad_valid) { + *output_w = ceil(((float)(input_w) + param->pad_l_ + param->pad_r_ - ((float)(kernel_w)-1) * (float)(dilate_w)) / + (float)(stride_w)); + *output_h = ceil(((float)(input_h) + param->pad_u_ + param->pad_d_ - ((float)(kernel_h)-1) * (float)(dilate_h)) / + (float)(stride_h)); + } else { + int kernel_width = (kernel_w - 1) * dilate_w + 1; + int kernel_height = (kernel_h - 1) * dilate_h + 1; + *output_w = ((input_w) + param->pad_l_ + param->pad_r_ - kernel_width) / stride_w + 1; + *output_h = ((input_h) + param->pad_u_ + param->pad_d_ - kernel_height) / stride_h + 1; + } + + if (param->kernel_h_ > input_h + param->pad_u_ + param->pad_d_ || + param->kernel_w_ > input_w + param->pad_l_ + param->pad_r_) { + return NNACL_PARAM_INVALID; + } + return NNACL_OK; +} + +static const int MAX_CONV_KERNEL_DIM = 10000; // One big value that should not be adopted as the conv kernel dimension. + +int CheckConvAttr(const int input_c, const TensorC *weight_tensor, const ConvParameter *param) { + // common conv: input_c == weight_tensor->shape_[3] + // conv depthwise: input_c == 1 + // group conv: input_c / group == weight_tensor->shape_[3] + NNACL_CHECK_FALSE(param->group_ == 0, NNACL_PARAM_INVALID); + if (input_c != weight_tensor->shape_[3] && input_c != 1 && (input_c / param->group_) != weight_tensor->shape_[3]) { + return NNACL_PARAM_INVALID; + } + + // common conv: group == 1 + // conv depthwise: group == input_c == output_c + // group conv: group == input_c / weight_tensor->shape_[3] + NNACL_CHECK_FALSE(weight_tensor->shape_[3] == 0, NNACL_PARAM_INVALID); + if (param->group_ != 1 && param->group_ != input_c && param->group_ != (input_c / weight_tensor->shape_[3])) { + return NNACL_PARAM_INVALID; + } + if (param->stride_h_ <= 0 || param->stride_w_ <= 0) { + return NNACL_PARAM_INVALID; + } + + if ((param->kernel_h_ >= MAX_CONV_KERNEL_DIM) || (param->kernel_w_ >= MAX_CONV_KERNEL_DIM)) { + return NNACL_PARAM_INVALID; + } + + NNACL_CHECK_TRUE_RET(param->kernel_h_ == weight_tensor->shape_[1], NNACL_PARAM_INVALID); + NNACL_CHECK_TRUE_RET(param->kernel_w_ == weight_tensor->shape_[2], NNACL_PARAM_INVALID); + return NNACL_OK; +} + +int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input_tensor = inputs[0]; + if (input_tensor->format_ != Format_NHWC && input_tensor->format_ != Format_KHWC && + input_tensor->format_ != Format_NC4HW4 && input_tensor->format_ != Format_NC8HW8) { + return NNACL_FORMAT_ERROR; + } + const TensorC *weight_tensor = inputs[1]; + if (weight_tensor->format_ != Format_NHWC && weight_tensor->format_ != Format_KHWC) { + return NNACL_FORMAT_ERROR; + } + TensorC *out_tensor = outputs[0]; + if (out_tensor->format_ != Format_NC4HW4) { + out_tensor->format_ = input_tensor->format_; + } + out_tensor->data_type_ = input_tensor->data_type_; + ConvParameter *param = (ConvParameter *)parameter; + if (param->group_ == 0) { + param->group_ = weight_tensor->shape_[0]; + } + param->output_channel_ = weight_tensor->shape_[0]; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + param->kernel_h_ = param->kernel_h_ != -1 ? param->kernel_h_ : weight_tensor->shape_[1]; + param->kernel_w_ = param->kernel_w_ != -1 ? param->kernel_w_ : weight_tensor->shape_[2]; + + if (input_tensor->shape_size_ == 0) { + return NNACL_INFER_INVALID; + } + + int ret = CheckConvAttr(NNACLGetChannel(input_tensor), weight_tensor, param); + if (ret != NNACL_OK) { + return ret; + } + + int output_w = 0, output_h = 0; + ret = ConvInferShape(NNACLGetHeight(input_tensor), NNACLGetWidth(input_tensor), &output_h, &output_w, param); + if (ret != NNACL_OK) { + return ret; + } + + out_tensor->shape_size_ = input_tensor->shape_size_; + NNACLSetBatch(out_tensor, NNACLGetBatch(input_tensor)); + NNACLSetChannel(out_tensor, NNACLGetBatch(weight_tensor)); + output_h = output_h >= 0 ? output_h : 1; + NNACLSetHeight(out_tensor, output_h); + output_w = output_w >= 0 ? output_w : 1; + NNACLSetWidth(out_tensor, output_w); + + param->input_batch_ = NNACLGetBatch(input_tensor); + param->input_h_ = NNACLGetHeight(input_tensor); + param->input_w_ = NNACLGetWidth(input_tensor); + param->input_channel_ = NNACLGetChannel(input_tensor); + param->output_batch_ = NNACLGetBatch(out_tensor); + param->output_h_ = NNACLGetHeight(out_tensor); + param->output_w_ = NNACLGetWidth(out_tensor); + param->output_channel_ = NNACLGetChannel(out_tensor); + param->out_format_ = out_tensor->format_; + return NNACL_OK; +} + +REG_INFER(Adder, PrimType_AdderFusion, Conv2dInferShape) +REG_INFER(Conv2D, PrimType_Conv2DFusion, Conv2dInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c38b83b4909538fb528a3ea03272a9134266c9b4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONV2D_INFER_H +#define MINDSPORE_NNACL_CONV2D_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONV2D_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv3d_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv3d_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..fbed116eec8d02ee49e723901c0c0d007b76690c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv3d_infer.c @@ -0,0 +1,27 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/conv3d_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int Conv3dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + // The InferShape of Conv3D is not implemented here, it just prevents the InferShape process from being interrupted + // and makes the nodes shape are {}. + return NNACL_OK; +} + +REG_INFER(Conv3D, PrimType_Inner_Conv3D, Conv3dInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv3d_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv3d_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..9c4ee48a7042b694be9ce4f9d211fbcdf227120f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv3d_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONV3D_INFER_H +#define MINDSPORE_NNACL_CONV3D_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Conv3dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONV3D_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_and_resize_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_and_resize_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..cd8e2f45a8a36fb32c15f82e36a93f4c0f681850 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_and_resize_infer.c @@ -0,0 +1,69 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/crop_and_resize_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int CropAndResizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 4); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (outputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 0 && input->shape_size_ != 4) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + if (NNACLGetBatch(input) == 0) { + ShapePush(output_shape, &output_shape_size, 0); + } else if (inputs[1]->data_ != NULL) { + const TensorC *boxes_tensor = inputs[1]; + if (boxes_tensor->shape_size_ < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + ShapePush(output_shape, &output_shape_size, boxes_tensor->shape_[0]); + } else { + return NNACL_INFER_INVALID; + } + + const TensorC *shape_tensor = inputs[3]; + int32_t *data = (int32_t *)(shape_tensor->data_); + if (data == NULL) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(shape_tensor) < 2) { + return NNACL_INPUT_TENSOR_ERROR; + } + ShapePush(output_shape, &output_shape_size, data[0]); + ShapePush(output_shape, &output_shape_size, data[1]); + ShapePush(output_shape, &output_shape_size, NNACLGetChannel(input)); + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(CropAndResize, PrimType_CropAndResize, CropAndResizeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_and_resize_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_and_resize_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..7571b88ea0bc00d4e531aaa21842bd305be0f93f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_and_resize_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CROP_AND_RESIZE_INFER_H +#define MINDSPORE_NNACL_CROP_AND_RESIZE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CropAndResizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CROP_AND_RESIZE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..da920180a2342db1a90851aea2aa2151c9a48e74 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_infer.c @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/crop_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/crop_parameter.h" +#include "nnacl_c/tensor_c_utils.h" + +int CropInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + size_t input_shape_size = inputs[0]->shape_size_; + CropParameter *param = (CropParameter *)parameter; + int64_t axis = param->axis_ < 0 ? param->axis_ + (int64_t)input_shape_size : param->axis_; + if (axis < 0 || axis >= (int64_t)input_shape_size) { + return NNACL_ERR; + } + + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[1]); + return NNACL_OK; +} + +REG_INFER(Crop, PrimType_Crop, CropInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..aab2973742d223afe448ddd127fe220765b94d4f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CROP_INFER_H +#define MINDSPORE_NNACL_CROP_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/crop_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CropInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CROP_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cumsum_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cumsum_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..ebd1d0e4991a8e2126cd8b34879e4e069fd80090 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cumsum_infer.c @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/cumsum_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CumsumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullOutputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(Cumsum, PrimType_CumSum, CumsumInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cumsum_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cumsum_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..877ae308a9a2dfd0d9e1ec61ae42269b401a2dcc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cumsum_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUMSUM_INFER_H +#define MINDSPORE_NNACL_CUMSUM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CumsumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUMSUM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_gru_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_gru_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..2963e460c1701ec16ebe1b412c1e13e1f5a295d7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_gru_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/custom_gru_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CustomGruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C6NUM, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != C3NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + SetShapeTensor(output, input); + const TensorC *weight_in = inputs[1]; + if (weight_in->shape_size_ != C2NUM || weight_in->shape_[0] % C3NUM != 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + output->shape_[C2NUM] = weight_in[0].shape_[0] / C3NUM; + return NNACL_OK; +} + +REG_INFER(CustomGru, PrimType_Inner_CustomGru, CustomGruInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_gru_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_gru_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d154a9715c8efccd77c57541b3a331b1cf89b053 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_gru_infer.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_GRU_INFER_H +#define MINDSPORE_NNACL_CUSTOM_GRU_INFER_H +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomGruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUSTOM_GRU_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_is_inf_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_is_inf_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..740eff8439af0750117e14fe7f1bfaa9d57da296 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_is_inf_infer.c @@ -0,0 +1,40 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/custom_is_inf_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CustomIsInfInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C1NUM, C1NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[FIRST_INPUT]; + TensorC *output = outputs[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + output->data_type_ = kNumberTypeBool; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(CustomIsInf, PrimType_Inner_CustomIsInf, CustomIsInfInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_is_inf_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_is_inf_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..87b8731e7bdbc59ac4fd7a6f94e0f79737bae954 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_is_inf_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H +#define MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomIsInfInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_masked_fill_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_masked_fill_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..302dec2ff64fb50d47f1095dc5e45e20110450e4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_masked_fill_infer.c @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/custom_masked_fill_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CustomMaskedFillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[FIRST_INPUT]; + TensorC *output = outputs[FIRST_INPUT]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(CustomMaskedFill, PrimType_Inner_CustomMaskedFill, CustomMaskedFillInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_masked_fill_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_masked_fill_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..844f5e0e792fa0e22b5065029653244ed3d81b06 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_masked_fill_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H +#define MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomMaskedFillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_tensor_scatter_max_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_tensor_scatter_max_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..da6b55e06a21297a50429b8ad90aec71478ccb61 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_tensor_scatter_max_infer.c @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/custom_tensor_scatter_max_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CustomTensorScatterMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[FIRST_INPUT]; + TensorC *output = outputs[FIRST_INPUT]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(CustomTensorScatterMax, PrimType_Inner_CustomTensorScatterMax, CustomTensorScatterMaxInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_tensor_scatter_max_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_tensor_scatter_max_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..f19cccb78ec0af61b8f8cba146c86a5a97e94d81 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_tensor_scatter_max_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H +#define MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomTensorScatterMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/decoder_layer_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/decoder_layer_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..6257e58e3931d92c19fc110c3e255be85767bf16 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/decoder_layer_infer.c @@ -0,0 +1,36 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/infer/decoder_layer_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int DecoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, C16NUM, C1NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } + const TensorC *input = inputs[FIRST_INPUT]; + TensorC *output0 = outputs[FIRST_INPUT]; + SetDataTypeFormat(output0, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output0, input); + return NNACL_OK; +} + +REG_INFER(DecoderLayer, PrimType_Inner_DecoderLayer, DecoderLayerInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/decoder_layer_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/decoder_layer_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..2b894fd1352e9e49789583552403b13679035ae8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/decoder_layer_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_DECODER_LAYER_INFER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_DECODER_LAYER_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DecoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_DECODER_LAYER_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/deconv2d_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/deconv2d_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..01b243fc25ab78d1f85f9a4335eede7488949554 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/deconv2d_infer.c @@ -0,0 +1,119 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/deconv2d_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int Deconv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + const TensorC *weight = inputs[1]; + TensorC *output = outputs[0]; + output->format_ = input->format_; + output->data_type_ = input->data_type_; + + ConvParameter *param = (ConvParameter *)parameter; + if (param->group_ == 0) { + param->group_ = weight->shape_[0]; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int32_t input_h = NNACLGetHeight(input); + int32_t input_w = NNACLGetWidth(input); + + int32_t output_n = NNACLGetBatch(input); + int32_t output_h = 0; + int32_t output_w = 0; + int32_t output_c = NNACLGetChannel(weight); + NNACL_CHECK_TRUE_RET(NNACLGetChannel(input) == NNACLGetBatch(weight), NNACL_ERR); + if (param->group_ == NNACLGetChannel(input) && 1 == NNACLGetChannel(weight)) { + output_c = NNACLGetBatch(weight); /* depthwise */ + } + + int kernel_w = param->kernel_w_ != -1 ? param->kernel_w_ : NNACLGetWidth(weight); + int kernel_h = param->kernel_h_ != -1 ? param->kernel_h_ : NNACLGetHeight(weight); + NNACL_CHECK_FALSE(kernel_w <= 0, NNACL_ERR); + NNACL_CHECK_FALSE(kernel_h <= 0, NNACL_ERR); + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(kernel_h, kernel_w), NNACL_ERR); + + int stride_w = param->stride_w_; + int stride_h = param->stride_h_; + NNACL_CHECK_FALSE(stride_w <= 0, NNACL_ERR); + NNACL_CHECK_FALSE(stride_h <= 0, NNACL_ERR); + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(input_h, stride_h), NNACL_ERR); + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(input_w, stride_w), NNACL_ERR); + + int dilate_w = param->dilation_w_; + int dilate_h = param->dilation_h_; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(kernel_h, dilate_h), NNACL_ERR); + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(kernel_w, dilate_w), NNACL_ERR); + + int pad_mode = param->pad_mode_; + if (pad_mode == Pad_pad) { + output_h = (input_h - 1) * stride_h + ((kernel_h - 1) * dilate_h + 1) - param->pad_u_ - param->pad_d_; + output_w = (input_w - 1) * stride_w + ((kernel_w - 1) * dilate_w + 1) - param->pad_l_ - param->pad_r_; + } else if (pad_mode == Pad_same) { + output_h = input_h * stride_h; + output_w = input_w * stride_w; + } else if (pad_mode == Pad_valid) { + output_h = (input_h - 1) * stride_h + kernel_h; + output_w = (input_w - 1) * stride_w + kernel_w; + } else { + return NNACL_ERR; + } + + output_h += param->output_padding_h_; + output_w += param->output_padding_w_; + + output->shape_size_ = 4; + output->shape_[0] = output_n; + output->shape_[1] = output_h; + output->shape_[2] = output_w; + output->shape_[3] = output_c; + + if (pad_mode == Pad_same) { + param->pad_u_ = ((input_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - output_h) / 2; + param->pad_l_ = ((input_w - 1) * stride_w + (kernel_w - 1) * dilate_w + 1 - output_w) / 2; + } else if (pad_mode == Pad_valid) { + param->pad_u_ = 0; + param->pad_l_ = 0; + } + + const int *in_shape = input->shape_; + param->input_batch_ = in_shape[0]; + param->input_h_ = in_shape[1]; + param->input_w_ = in_shape[2]; + param->input_channel_ = in_shape[3]; + param->output_batch_ = output_n; + param->output_h_ = output_h; + param->output_w_ = output_w; + param->output_channel_ = output_c; + param->kernel_h_ = kernel_h; + param->kernel_w_ = kernel_w; + return NNACL_OK; +} + +REG_INFER(Conv2dTranspose, PrimType_Conv2dTransposeFusion, Deconv2dInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/deconv2d_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/deconv2d_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..a2c713b12f911a13c9de6129d89b6a7ce729888a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/deconv2d_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DECONV2D_INFER_H +#define MINDSPORE_NNACL_DECONV2D_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Deconv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DECONV2D_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depth_to_space_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depth_to_space_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..bdd1eb1f652b8d8735c2dcd8d3856a3cadf9741a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depth_to_space_infer.c @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/depth_to_space_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int DepthToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[0], input); + DepthToSpaceParameter *param = (DepthToSpaceParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 4) { + return NNACL_PARAM_INVALID; + } + int input_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_); + + int32_t block_size = param->block_size_; + if (INT_MUL_OVERFLOW(block_size, block_size)) { + return NNACL_PARAM_INVALID; + } + if (block_size == 0 || input_shape[kNHWC_C] % (block_size * block_size) != 0 || input_shape[kNHWC_C] == 0) { + return NNACL_PARAM_INVALID; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input_shape_size; + output_shape[kNHWC_N] = input_shape[kNHWC_N]; + output_shape[kNHWC_H] = input_shape[kNHWC_H] * block_size; + output_shape[kNHWC_W] = input_shape[kNHWC_W] * block_size; + output_shape[kNHWC_C] = input_shape[kNHWC_C] / (block_size * block_size); + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(DepthToSpace, PrimType_DepthToSpace, DepthToSpaceInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depth_to_space_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depth_to_space_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..6b67618ae5a1578241b9ae1348a2fd4cd78b22ca --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depth_to_space_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DEPTHTOSPACE_INFER_H +#define MINDSPORE_NNACL_DEPTHTOSPACE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/depth_to_space_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DepthToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DEPTHTOSPACE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depthwise_conv2d_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depthwise_conv2d_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..79b887302039981ff7ec5caff0db8de684fa20e6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depthwise_conv2d_infer.c @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/depthwise_conv2d_infer.h" +#include "nnacl_c/tensor_c_utils.h" + +int DepthwiseConv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + ConvParameter *param = (ConvParameter *)parameter; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 4) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input_h = input->shape_[1]; + int input_w = input->shape_[2]; + int input_channel = input->shape_[3]; + int output_w = 0, output_h = 0; + param->input_channel_ = input_channel; + + if (param->stride_h_ == 0 || param->stride_w_ == 0) { + return NNACL_PARAM_INVALID; + } + param->kernel_h_ = param->kernel_h_ != -1 ? param->kernel_h_ : NNACLGetHeight(inputs[kWeightIndex]); + param->kernel_w_ = param->kernel_w_ != -1 ? param->kernel_w_ : NNACLGetWidth(inputs[kWeightIndex]); + if (param->pad_mode_ == Pad_same) { + output_h = ceil((float)(input_h) / (float)(param->stride_h_)); + output_w = ceil((float)(input_w) / (float)(param->stride_w_)); + int pad_h_all = ((output_h - 1) * param->stride_h_ + (param->kernel_h_ - 1) * param->dilation_h_ + 1 - input_h); + int pad_w_all = ((output_w - 1) * param->stride_w_ + (param->kernel_w_ - 1) * param->dilation_w_ + 1 - input_w); + if (pad_h_all > 0) { + param->pad_u_ = pad_h_all / 2; + param->pad_d_ = pad_h_all - param->pad_u_; + } + if (pad_w_all > 0) { + param->pad_l_ = pad_w_all / 2; + param->pad_r_ = pad_w_all - param->pad_l_; + } + } else { + output_h = ceil(((float)(input_h) + param->pad_u_ + param->pad_d_ - + ((float)(param->kernel_h_) - 1) * (float)(param->dilation_h_)) / + (float)(param->stride_h_)); + output_w = ceil(((float)(input_w) + param->pad_l_ + param->pad_r_ - + ((float)(param->kernel_w_) - 1) * (float)(param->dilation_w_)) / + (float)(param->stride_w_)); + } + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, input->shape_, input->shape_size_); + out_shape[1] = output_h; + out_shape[2] = output_w; + if (param->channel_multiplie_ != 1) { + return NNACL_ERR; + } + out_shape[3] = input_channel; // in_channel * out_channel + SetShapeArray(output, out_shape, out_shape_size); + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depthwise_conv2d_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depthwise_conv2d_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..6230491e8813f530cd0893485c3d3430bc9f1646 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depthwise_conv2d_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DEPTHWISE_CONV2D_INFER_H +#define MINDSPORE_NNACL_DEPTHWISE_CONV2D_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DepthwiseConv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DEPTHWISE_CONV2D_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/detection_post_process_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/detection_post_process_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..c3e43e449a7b0139b4b9de1aac168675a51dd372 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/detection_post_process_infer.c @@ -0,0 +1,83 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/detection_post_process_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int DetectionPostProcessInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 4); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *boxes = inputs[0]; + const TensorC *scores = inputs[1]; + const TensorC *anchors = inputs[2]; + if (boxes->shape_size_ < 2 || scores->shape_size_ < 3 || anchors->shape_size_ < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + DetectionPostProcessParameter *param = (DetectionPostProcessParameter *)parameter; + if (scores->shape_[2] < param->num_classes_) { + return NNACL_ERR; + } + if (scores->shape_[2] - param->num_classes_ > 1) { + return NNACL_ERR; + } + if (boxes->shape_[1] != scores->shape_[1]) { + return NNACL_ERR; + } + if (boxes->shape_[1] != anchors->shape_[0]) { + return NNACL_ERR; + } + + TensorC *detected_boxes = outputs[0]; + TensorC *detected_classes = outputs[1]; + TensorC *detected_scores = outputs[2]; + TensorC *num_det = outputs[3]; + + detected_boxes->format_ = boxes->format_; + detected_boxes->data_type_ = kNumberTypeFloat32; + detected_classes->format_ = boxes->format_; + detected_classes->data_type_ = kNumberTypeFloat32; + detected_scores->format_ = boxes->format_; + detected_scores->data_type_ = kNumberTypeFloat32; + num_det->format_ = boxes->format_; + num_det->data_type_ = kNumberTypeFloat32; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const int max_detections = param->max_detections_; + const int max_classes_per_detection = param->max_classes_per_detection_; + const int num_detected_boxes = (int)(max_detections * max_classes_per_detection); + detected_boxes->shape_size_ = 3; + detected_boxes->shape_[0] = 1; + detected_boxes->shape_[1] = num_detected_boxes; + detected_boxes->shape_[2] = 4; + detected_classes->shape_size_ = 2; + detected_classes->shape_[0] = 1; + detected_classes->shape_[1] = num_detected_boxes; + detected_scores->shape_size_ = 2; + detected_scores->shape_[0] = 1; + detected_scores->shape_[1] = num_detected_boxes; + num_det->shape_size_ = 1; + num_det->shape_[0] = 1; + + return NNACL_OK; +} + +REG_INFER(DetectionPostProcess, PrimType_DetectionPostProcess, DetectionPostProcessInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/detection_post_process_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/detection_post_process_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..4c40cbe790757bd34d9dda39539f22ec13fceaff --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/detection_post_process_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DETECTION_POST_PROCESS_INFER_H +#define MINDSPORE_NNACL_DETECTION_POST_PROCESS_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/detection_post_process_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DetectionPostProcessInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DETECTION_POST_PROCESS_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..dc07820da72acf6dd6fcac691f4b2a69ed345804 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_grad_infer.c @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/dropout_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int DropoutGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (outputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(DropoutGrad, PrimType_DropoutGrad, DropoutGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..f3ef875140477794e692f1c96aa389b75ccca0f2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DROPOUT_GRAD_INFER_H +#define MINDSPORE_NNACL_DROPOUT_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DropoutGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DROPOUT_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..96e4d263e0deb74468f691a286e1350929f26714 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/dropout_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int DropoutInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output0 = outputs[0]; + SetDataTypeFormat(output0, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output0, input); + if (outputs_size > 1) { + TensorC *output1 = outputs[1]; + SetDataTypeFormat(output1, input); + SetShapeTensor(output1, input); + } + return NNACL_OK; +} + +REG_INFER(Dropout, PrimType_Dropout, DropoutInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..73dae73e6e7ca862017fcb8a6c14eda14790d01b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DROPOUT_INFER_H +#define MINDSPORE_NNACL_DROPOUT_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DropoutInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DROPOUT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dynamic_quant_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dynamic_quant_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..66022efcb1ae3b1e00c8b6770e834ca67adfaec7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dynamic_quant_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/dynamic_quant_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/dynamic_quant_parameter.h" + +int DynamicQuantInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + DynamicQuantParameter *param = (DynamicQuantParameter *)parameter; + output->data_type_ = param->dst_type_; + NNACL_CHECK_TRUE_RET(output->data_type_ > kNumberTypeBegin && output->data_type_ < kNumberTypeEnd, NNACL_ERR); + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(DynamicQuant, PrimType_DynamicQuant, DynamicQuantInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dynamic_quant_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dynamic_quant_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..5ede6f2ac7b1c15caa17b7ff88e3a38537f2b4cf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dynamic_quant_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DYNAMIC_QUANT_INFER_H +#define MINDSPORE_NNACL_DYNAMIC_QUANT_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DynamicQuantInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DYNAMIC_QUANT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/embedding_lookup_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/embedding_lookup_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..e37694371320b8805b2cf69ada95505e74061b05 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/embedding_lookup_infer.c @@ -0,0 +1,77 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/embedding_lookup_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int EmbeddingLookupInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size < 2 || outputs_size != 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *params_ = inputs[0]; + const TensorC *ids = inputs[inputs_size - 1]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, params_); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (params_->shape_size_ > MAX_SHAPE_SIZE || ids->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int embedding_shape[MAX_SHAPE_SIZE] = {0}; + size_t embedding_shape_size = 0; + ShapeSet(embedding_shape, &embedding_shape_size, params_->shape_, params_->shape_size_); + int erase_ret = ShapeErase(embedding_shape, &embedding_shape_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + ShapeSet(output_shape, &output_shape_size, ids->shape_, ids->shape_size_); + for (size_t i = 0; i < embedding_shape_size; ++i) { + if (output_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + ShapePush(output_shape, &output_shape_size, embedding_shape[i]); + } + for (size_t i = 1; i < inputs_size - 1; ++i) { + if (inputs[i]->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int embedding_shape_t[MAX_SHAPE_SIZE] = {0}; + size_t embedding_shape_t_size = 0; + ShapeSet(embedding_shape_t, &embedding_shape_t_size, inputs[i]->shape_, inputs[i]->shape_size_); + erase_ret = ShapeErase(embedding_shape_t, &embedding_shape_t_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + bool t_equal = ShapeEqual(embedding_shape_t, embedding_shape_t_size, embedding_shape, embedding_shape_size); + if (!t_equal) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(EmbeddingLookup, PrimType_EmbeddingLookupFusion, EmbeddingLookupInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/embedding_lookup_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/embedding_lookup_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..91715e315f82b186ce55b3823dd81b84c81000e8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/embedding_lookup_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_EMBEDDING_LOOKUP_INFER_H +#define MINDSPORE_NNACL_EMBEDDING_LOOKUP_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int EmbeddingLookupInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_EMBEDDING_LOOKUP_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/encoder_layer_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/encoder_layer_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..54d72144856e0156275bceb23b4c55ecb6cb57a2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/encoder_layer_infer.c @@ -0,0 +1,36 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/infer/encoder_layer_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int EncoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, C9NUM, C1NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } + const TensorC *input = inputs[FIRST_INPUT]; + TensorC *output0 = outputs[FIRST_INPUT]; + SetDataTypeFormat(output0, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output0, input); + return NNACL_OK; +} + +REG_INFER(EncoderLayer, PrimType_Inner_EncoderLayer, EncoderLayerInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/encoder_layer_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/encoder_layer_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..1c156b3580b21b2ba0aa5036aee92eec138d7db0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/encoder_layer_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_ENCODER_LAYER_INFER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_ENCODER_LAYER_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int EncoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_ENCODER_LAYER_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/expand_dims_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/expand_dims_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..cb316584d746fa26e655c58124e59f0bfb8d06c1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/expand_dims_infer.c @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/expand_dims_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int ExpandDimsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullOutputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (parameter->quant_type_ == Quant_QuantWeight) { + output->data_type_ = kNumberTypeFloat32; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (inputs_size < C2NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + + if (inputs[1]->data_ == NULL) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (NNACLGetElementNum(inputs[1]) < 1) { + return NNACL_ERR; + } + int dim = ((int32_t *)(inputs[1]->data_))[0]; + if (dim < 0) { + dim += (int)(input->shape_size_) + 1; + } + if (dim > (int)(input->shape_size_)) { + return NNACL_INPUT_TENSOR_ERROR; + } + + ShapeSet(output->shape_, &(output->shape_size_), input->shape_, input->shape_size_); + int ret = ShapeInsert(output->shape_, &(output->shape_size_), dim, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + return NNACL_OK; +} + +REG_INFER(ExpandDims, PrimType_ExpandDims, ExpandDimsInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/expand_dims_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/expand_dims_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..db53049d5c296d8da5c723d44887da3622c46104 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/expand_dims_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_EXPAND_DIMS_INFER_H +#define MINDSPORE_NNACL_EXPAND_DIMS_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ExpandDimsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_EXPAND_DIMS_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_imag_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_imag_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..c2865d38d35b4fdffe3823a9b77e11926d0f6972 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_imag_infer.c @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/fft_imag_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int FftImagInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + return FftInferShape(inputs, inputs_size, outputs, outputs_size, parameter); +} + +REG_INFER(FftImag, PrimType_FftImag, FftImagInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_imag_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_imag_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..44f5b6f6512d0c05e2acb3db86cd7525d02d2a88 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_imag_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FFT_IMAG_INFER_H +#define MINDSPORE_NNACL_FFT_IMAG_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FftImagInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FFT_IMAG_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_real_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_real_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..a1c3ccc316f4064a5b46918351a37555901ddac0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_real_infer.c @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/fft_real_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int FftRealInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + return FftInferShape(inputs, inputs_size, outputs, outputs_size, parameter); +} + +REG_INFER(FftReal, PrimType_FftReal, FftRealInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_real_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_real_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..0e233c6813e43aa69a02fcc7a1bc4266cdaa980d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_real_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FFT_REAL_INFER_H +#define MINDSPORE_NNACL_FFT_REAL_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FftRealInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FFT_REAL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fill_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fill_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..6b47d2a555e45d40992b0d422a571f95fde2b968 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fill_infer.c @@ -0,0 +1,65 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/fill_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int FillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + const TensorC *dst_shape_tensor = inputs[1]; + if (dst_shape_tensor->data_type_ != kNumberTypeInt && dst_shape_tensor->data_type_ != kNumberTypeInt32) { + return NNACL_ERR; + } + const int32_t *dst_shape = (int32_t *)(dst_shape_tensor->data_); + int num_dims = 1; + if (dst_shape_tensor->shape_size_ != DIMENSION_1D) { + return NNACL_ERR; + } + for (size_t i = 0; i < dst_shape_tensor->shape_size_; ++i) { + if (INT_MUL_OVERFLOW(num_dims, dst_shape_tensor->shape_[i])) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + NNACL_CHECK_FALSE(dst_shape_tensor->shape_[i] < 0, NNACL_ERR); + num_dims *= dst_shape_tensor->shape_[i]; + } + if (num_dims != 0 && dst_shape == NULL) { + return NNACL_INFER_INVALID; + } + if (num_dims > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + for (int i = 0; i < num_dims; i++) { + ShapePush(output_shape, &output_shape_size, dst_shape[i]); + } + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(Fill, PrimType_Fill, FillInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fill_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fill_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..cfe46b02b4f287ce026c1b03260a1888f3a956d2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fill_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FILL_INFER_H +#define MINDSPORE_NNACL_FILL_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FILL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fillv2_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fillv2_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..b2816757c45fcf2fb00a491f627ebc6f2678828e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fillv2_infer.c @@ -0,0 +1,62 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/fillv2_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int FillV2InferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[1]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + const TensorC *dst_shape_tensor = inputs[0]; + const int32_t *dst_shape = (int32_t *)(dst_shape_tensor->data_); + int num_dims = 1; + if (dst_shape_tensor->shape_size_ != DIMENSION_1D) { + return NNACL_ERR; + } + for (size_t i = 0; i < dst_shape_tensor->shape_size_; ++i) { + if (INT_MUL_OVERFLOW(num_dims, dst_shape_tensor->shape_[i])) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + NNACL_CHECK_FALSE(dst_shape_tensor->shape_[i] < 0, NNACL_ERR); + num_dims *= dst_shape_tensor->shape_[i]; + } + if (num_dims != 0 && dst_shape == NULL) { + return NNACL_INFER_INVALID; + } + if (num_dims > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + for (int i = 0; i < num_dims; i++) { + ShapePush(output_shape, &output_shape_size, dst_shape[i]); + } + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(FillV2, PrimType_FillV2, FillV2InferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fillv2_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fillv2_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..00d45bd5b79a45515d62c8ccbc89d9a2f84318f7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fillv2_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FILLV2_INFER_H +#define MINDSPORE_NNACL_FILLV2_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FillV2InferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FILLV2_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..cac743181b620cfc8fa972da876a697fc5d1f787 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_grad_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/flatten_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int FlattenGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int output_shape_size = inputs[1]->shape_[0]; + if (inputs[1]->data_ == NULL || output_shape_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + SetShapeArray(output, (int *)(inputs[1]->data_), (size_t)output_shape_size); + return NNACL_OK; +} + +REG_INFER(FlattenGrad, PrimType_FlattenGrad, FlattenGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..7fa843c36a6d52e531b0b1d1804881118a3035d0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FLATTEN_GRAD_INFER_INFER_H +#define MINDSPORE_NNACL_FLATTEN_GRAD_INFER_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FlattenGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FLATTEN_GRAD_INFER_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..b154952aa4eb61bea753c4f53b7d79a824560435 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_infer.c @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/flatten_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/flatten_parameter.h" + +int FlattenInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ <= 0 || input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_); + FlattenParameter *param = (FlattenParameter *)parameter; + int axis = param->axis_; + // The value for axis must be in the range[-r, r], where r is + // the rank of the input tensor.Negative value means counting + // dimensions from the back. + axis = axis < 0 ? (int)input_shape_size - axis : axis; + if (axis >= (int)input_shape_size) { + return NNACL_ERR; + } + int output_shape[2]; + output_shape[0] = axis == 0 ? 1 : input_shape[0]; + for (size_t i = 1; i < (size_t)axis; i++) { + output_shape[0] *= input_shape[i]; + } + output_shape[1] = input_shape[axis]; + for (size_t i = (size_t)axis + 1; i < input_shape_size; i++) { + output_shape[1] *= input_shape[i]; + } + SetShapeArray(output, output_shape, 2); + return NNACL_OK; +} + +REG_INFER(Flatten, PrimType_Flatten, FlattenInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..fc7671dc616d9fccd120111f1b3498e9bfe61c0e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FLATTEN_INFER_H +#define MINDSPORE_NNACL_FLATTEN_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FlattenInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FLATTEN_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/format_transpose_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/format_transpose_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..6c720a2b40a54b38d32dfc95ee174ba066bcb988 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/format_transpose_infer.c @@ -0,0 +1,67 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/format_transpose_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/format_transpose_parameter.h" +#include "nnacl_c/tensor_c_utils.h" + +int FormatTransposeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + FormatTransposeParameter *param = (FormatTransposeParameter *)parameter; + output->format_ = (int)(param->dst_format_); + output->data_type_ = input->data_type_; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != DIMENSION_4D) { + SetShapeArray(output, input->shape_, input->shape_size_); + return NNACL_OK; + } + + int input_b = NNACLGetBatch(input); + int input_h = NNACLGetHeight(input); + int input_w = NNACLGetWidth(input); + int input_c = NNACLGetChannel(input); + + // set output shape + int out_shape[MAX_SHAPE_SIZE] = {0}; + out_shape[DIMENSION_0D] = input_b; + if (param->dst_format_ == Format_NCHW || param->dst_format_ == Format_NC4HW4 || param->dst_format_ == Format_NC8HW8) { + out_shape[DIMENSION_1D] = input_c; + out_shape[DIMENSION_2D] = input_h; + out_shape[DIMENSION_3D] = input_w; + } else if (param->dst_format_ == Format_NHWC) { + out_shape[DIMENSION_1D] = input_h; + out_shape[DIMENSION_2D] = input_w; + out_shape[DIMENSION_3D] = input_c; + } else { + return NNACL_ERR; + } + + SetShapeArray(output, out_shape, input->shape_size_); + return NNACL_OK; +} + +REG_INFER(FormatTranspose, PrimType_FormatTranspose, FormatTransposeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/format_transpose_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/format_transpose_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..b8bb644fb01f86314d1d9627daaff1b145387d1b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/format_transpose_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FORMAT_TRANSPOSE_INFER_H +#define MINDSPORE_NNACL_FORMAT_TRANSPOSE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FormatTransposeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FORMAT_TRANSPOSE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fse_decoder_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fse_decoder_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..a6d354ea00284a14e87349ddaf0df6450ffcbf53 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fse_decoder_infer.c @@ -0,0 +1,35 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/fse_decoder_infer.h" +#include "nnacl_c/infer/infer_register.h" + +size_t kInputSize = 7; + +int FseDecoderInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, kInputSize, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + const TensorC *cen_input = inputs[4]; + TensorC *output0 = outputs[FIRST_INPUT]; + SetDataTypeFormat(output0, cen_input); + + return NNACL_OK; +} + +REG_INFER(FseDecode, PrimType_Inner_FseDecode, FseDecoderInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fse_decoder_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fse_decoder_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..2f93ba445427e005594e3ae2005a1555cf47cf19 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fse_decoder_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_FSE_DECODER_INFER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_FSE_DECODER_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FseDecoderInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_FSE_DECODER_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/full_connection_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/full_connection_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..e9ffefde5d80fab00c93c8f2aef6778b3db5acf6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/full_connection_infer.c @@ -0,0 +1,92 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/full_connection_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int FullConnectionInferPreJudge(const MatMulParameter *param, size_t inputs_size, const TensorC *input0) { + if ((param->has_bias_ && inputs_size != 3) || (!param->has_bias_ && inputs_size != 2)) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (param->use_axis_ && (param->axis_ < 1 || param->axis_ > (int)(input0->shape_size_))) { + return NNACL_ERR; + } + return NNACL_OK; +} + +int FullConnectionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input0 = inputs[0]; + const TensorC *input1 = inputs[1]; + TensorC *output = outputs[0]; + MatMulParameter *param = (MatMulParameter *)parameter; + SetDataTypeFormat(output, input0); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int pre_judge = FullConnectionInferPreJudge(param, inputs_size, input0); + if (pre_judge != NNACL_OK) { + return pre_judge; + } + int new_k = 1; + if (param->use_axis_) { + for (size_t i = (size_t)(param->axis_); i < input0->shape_size_; ++i) { + new_k *= input0->shape_[i]; + } + if (new_k != input1->shape_[1]) { + return NNACL_INPUT_TENSOR_ERROR; + } + } else { + new_k = input1->shape_[1]; + } + if (param->has_bias_) { + if (inputs[2]->shape_[0] != input1->shape_[0]) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + if (inputs[0]->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, inputs[0]->shape_, inputs[0]->shape_size_); + if (param->use_axis_) { + out_shape_size = (size_t)(param->axis_) + 1; + out_shape[param->axis_] = input1->shape_[0]; + } else { + int total = 1; + for (size_t i = 0; i < input0->shape_size_; ++i) { + total *= input0->shape_[i]; + } + out_shape_size = 2; + if (new_k == 0) { + return NNACL_ERR; + } + int batch_size = total / new_k; + out_shape[0] = batch_size; + out_shape[1] = input1->shape_[0]; + } + SetShapeArray(output, out_shape, out_shape_size); + + return NNACL_OK; +} + +REG_INFER(FullConnection, PrimType_FullConnection, FullConnectionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/full_connection_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/full_connection_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..18cb1c7fe91c9aae022a077b71ca9e3989240054 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/full_connection_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FULL_CONNECTION_INFER_H +#define MINDSPORE_NNACL_FULL_CONNECTION_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/matmul_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FullConnectionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FULL_CONNECTION_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fused_batchnorm_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fused_batchnorm_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..0a8b247d9b7b6a6b21d566055b5889dc6add02c6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fused_batchnorm_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/fused_batchnorm_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int FusedBatchNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + + for (size_t i = 0; i < inputs_size; i++) { + if (outputs_size <= i) { + break; + } + SetShapeTensor(outputs[i], inputs[i]); + SetDataTypeFormat(outputs[i], inputs[i]); + } + if (outputs_size > 5) { + SetDataTypeFormat(outputs[5], inputs[0]); + outputs[5]->shape_size_ = 1; + outputs[5]->shape_[0] = 1; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + return NNACL_OK; +} + +REG_INFER(FusedBatchNorm, PrimType_FusedBatchNorm, FusedBatchNormInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fused_batchnorm_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fused_batchnorm_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..9279dba16a65d4b2af106f0fa9c5063a4606f4b3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fused_batchnorm_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FUSED_BATCHNORM_INFER_H +#define MINDSPORE_NNACL_FUSED_BATCHNORM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FusedBatchNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FUSED_BATCHNORM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_d_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_d_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..2cd90d2122532a0c0d612053018cc2dc1ba707e0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_d_infer.c @@ -0,0 +1,47 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/gather_d_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int GatherDInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + const int input_size_limit = 3; + const int output_size_limit = 1; + if (inputs_size != input_size_limit || outputs_size != output_size_limit) { + return NNACL_ERR; + } + const TensorC *input = inputs[0]; + const TensorC *index = inputs[2]; + TensorC *output = outputs[0]; + output->data_type_ = input->data_type_; + if (parameter->quant_type_ == Quant_QuantWeight) { + output->data_type_ = kNumberTypeFloat32; + } + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + SetShapeTensor(output, index); + return NNACL_OK; +} + +REG_INFER(GatherD, PrimType_GatherD, GatherDInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_d_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_d_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..0b6000998b8e990b4e143751868e2fdd0db948d3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_d_infer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_GATHER_D_INFER_H +#define MINDSPORE_NNACL_GATHER_D_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/gather_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GatherDInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GATHER_D_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..653e1839d37383eda026fbf4b0cc43ba87905eaa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_infer.c @@ -0,0 +1,83 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/gather_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + const size_t kMinimumGradInputsNum = 3; + if (inputs_size < kMinimumGradInputsNum || outputs_size != 1) { + return NNACL_ERR; + } + const TensorC *input = inputs[0]; + const TensorC *indices = inputs[1]; + TensorC *output = outputs[0]; + output->data_type_ = input->data_type_; + if ((input->data_type_ == kNumberTypeInt8 || input->data_type_ == kNumberTypeInt16) && + (parameter->quant_type_ == Quant_QuantWeight || parameter->quant_type_ == Quant_QuantDynamic)) { + output->data_type_ = kNumberTypeFloat32; + } + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE || indices->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (inputs[2]->data_ == NULL) { + return NNACL_NULL_PTR; + } + if (NNACLGetElementNum(inputs[2]) < 1) { + return NNACL_ERR; + } + int axis = *((int *)inputs[2]->data_); + if (axis < 0) { + axis += input->shape_size_; + } + int indices_shape[MAX_SHAPE_SIZE]; + size_t indices_shape_size = 0; + ShapeSet(indices_shape, &indices_shape_size, indices->shape_, indices->shape_size_); + size_t indices_rank = indices_shape_size; + int in_shape[MAX_SHAPE_SIZE] = {0}; + size_t in_shape_size = 0; + ShapeSet(in_shape, &in_shape_size, input->shape_, input->shape_size_); + if ((int)(in_shape_size) < axis + 1) { + return NNACL_ERR; + } + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, in_shape, in_shape_size); + int erase_ret = ShapeErase(out_shape, &out_shape_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + for (int i = (int)(indices_rank - 1); i >= 0; --i) { + ret = ShapeInsert(out_shape, &out_shape_size, axis, indices_shape[i]); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(Gather, PrimType_Gather, GatherInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..25cafb2f9283a2731056ab48c771f4b447ff6177 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GATHER_INFER_H +#define MINDSPORE_NNACL_GATHER_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/gather_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GATHER_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..27512661bcbb2d914db848a0d19900bc6e480537 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.c @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/gather_nd_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int GatherNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + const TensorC *indices = inputs[1]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE || indices->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int in_rank = (int)(input->shape_size_); + int indices_rank = (int)(indices->shape_size_); + for (int i = 0; i < indices_rank; i++) { + NNACL_CHECK_FALSE(indices->shape_[i] == 0, NNACL_ERR); + } + if (indices->shape_[indices_rank - 1] > in_rank) { + return NNACL_OK; + } + int i = 0; + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + for (i = 0; i < indices_rank - 1; ++i) { + ShapePush(out_shape, &out_shape_size, indices->shape_[i]); + } + for (i = indices->shape_[indices_rank - 1]; i < in_rank; ++i) { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(GatherNd, PrimType_GatherNd, GatherNdInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..f2a102fb36b8936d36671330bd8796d363caa13a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GATHER_ND_INFER_H +#define MINDSPORE_NNACL_GATHER_ND_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/gatherNd_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GatherNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GATHER_ND_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/glu_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/glu_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..9ff6ecb423bdb4b0cbc2b2de59f9144382c96deb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/glu_infer.c @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/glu_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/glu_parameter.h" + +int GluInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + GluParameter *param = (GluParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (param->axis_ >= (int)input->shape_size_ || (param->axis_ < 0 && ((int)input->shape_size_ + param->axis_) < 0)) { + return NNACL_ERR; + } + int axis = param->axis_ > 0 ? param->axis_ : (int)input->shape_size_ + param->axis_; + if (axis < 0 || axis >= MAX_SHAPE_SIZE) { + return NNACL_BUFFER_OVERFLOW; + } + output->shape_[axis] /= 2; + return NNACL_OK; +} + +REG_INFER(GLU, PrimType_GLU, GluInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/glu_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/glu_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..a32b34870c2c2c483460c56a086baf1e9b3b799c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/glu_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GLU_INFER_H +#define MINDSPORE_NNACL_GLU_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GluInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GLU_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/grid_sampler_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/grid_sampler_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..25366bb6ffeb996cb68e101b2178507ae50e7165 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/grid_sampler_infer.c @@ -0,0 +1,47 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/grid_sampler_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/grid_sampler_parameter.h" + +int GridSamplerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != inputs[1]->shape_size_) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (input->shape_size_ < DIMENSION_4D) { + return NNACL_INPUT_TENSOR_ERROR; + } + SetShapeTensor(output, input); + for (size_t i = DIMENSION_2D; i < input->shape_size_; ++i) { + output->shape_[i] = inputs[1]->shape_[i - 1]; + } + return NNACL_OK; +} + +REG_INFER(GridSampler, PrimType_Inner_GridSampler, GridSamplerInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/grid_sampler_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/grid_sampler_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..6110b83a0baec4e6a3504b7933ea5cebfb57b30a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/grid_sampler_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GRID_SAMPLER_INFER_H +#define MINDSPORE_NNACL_GRID_SAMPLER_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GridSamplerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GRID_SAMPLER_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_conv2d_grad_input_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_conv2d_grad_input_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..d7fc86de823f763d222442573e12782f48de8297 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_conv2d_grad_input_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/group_conv2d_grad_input_infer.h" + +int GroupConv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (inputs_size < 2 || outputs_size != 1) { + return NNACL_ERR; + } + + const TensorC *in0 = inputs[0]; + TensorC *out = outputs[0]; + + SetDataTypeFormat(out, in0); + + size_t shape_size = in0->shape_size_; + if (shape_size > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int shape_[MAX_SHAPE_SIZE]; + for (size_t i = 0; i < shape_size; i++) { + shape_[i] = in0->shape_[i]; + } + SetShapeArray(out, shape_, shape_size); + + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_conv2d_grad_input_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_conv2d_grad_input_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..e807f4843de5507bae50f4340b290b912b8afc10 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_conv2d_grad_input_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GROUP_CONV2D_GRAD_INPUT_INFER_H +#define MINDSPORE_NNACL_GROUP_CONV2D_GRAD_INPUT_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GroupConv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GROUP_CONV2D_GRAD_INPUT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_norm_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_norm_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..696154933d03e834ca30a4248ec14b7b2b88e6f8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_norm_infer.c @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/group_norm_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int GroupNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output0 = outputs[0]; + SetDataTypeFormat(output0, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output0, input); + return NNACL_OK; +} + +REG_INFER(GroupNorm, PrimType_GroupNormFusion, GroupNormInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_norm_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_norm_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c9f2e245feb42f5d16979689e1c1d9d2aa33257c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_norm_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_GROUP_NORM_INFER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_GROUP_NORM_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GroupNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_GROUP_NORM_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gru_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gru_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..310bad7559558de1455b3ae1a6bea2083c42dd1d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gru_infer.c @@ -0,0 +1,92 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/gru_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int GruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 5, 6, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + const TensorC *weight_gate = inputs[1]; + const TensorC *weight_recurrence = inputs[2]; + const TensorC *bias = inputs[3]; + TensorC *output = outputs[0]; + for (int i = 0; i < 2; i++) { + SetDataTypeFormat(outputs[i], input); + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + const int *in_shape = input->shape_; // seq_len, batch, input_size + const int *w_gate_shape = weight_gate->shape_; // num_direction, hidden_size * 3, input_size + const int *w_recu_shape = weight_recurrence->shape_; // num_direction, hidden_size * 3, hidden_size + const int *bias_shape = bias->shape_; // num_direction, hidden_size * 6 + if (input->shape_size_ != 3 || weight_gate->shape_size_ != 3 || weight_recurrence->shape_size_ != 3) { + return NNACL_ERR; + } + if (w_gate_shape[1] != w_recu_shape[1] || w_recu_shape[1] * 2 != bias_shape[1]) { + return NNACL_ERR; + } + if (inputs_size == 6) { + const int *seq_len_shape = inputs[5]->shape_; + if (seq_len_shape[0] > 1) { + return NNACL_ERR; + } + if (inputs[5]->shape_size_ != 1 && seq_len_shape[0] != in_shape[1]) { + return NNACL_ERR; + } + } + + int hidden_size = w_gate_shape[1] / 3; + // set output + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, in_shape, input->shape_size_); + out_shape[2] = hidden_size; + + GruParameter *param = (GruParameter *)parameter; + if (param->bidirectional_) { + int ret = ShapeInsert(out_shape, &out_shape_size, 1, 2); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + } else { + int ret = ShapeInsert(out_shape, &out_shape_size, 1, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + } + SetShapeArray(output, out_shape, out_shape_size); + // set hidden state + int state_shape[MAX_SHAPE_SIZE]; + size_t state_shape_size = 0; + ShapeSet(state_shape, &state_shape_size, in_shape, input->shape_size_); + state_shape[0] = param->bidirectional_ ? 2 : 1; + state_shape[2] = hidden_size; + SetShapeArray(outputs[1], state_shape, state_shape_size); + return NNACL_OK; +} + +REG_INFER(GRU, PrimType_GRU, GruInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gru_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gru_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..fc57baf2f62dd98aacd22148f06051648fb52085 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gru_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GRU_INFER_H +#define MINDSPORE_NNACL_GRU_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GRU_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c22403e1e34df56c4e911fb596a22d693fc45fd6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_INFER_H_ +#define MINDSPORE_NNACL_INFER_INFER_H_ + +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef int (*InferShape)(const TensorC *const *inputs, size_t input_size, TensorC **outputs, size_t output_size, + OpParameter *parameter); + +InferShape GetInferFunc(int prim_type); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.c new file mode 100644 index 0000000000000000000000000000000000000000..f8ec26dcf683b63dc3a2665cd1f01ca1c75aeade --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.c @@ -0,0 +1,450 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/infer_register.h" + +#ifdef _MSC_VER +#include "nnacl_c/infer/activation_grad_infer.h" +#include "nnacl_c/infer/adam_infer.h" +#include "nnacl_c/infer/adam_weight_decay_infer.h" +#include "nnacl_c/infer/add_sub_grad_infer.h" +#include "nnacl_c/infer/addn_infer.h" +#include "nnacl_c/infer/affine_infer.h" +#include "nnacl_c/infer/all_gather_infer.h" +#include "nnacl_c/infer/apply_momentum_infer.h" +#include "nnacl_c/infer/argmin_max_infer.h" +#include "nnacl_c/infer/arithmetic_compare_infer.h" +#include "nnacl_c/infer/arithmetic_grad_infer.h" +#include "nnacl_c/infer/arithmetic_infer.h" +#include "nnacl_c/infer/assert_op_infer.h" +#include "nnacl_c/infer/assign_add_infer.h" +#include "nnacl_c/infer/assign_infer.h" +#include "nnacl_c/infer/attention_infer.h" +#include "nnacl_c/infer/encoder_layer_infer.h" +#include "nnacl_c/infer/audio_spectrogram_infer.h" +#include "nnacl_c/infer/batch_to_space_infer.h" +#include "nnacl_c/infer/bias_grad_infer.h" +#include "nnacl_c/infer/binary_cross_entropy_infer.h" +#include "nnacl_c/infer/bn_grad_infer.h" +#include "nnacl_c/infer/broadcast_to_infer.h" +#include "nnacl_c/infer/cast_infer.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/infer/concat_infer.h" +#include "nnacl_c/infer/constant_of_shape_infer.h" +#include "nnacl_c/infer/decoder_layer_infer.h" + +#ifdef MSLITE_ENABLE_CONTROLFLOW +#include "nnacl_c/infer/control/tensor_array_infer.h" +#include "nnacl_c/infer/control/tensor_array_read_infer.h" +#include "nnacl_c/infer/control/tensor_array_write_infer.h" +#include "nnacl_c/infer/control/tensorlist_fromtensor_infer.h" +#include "nnacl_c/infer/control/tensorlist_getitem_infer.h" +#include "nnacl_c/infer/control/tensorlist_reserve_infer.h" +#include "nnacl_c/infer/control/tensorlist_setitem_infer.h" +#include "nnacl_c/infer/control/tensorlist_stack_infer.h" +#endif +#include "nnacl_c/infer/conv2d_grad_filter_infer.h" +#include "nnacl_c/infer/conv2d_grad_input_infer.h" +#include "nnacl_c/infer/conv2d_infer.h" +#include "nnacl_c/infer/crop_and_resize_infer.h" +#include "nnacl_c/infer/crop_infer.h" +#include "nnacl_c/infer/cumsum_infer.h" +#include "nnacl_c/infer/deconv2d_infer.h" +#include "nnacl_c/infer/depth_to_space_infer.h" +#include "nnacl_c/infer/depthwise_conv2d_infer.h" +#include "nnacl_c/infer/detection_post_process_infer.h" +#include "nnacl_c/infer/dropout_grad_infer.h" +#include "nnacl_c/infer/dropout_infer.h" +#include "nnacl_c/infer/dynamic_quant_infer.h" +#include "nnacl_c/infer/embedding_lookup_infer.h" +#include "nnacl_c/infer/expand_dims_infer.h" +#include "nnacl_c/infer/fft_imag_infer.h" +#include "nnacl_c/infer/fft_real_infer.h" +#include "nnacl_c/infer/fill_infer.h" +#include "nnacl_c/infer/fillv2_infer.h" +#include "nnacl_c/infer/flatten_grad_infer.h" +#include "nnacl_c/infer/flatten_infer.h" +#include "nnacl_c/infer/full_connection_infer.h" +#include "nnacl_c/infer/fused_batchnorm_infer.h" +#include "nnacl_c/infer/gather_infer.h" +#include "nnacl_c/infer/gather_nd_infer.h" +#include "nnacl_c/infer/glu_infer.h" +#include "nnacl_c/infer/group_conv2d_grad_input_infer.h" +#include "nnacl_c/infer/gru_infer.h" +#include "nnacl_c/infer/instance_norm_infer.h" +#include "nnacl_c/infer/invert_permutation_infer.h" +#include "nnacl_c/infer/layer_norm_grad_infer.h" +#include "nnacl_c/infer/layer_norm_infer.h" +#include "nnacl_c/infer/lin_space_infer.h" +#include "nnacl_c/infer/log_softmax_infer.h" +#include "nnacl_c/infer/lstm_grad_data_infer.h" +#include "nnacl_c/infer/lstm_grad_infer.h" +#include "nnacl_c/infer/lstm_grad_weight_infer.h" +#include "nnacl_c/infer/lstm_infer.h" +#include "nnacl_c/infer/matmul_infer.h" +#include "nnacl_c/infer/max_min_grad_infer.h" +#include "nnacl_c/infer/mfcc_infer.h" +#include "nnacl_c/infer/nllloss_grad_infer.h" +#include "nnacl_c/infer/nllloss_infer.h" +#include "nnacl_c/infer/non_max_suppression_infer.h" +#include "nnacl_c/infer/one_hot_infer.h" +#include "nnacl_c/infer/pad_infer.h" +#include "nnacl_c/infer/pooling_grad_infer.h" +#include "nnacl_c/infer/pooling_infer.h" +#include "nnacl_c/infer/power_infer.h" +#include "nnacl_c/infer/prior_box_infer.h" +#include "nnacl_c/infer/quant_dtype_cast_infer.h" +#include "nnacl_c/infer/ragged_range_infer.h" +#include "nnacl_c/infer/random_normal_infer.h" +#include "nnacl_c/infer/random_standard_normal_infer.h" +#include "nnacl_c/infer/range_infer.h" +#include "nnacl_c/infer/rank_infer.h" +#include "nnacl_c/infer/reduce_infer.h" +#include "nnacl_c/infer/reduce_scatter_infer.h" +#include "nnacl_c/infer/reshape_infer.h" +#include "nnacl_c/infer/resize_grad_infer.h" +#include "nnacl_c/infer/resize_infer.h" +#include "nnacl_c/infer/rfft_infer.h" +#include "nnacl_c/infer/roi_pooling_infer.h" +#include "nnacl_c/infer/scatter_nd_infer.h" +#include "nnacl_c/infer/scatter_nd_update_infer.h" +#include "nnacl_c/infer/select_infer.h" +#include "nnacl_c/infer/sgd_infer.h" +#include "nnacl_c/infer/invalid_infer.h" +#ifndef RUNTIME_PASS_CLIP +#include "nnacl_c/infer/shape_fusion_infer.h" +#endif +#include "nnacl_c/infer/shape_infer.h" +#include "nnacl_c/infer/size_infer.h" +#include "nnacl_c/infer/slice_infer.h" +#include "nnacl_c/infer/softmax_cross_entropy_infer.h" +#include "nnacl_c/infer/softmax_infer.h" +#include "nnacl_c/infer/space_to_batch_infer.h" +#include "nnacl_c/infer/space_to_batch_nd_infer.h" +#include "nnacl_c/infer/space_to_depth_infer.h" +#include "nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.h" +#include "nnacl_c/infer/sparse_to_dense_infer.h" +#include "nnacl_c/infer/splice_infer.h" +#include "nnacl_c/infer/split_infer.h" +#include "nnacl_c/infer/split_with_over_lap_infer.h" +#include "nnacl_c/infer/squeeze_infer.h" +#include "nnacl_c/infer/stack_infer.h" +#include "nnacl_c/infer/strided_slice_grad_infer.h" +#include "nnacl_c/infer/strided_slice_infer.h" +#ifdef MSLITE_ENABLE_STRING_KERNEL +#include "nnacl_c/infer/string/custom_extract_features_infer.h" +#include "nnacl_c/infer/string/custom_normalize_infer.h" +#include "nnacl_c/infer/string/custom_predict_infer.h" +#include "nnacl_c/infer/string/hashtable_lookup_infer.h" +#include "nnacl_c/infer/string/lsh_projection_infer.h" +#include "nnacl_c/infer/string/skip_gram_infer.h" +#endif +#include "nnacl_c/infer/tile_infer.h" +#include "nnacl_c/infer/topk_infer.h" +#include "nnacl_c/infer/transpose_infer.h" +#include "nnacl_c/infer/uniform_real_infer.h" +#include "nnacl_c/infer/unique_infer.h" +#include "nnacl_c/infer/unsorted_segment_sum_infer.h" +#include "nnacl_c/infer/unsqueeze_infer.h" +#include "nnacl_c/infer/unstack_infer.h" +#include "nnacl_c/infer/where_infer.h" +#include "nnacl_c/infer/isfinite_infer.h" +#include "nnacl_c/infer/fse_decoder_infer.h" +#include "nnacl_c/infer/custom_gru_infer.h" + +InferShape g_infer_func[PrimType_MAX] = {0}; +InferShape g_inner_op_infer_func[PrimType_InnerOpMax - PrimType_InnerOpMin] = {0}; +void RegAllInferFunc1() { + g_infer_func[PrimType_NONE] = NULL; + g_infer_func[PrimType_Abs] = CommonInferShape; + g_infer_func[PrimType_AbsGrad] = CommonGradInferShape; + g_infer_func[PrimType_Activation] = CommonInferShape; + g_infer_func[PrimType_ActivationGrad] = ActivationGradInferShape; + g_infer_func[PrimType_Adam] = AdamInferShape; + g_infer_func[PrimType_AdamWeightDecay] = AdamWeightDecayInferShape; + g_infer_func[PrimType_AdderFusion] = Conv2dInferShape; + g_infer_func[PrimType_AddFusion] = ArithmeticInferShape; + g_infer_func[PrimType_AddGrad] = AddSubGradInferShape; + g_infer_func[PrimType_AddN] = AddnInferShape; + g_infer_func[PrimType_Affine] = AffineInferShape; + g_infer_func[PrimType_All] = NULL; + g_infer_func[PrimType_AllGather] = AllGatherInferShape; + g_infer_func[PrimType_ApplyMomentum] = ApplyMomentumInferShape; + g_infer_func[PrimType_ArgMaxFusion] = ArgMinMaxInferShape; + g_infer_func[PrimType_ArgMinFusion] = ArgMinMaxInferShape; + g_infer_func[PrimType_Assert] = AssertOpInferShape; + g_infer_func[PrimType_Assign] = AssignInferShape; + g_infer_func[PrimType_AssignAdd] = AssignAddInferShape; + g_infer_func[PrimType_Attention] = AttentionInferShape; + g_infer_func[PrimType_AudioSpectrogram] = AudioSpectrogramInferShape; + g_infer_func[PrimType_AvgPoolFusion] = PoolingInferShape; + g_infer_func[PrimType_AvgPoolGrad] = PoolingGradInferShape; + g_infer_func[PrimType_BatchNorm] = CommonInferShape; + g_infer_func[PrimType_BatchNormGrad] = BnGradInferShape; + g_infer_func[PrimType_BatchToSpace] = BatchToSpaceInferShape; + g_infer_func[PrimType_BatchToSpaceND] = NULL; + g_infer_func[PrimType_BiasAdd] = ArithmeticInferShape; + g_infer_func[PrimType_BiasAddGrad] = BiasGradInferShape; + g_infer_func[PrimType_BinaryCrossEntropy] = BinaryCrossEntropyInferShape; + g_infer_func[PrimType_BinaryCrossEntropyGrad] = CommonInferShape; + g_infer_func[PrimType_BroadcastTo] = BroadcastToInferShape; + g_infer_func[PrimType_Call] = InvalidInferShape; + g_infer_func[PrimType_Cast] = CastInferShape; + g_infer_func[PrimType_Ceil] = CommonInferShape; + g_infer_func[PrimType_Clip] = CommonInferShape; + g_infer_func[PrimType_Concat] = ConcatInferShape; + g_infer_func[PrimType_ConstantOfShape] = ConstantOfShapeInferShape; + g_infer_func[PrimType_Conv2DBackpropFilterFusion] = Conv2dGradFilterInferShape; + g_infer_func[PrimType_Conv2DBackpropInputFusion] = Conv2dGradInputInferShape; + g_infer_func[PrimType_Conv2DFusion] = Conv2dInferShape; + g_infer_func[PrimType_Conv2dTransposeFusion] = Deconv2dInferShape; + g_infer_func[PrimType_Cos] = CommonInferShape; + g_infer_func[PrimType_Crop] = CropInferShape; + g_infer_func[PrimType_CropAndResize] = CropAndResizeInferShape; + g_infer_func[PrimType_CumSum] = CumsumInferShape; + g_infer_func[PrimType_Custom] = NULL; +#ifdef MSLITE_ENABLE_STRING_KERNEL + g_infer_func[PrimType_CustomExtractFeatures] = CustomExtractFeaturesInferShape; +#endif +} + +void RegAllInferFunc2() { +#ifdef MSLITE_ENABLE_STRING_KERNEL + g_infer_func[PrimType_CustomNormalize] = CustomNormalizeInferShape; + g_infer_func[PrimType_CustomPredict] = CustomPredictInferShape; +#endif + g_infer_func[PrimType_DeConv2DGradFilter] = NULL; + g_infer_func[PrimType_Depend] = CommonInferShape; + g_infer_func[PrimType_DepthToSpace] = DepthToSpaceInferShape; + g_infer_func[PrimType_DetectionPostProcess] = DetectionPostProcessInferShape; + g_infer_func[PrimType_DivFusion] = ArithmeticInferShape; + g_infer_func[PrimType_DivGrad] = ArithmeticGradInferShape; + g_infer_func[PrimType_Dropout] = DropoutInferShape; + g_infer_func[PrimType_DropoutGrad] = DropoutGradInferShape; + g_infer_func[PrimType_DynamicQuant] = DynamicQuantInferShape; + g_infer_func[PrimType_Eltwise] = ArithmeticInferShape; + g_infer_func[PrimType_Elu] = CommonInferShape; + g_infer_func[PrimType_EmbeddingLookupFusion] = EmbeddingLookupInferShape; + g_infer_func[PrimType_Equal] = ArithmeticCompareInferShape; + g_infer_func[PrimType_Erf] = CommonInferShape; + g_infer_func[PrimType_ExpandDims] = ExpandDimsInferShape; + g_infer_func[PrimType_ExpFusion] = CommonInferShape; + g_infer_func[PrimType_FakeQuantWithMinMaxVars] = CommonInferShape; + g_infer_func[PrimType_FakeQuantWithMinMaxVarsPerChannel] = NULL; + g_infer_func[PrimType_FftImag] = FftImagInferShape; + g_infer_func[PrimType_FftReal] = FftRealInferShape; + g_infer_func[PrimType_Fill] = FillInferShape; + g_infer_func[PrimType_FillV2] = FillInferShape; + g_infer_func[PrimType_Flatten] = FlattenInferShape; + g_infer_func[PrimType_FlattenGrad] = FlattenGradInferShape; + g_infer_func[PrimType_Floor] = CommonInferShapeWithOneInput; + g_infer_func[PrimType_FloorDiv] = ArithmeticInferShape; + g_infer_func[PrimType_FloorMod] = ArithmeticInferShape; + g_infer_func[PrimType_FullConnection] = FullConnectionInferShape; + g_infer_func[PrimType_FusedBatchNorm] = FusedBatchNormInferShape; + g_infer_func[PrimType_Gather] = GatherInferShape; + g_infer_func[PrimType_GatherNd] = GatherNdInferShape; + g_infer_func[PrimType_GenOP] = NULL; + g_infer_func[PrimType_GLU] = GluInferShape; + g_infer_func[PrimType_Greater] = ArithmeticCompareInferShape; + g_infer_func[PrimType_GreaterEqual] = ArithmeticCompareInferShape; + g_infer_func[PrimType_GRU] = GruInferShape; +#ifdef MSLITE_ENABLE_STRING_KERNEL + g_infer_func[PrimType_HashtableLookup] = HashtableLoopupInferShape; +#endif + g_infer_func[PrimType_InstanceNorm] = InstanceNormInferShape; + g_infer_func[PrimType_InvertPermutation] = InvertPermutationInferShape; + g_infer_func[PrimType_IsFinite] = IsFiniteInferShape; + g_infer_func[PrimType_L2NormalizeFusion] = CommonInferShape; + g_infer_func[PrimType_LayerNormFusion] = LayerNormInferShape; + g_infer_func[PrimType_LayerNormGrad] = LayerNormGradInferShape; + g_infer_func[PrimType_LeakyRelu] = CommonInferShape; + g_infer_func[PrimType_Less] = ArithmeticCompareInferShape; + g_infer_func[PrimType_LessEqual] = ArithmeticCompareInferShape; + g_infer_func[PrimType_LinSpace] = LinSpaceInferShape; +} + +void RegAllInferFunc3() { + g_infer_func[PrimType_Log] = CommonInferShape; + g_infer_func[PrimType_LogGrad] = CommonGradInferShape; + g_infer_func[PrimType_LogicalAnd] = ArithmeticInferShape; + g_infer_func[PrimType_LogicalNot] = CommonInferShape; + g_infer_func[PrimType_LogicalOr] = ArithmeticInferShape; + g_infer_func[PrimType_LogSoftmax] = LogSoftmaxInferShape; + g_infer_func[PrimType_LpNormalization] = NULL; + g_infer_func[PrimType_LRN] = CommonInferShapeWithNHWC; +#ifdef MSLITE_ENABLE_STRING_KERNEL + g_infer_func[PrimType_LshProjection] = LshProjectionInferShape; +#endif + g_infer_func[PrimType_LSTM] = LstmInferShape; + g_infer_func[PrimType_LSTMGrad] = LstmGradInferShape; + g_infer_func[PrimType_LSTMGradData] = LstmGradDataInferShape; + g_infer_func[PrimType_LSTMGradWeight] = LstmGradWeightInferShape; + g_infer_func[PrimType_MatMulFusion] = MatmulInferShape; + g_infer_func[PrimType_Maximum] = ArithmeticInferShape; + g_infer_func[PrimType_MaximumGrad] = MaxMinGradInferShape; + g_infer_func[PrimType_MaxPoolFusion] = PoolingInferShape; + g_infer_func[PrimType_MaxPoolGrad] = PoolingGradInferShape; + g_infer_func[PrimType_SwitchLayer] = InvalidInferShape; + g_infer_func[PrimType_Mfcc] = MfccInferShape; + g_infer_func[PrimType_MIN] = NULL; + g_infer_func[PrimType_Minimum] = ArithmeticInferShape; + g_infer_func[PrimType_MinimumGrad] = MaxMinGradInferShape; + g_infer_func[PrimType_Mod] = ArithmeticInferShape; + g_infer_func[PrimType_MulFusion] = ArithmeticInferShape; + g_infer_func[PrimType_MulGrad] = ArithmeticGradInferShape; + g_infer_func[PrimType_Neg] = CommonInferShape; + g_infer_func[PrimType_NegGrad] = CommonGradInferShape; + g_infer_func[PrimType_NLLLoss] = NLLLossInferShape; + g_infer_func[PrimType_NLLLossGrad] = NLLLossGradInferShape; + g_infer_func[PrimType_NonMaxSuppression] = NonMaxSuppressionInferShape; + g_infer_func[PrimType_NonZero] = NULL; + g_infer_func[PrimType_NotEqual] = ArithmeticCompareInferShape; + g_infer_func[PrimType_OneHot] = OneHotInferShape; + g_infer_func[PrimType_OnesLike] = NULL; + g_infer_func[PrimType_PadFusion] = PadInferShape; + g_infer_func[PrimType_PartialFusion] = InvalidInferShape; + g_infer_func[PrimType_PowerGrad] = CommonGradInferShape; + g_infer_func[PrimType_PowFusion] = PowerInferShape; + g_infer_func[PrimType_PReLUFusion] = CommonInferShape; + g_infer_func[PrimType_PriorBox] = PriorBoxInferShape; + g_infer_func[PrimType_QuantDTypeCast] = QuantDtypeCastInferShape; + g_infer_func[PrimType_RaggedRange] = RaggedRangeInferShape; + g_infer_func[PrimType_RandomNormal] = RandomNormalInferShape; + g_infer_func[PrimType_RandomStandardNormal] = RandomStandardNormalInferShape; + g_infer_func[PrimType_Range] = RangeInferShape; + g_infer_func[PrimType_Rank] = RankInferShape; +} + +void RegAllInferFunc4() { + g_infer_func[PrimType_RealDiv] = ArithmeticInferShape; + g_infer_func[PrimType_Reciprocal] = CommonInferShape; + g_infer_func[PrimType_ReduceFusion] = ReduceInferShape; + g_infer_func[PrimType_ReduceScatter] = ReduceScatterInferShape; + g_infer_func[PrimType_Reshape] = ReshapeInferShape; + g_infer_func[PrimType_Resize] = ResizeInferShape; + g_infer_func[PrimType_ResizeGrad] = ResizeGradInferShape; + g_infer_func[PrimType_ReverseSequence] = CommonInferShape; + g_infer_func[PrimType_ReverseV2] = CommonInferShape; + g_infer_func[PrimType_Rfft] = RfftInferShape; + g_infer_func[PrimType_ROIPooling] = ROIPoolingInferShape; + g_infer_func[PrimType_Round] = CommonInferShape; + g_infer_func[PrimType_Rsqrt] = CommonInferShape; + g_infer_func[PrimType_RsqrtGrad] = NULL; + g_infer_func[PrimType_ScaleFusion] = CommonInferShape; + g_infer_func[PrimType_ScatterNd] = ScatterNdInferShape; + g_infer_func[PrimType_ScatterNdUpdate] = ScatterNdUpdateInferShape; + g_infer_func[PrimType_TensorScatterAdd] = ScatterNdUpdateInferShape; + g_infer_func[PrimType_Select] = SelectInferShape; + g_infer_func[PrimType_SGD] = SgdInferShape; + g_infer_func[PrimType_Shape] = ShapeInferShape; + g_infer_func[PrimType_SigmoidCrossEntropyWithLogits] = CommonInferShape; + g_infer_func[PrimType_SigmoidCrossEntropyWithLogitsGrad] = CommonInferShape; + g_infer_func[PrimType_Sin] = CommonInferShape; + g_infer_func[PrimType_Size] = SizeInferShape; +#ifdef MSLITE_ENABLE_STRING_KERNEL + g_infer_func[PrimType_SkipGram] = SkipGramInferShape; +#endif + g_infer_func[PrimType_SliceFusion] = SliceInferShape; + g_infer_func[PrimType_SmoothL1Loss] = CommonInferShape; + g_infer_func[PrimType_SmoothL1LossGrad] = CommonInferShape; + g_infer_func[PrimType_Softmax] = SoftMaxInferShape; + g_infer_func[PrimType_SoftmaxCrossEntropyWithLogits] = SoftmaxCrossEntropyInferShape; + g_infer_func[PrimType_SpaceToBatch] = SpaceToBatchInferShape; + g_infer_func[PrimType_SpaceToBatchND] = SpaceToBatchNdInferShape; + g_infer_func[PrimType_SpaceToDepth] = SpaceToDepthInferShape; + g_infer_func[PrimType_SparseSoftmaxCrossEntropyWithLogits] = SparseSoftmaxCrossEntropyWithLogitsInferShape; + g_infer_func[PrimType_SparseToDense] = SparseToDenseInferShape; + g_infer_func[PrimType_Splice] = SpliceInferShape; + g_infer_func[PrimType_Split] = SplitInferShape; + g_infer_func[PrimType_SplitWithOverlap] = SplitWithOverlapInferShape; + g_infer_func[PrimType_Sqrt] = CommonInferShape; + g_infer_func[PrimType_SqrtGrad] = NULL; + g_infer_func[PrimType_Square] = CommonInferShape; + g_infer_func[PrimType_SquaredDifference] = ArithmeticInferShape; + g_infer_func[PrimType_Squeeze] = SqueezeInferShape; + g_infer_func[PrimType_Stack] = StackInferShape; + g_infer_func[PrimType_StridedSlice] = StridedSliceInferShape; + g_infer_func[PrimType_StridedSliceGrad] = StridedSliceGradInferShape; + g_infer_func[PrimType_SubFusion] = ArithmeticInferShape; + g_infer_func[PrimType_SubGrad] = AddSubGradInferShape; +} + +void RegAllInferFunc5() { + g_infer_func[PrimType_Switch] = InvalidInferShape; +#ifdef MSLITE_ENABLE_CONTROLFLOW + g_infer_func[PrimType_TensorArray] = TensorArrayInferShape; + g_infer_func[PrimType_TensorArrayRead] = TensorArrayReadInferShape; + g_infer_func[PrimType_TensorArrayWrite] = TensorArrayWriteInferShape; + g_infer_func[PrimType_TensorListFromTensor] = TensorListFromTensorInferShape; + g_infer_func[PrimType_TensorListGetItem] = TensorListGetItemInferShape; + g_infer_func[PrimType_TensorListReserve] = TensorListReserveInferShape; + g_infer_func[PrimType_TensorListSetItem] = TensorListSetItemInferShape; + g_infer_func[PrimType_TensorListStack] = TensorListStackInferShape; +#endif + g_infer_func[PrimType_TileFusion] = TileInferShape; + g_infer_func[PrimType_TopKFusion] = TopKInferShape; + g_infer_func[PrimType_Transpose] = TransposeInferShape; + g_infer_func[PrimType_UniformReal] = UniformRealInferShape; + g_infer_func[PrimType_Unique] = UniqueInferShape; + g_infer_func[PrimType_UnsortedSegmentSum] = UnsortedSegmentSumInferShape; + g_infer_func[PrimType_Unsqueeze] = UnsqueezeInferShape; + g_infer_func[PrimType_Unstack] = UnstackInferShape; + g_infer_func[PrimType_Where] = WhereInferShape; + g_infer_func[PrimType_ZerosLike] = CommonInferShape; + + // fused operators. + g_inner_op_infer_func[PrimType_Inner_GltextureToOpencl - PrimType_InnerOpMin] = NULL; + g_inner_op_infer_func[PrimType_Inner_Identity - PrimType_InnerOpMin] = NULL; +#ifndef RUNTIME_PASS_CLIP + g_inner_op_infer_func[PrimType_Inner_ShapeFusion - PrimType_InnerOpMin] = ShapeFusionInferShape; + g_inner_op_infer_func[PrimType_Inner_EncoderLayer - PrimType_InnerOpMin] = EncoderLayerInferShape; + g_inner_op_infer_func[PrimType_Inner_DecoderLayer - PrimType_InnerOpMin] = DecoderLayerInferShape; + g_inner_op_infer_func[PrimType_Inner_FseDecode - PrimType_InnerOpMin] = FseDecoderInferShape; +#endif + g_inner_op_infer_func[PrimType_Inner_CustomGru - PrimType_InnerOpMin] = CustomGruInferShape; + g_inner_op_infer_func[PrimType_Inner_ToFormat - PrimType_InnerOpMin] = NULL; +} + +#else +__attribute__((init_priority(101))) InferShape g_infer_func[PrimType_MAX] = {0}; +__attribute__((init_priority(101))) InferShape g_inner_op_infer_func[PrimType_InnerOpMax - PrimType_InnerOpMin] = {0}; +#endif // _MSC_VER + +InferShape GetInferFunc(int prim_type) { +#ifdef _MSC_VER + if (g_infer_func[PrimType_Abs] == NULL) { + RegAllInferFunc1(); + RegAllInferFunc2(); + RegAllInferFunc3(); + RegAllInferFunc4(); + RegAllInferFunc5(); + } +#endif + if (prim_type > PrimType_MIN && prim_type < PrimType_MAX) { + return g_infer_func[prim_type]; + } else if (prim_type >= PrimType_InnerOpMin && prim_type < PrimType_InnerOpMax) { + return g_inner_op_infer_func[prim_type - PrimType_InnerOpMin]; + } + return NULL; +} + +void RegInfer(int prim_type, InferShape func) { + if (prim_type > PrimType_MIN && prim_type < PrimType_MAX) { + g_infer_func[prim_type] = func; + } else if (prim_type >= PrimType_InnerOpMin && prim_type < PrimType_InnerOpMax) { + g_inner_op_infer_func[prim_type - PrimType_InnerOpMin] = func; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.h new file mode 100644 index 0000000000000000000000000000000000000000..4a43a24fdebc4b2ed860665ec76abb886b3379ea --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_INFER_REGISTER_H_ +#define MINDSPORE_NNACL_INFER_INFER_REGISTER_H_ + +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/infer/infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void RegInfer(int prim_type, InferShape func); + +#ifdef _MSC_VER +#define REG_INFER(op, type, func) +#else +#define REG_INFER(op, type, func) \ + __attribute__((constructor(102))) void Reg##op##Infer() { RegInfer(type, func); } +#endif + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_INFER_REGISTER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/instance_norm_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/instance_norm_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..41887cab995df8204b8ea5e28703753901b0bcc7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/instance_norm_infer.c @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/instance_norm_infer.h" +#include "nnacl_c/infer/crop_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int InstanceNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { + return NNACL_NULL_PTR; + } + TensorC *output = outputs[0]; + SetDataTypeFormat(output, inputs[0]); + if (output->format_ == Format_NC4HW4) { + output->format_ = Format_NHWC; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, inputs[0]); + if (inputs[0]->format_ != Format_NC4HW4) { + return NNACL_OK; + } + if (output->shape_size_ <= DIMENSION_2D) { + return NNACL_OK; + } + int channel = output->shape_[1]; + ShapeErase(output->shape_, &output->shape_size_, 1); + ShapePush(output->shape_, &output->shape_size_, channel); + return NNACL_OK; +} +REG_INFER(InstanceNorm, PrimType_InstanceNorm, InstanceNormInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/instance_norm_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/instance_norm_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..cc90bad4134f98a86dc299c7f578db538a6a0f3d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/instance_norm_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INSTANCE_NORM_INFER_H +#define MINDSPORE_NNACL_INSTANCE_NORM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int InstanceNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INSTANCE_NORM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invalid_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invalid_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..11be029d668ab4a768dbbc49cd469809fb443276 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invalid_infer.c @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/invalid_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int InvalidInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + return NNACL_INFER_INVALID; +} + +REG_INFER(PartialFusion, PrimType_PartialFusion, InvalidInferShape) +REG_INFER(Switch, PrimType_Switch, InvalidInferShape) +REG_INFER(Call, PrimType_Call, InvalidInferShape) +REG_INFER(SwitchLayer, PrimType_SwitchLayer, InvalidInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invalid_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invalid_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..e9abbbbbe1c37e733e22b81e7f7f74f4bf1bdf99 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invalid_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INVALID_INFER_H +#define MINDSPORE_NNACL_INVALID_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int InvalidInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INVALID_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invert_permutation_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invert_permutation_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..db56526b5de1f65d9c45327f9c81a037e41abe11 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invert_permutation_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/invert_permutation_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int InvertPermutationInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->data_type_ != kNumberTypeInt32) { + return NNACL_ERR; + } + if (input->shape_size_ != 1) { + return NNACL_ERR; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(InvertPermutation, PrimType_InvertPermutation, InvertPermutationInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invert_permutation_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invert_permutation_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..8f5a8074db2443298a7a6681a04d28c247ff4afc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invert_permutation_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INVERT_PERMUTATION_INFER_H +#define MINDSPORE_NNACL_INVERT_PERMUTATION_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int InvertPermutationInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INVERT_PERMUTATION_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/isfinite_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/isfinite_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..9c11207a14910c92310170147b1820557dba7a10 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/isfinite_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/isfinite_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int IsFiniteInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + output->data_type_ = kNumberTypeBool; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + for (size_t i = 0; i < input->shape_size_; i++) { + output->shape_[i] = input->shape_[i]; + } + output->shape_size_ = input->shape_size_; + return NNACL_OK; +} + +REG_INFER(IsFinite, PrimType_IsFinite, IsFiniteInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/isfinite_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/isfinite_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..46b6802d773fa0a258910801f9f55423668147d5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/isfinite_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ISFINITE_INFER_H_ +#define MINDSPORE_NNACL_ISFINITE_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int IsFiniteInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ISFINITE_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..c3ea45ce37eec1fc81f669c47c04f4961e488731 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_grad_infer.c @@ -0,0 +1,57 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/layer_norm_grad_infer.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32_grad/layernormgrad_parameter.h" +#include "nnacl_c/infer/infer_register.h" + +int LayerNormGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 3); + if (check_ret != NNACL_OK) { + return check_ret; + } + LayerNormGradParameter *param = (LayerNormGradParameter *)parameter; + const TensorC *input_x = inputs[0]; + TensorC *output_dx = outputs[0]; + TensorC *output_dg = outputs[1]; + TensorC *output_db = outputs[2]; + SetDataTypeFormat(output_dx, input_x); + SetDataTypeFormat(output_dg, input_x); + SetDataTypeFormat(output_db, input_x); + SetShapeTensor(output_dx, input_x); + int begin_params_axis = param->begin_params_axis_; + if (param->begin_params_axis_ < 0) { + begin_params_axis += (int)(input_x->shape_size_); + } + size_t size = 0; + if (input_x->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + for (int i = begin_params_axis; i < input_x->shape_size_; i++) { + if (size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + output_dg->shape_[size] = input_x->shape_[i]; + output_db->shape_[size] = input_x->shape_[i]; + size++; + } + output_db->shape_size_ = size; + output_dg->shape_size_ = size; + return NNACL_OK; +} + +REG_INFER(LayerNormGrad, PrimType_LayerNormGrad, LayerNormGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..dc884dc590a2b7351a13f4aa6519c4099e3faf22 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_LAYER_NORM_GRAD_INFER_H_ +#define MINDSPORE_NNACL_INFER_LAYER_NORM_GRAD_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_LAYER_NORM_GRAD_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..24a9021c6a00344c837feb76f3a9885108d3aea6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_infer.c @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/layer_norm_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int LayerNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if ((inputs_size != 1 && inputs_size != 3) || (outputs_size != 1 && outputs_size != 3)) { + return NNACL_INPUT_TENSOR_ERROR; + } + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + LayerNormParameter *param = (LayerNormParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (input->shape_size_ > COMM_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (param->begin_params_axis_ < (-1 * (int)(input->shape_size_)) || + param->begin_params_axis_ >= (int)(input->shape_size_)) { + return NNACL_PARAM_INVALID; + } + param->begin_norm_axis_ = + param->begin_norm_axis_ < 0 ? param->begin_norm_axis_ + ((int)(input->shape_size_)) : param->begin_norm_axis_; + SetShapeTensor(output, input); + // take care of other outputs + if (outputs_size == 3) { + TensorC *output_mean = outputs[1]; + TensorC *output_var = outputs[2]; + SetDataTypeFormat(output_mean, input); + SetDataTypeFormat(output_var, input); + int size = 0; + NNACL_CHECK_TRUE_RET(param->begin_norm_axis_ <= MAX_SHAPE_SIZE, NNACL_ERR); + for (; size < param->begin_norm_axis_; size++) { + output_mean->shape_[size] = input->shape_[size]; + output_var->shape_[size] = input->shape_[size]; + } + output_mean->shape_size_ = (size_t)size; + output_var->shape_size_ = (size_t)size; + } + + return NNACL_OK; +} + +REG_INFER(LayerNormFusion, PrimType_LayerNormFusion, LayerNormInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..85d51d2b0c8b6c81dd111844f99f46fa08934290 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LAYER_NORM_INFER_H +#define MINDSPORE_NNACL_LAYER_NORM_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/layer_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LAYER_NORM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lin_space_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lin_space_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..d774f9f9d6d6308bd9aaf8b8b61e093522a8189d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lin_space_infer.c @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/lin_space_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int LinSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + output->data_type_ = input->data_type_; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(inputs[2]) < 1) { + return NNACL_ERR; + } + int *num = (int *)(inputs[2]->data_); + if (num == NULL) { + return NNACL_INFER_INVALID; + } + output->shape_size_ = 1; + output->shape_[0] = num[0]; + return NNACL_OK; +} + +REG_INFER(LinSpace, PrimType_LinSpace, LinSpaceInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lin_space_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lin_space_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..1f5cf3faafb8d07814c3da1d98991c288fb50e48 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lin_space_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LIN_SPACE_INFER_H +#define MINDSPORE_NNACL_LIN_SPACE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LinSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LIN_SPACE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/log_softmax_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/log_softmax_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..3d2ea4cd27b58f8cd0d4699d7759dc2d3b93f60d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/log_softmax_infer.c @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/log_softmax_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int LogSoftmaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + const int input_size_limit = 1; + const int output_size_limit = 1; + if (inputs_size != input_size_limit || outputs_size != output_size_limit) { + return NNACL_ERR; + } + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > 5) { + return NNACL_ERR; + } + SetShapeTensor(output, input); + SoftmaxParameter *param = (SoftmaxParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (param->axis_ < (-1 * (int)(input->shape_size_)) || param->axis_ >= (int)(input->shape_size_)) { + return NNACL_PARAM_INVALID; + } + return NNACL_OK; +} + +REG_INFER(LogSoftmax, PrimType_LogSoftmax, LogSoftmaxInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/log_softmax_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/log_softmax_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d320fb58ce40bd7aad082cb7e823801b0f58c3e0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/log_softmax_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_LOG_SOFTMAX_INFER_H +#define MINDSPORE_LITE_NNACL_LOG_SOFTMAX_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/softmax_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LogSoftmaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_LITE_NNACL_LOG_SOFTMAX_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_data_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_data_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..a70323d261fbe10584e47652b10b9b4a7a200e5a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_data_infer.c @@ -0,0 +1,60 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/lstm_grad_data_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/lstm_fp32.h" +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" + +int LstmGradDataInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 9, 3); + if (check_ret != NNACL_OK) { + return check_ret; + } + LstmGradParameter *p = (LstmGradParameter *)parameter; + const TensorC *Y = inputs[SECOND_INPUT]; + const TensorC *H = inputs[THIRD_INPUT]; + const TensorC *C = inputs[FOURTH_INPUT]; + const TensorC *weight = inputs[FIFTH_INPUT]; + + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + + for (int i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], Y); + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (Y->shape_size_ != C3NUM || weight->shape_size_ != C3NUM) { + return NNACL_ERR; + } + ShapePush(out_shape, &out_shape_size, Y->shape_[out_shape_size]); + ShapePush(out_shape, &out_shape_size, Y->shape_[out_shape_size]); + ShapePush(out_shape, &out_shape_size, p->input_size_); + + SetShapeArray(outputs[FIRST_INPUT], out_shape, C3NUM); + SetShapeArray(outputs[SECOND_INPUT], H->shape_, H->shape_size_); + SetShapeArray(outputs[THIRD_INPUT], C->shape_, C->shape_size_); + + return NNACL_OK; +} + +REG_INFER(LSTMGradData, PrimType_LSTMGradData, LstmGradDataInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_data_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_data_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..e3a4885d395a2ac9ea17f3c09b565720300626f8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_data_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LSTM_GRAD_DATA_INFER_H +#define MINDSPORE_NNACL_LSTM_GRAD_DATA_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/lstm_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LstmGradDataInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LSTM_GRAD_DATA_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..124b80b7e7207dffbdd2415bd95f477be356b902 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_infer.c @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/lstm_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" + +int LstmGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 11, 4); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + const TensorC *H = inputs[1]; + const TensorC *C = inputs[2]; + const TensorC *weight = inputs[3]; + TensorC *output = outputs[0]; + for (size_t i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], input); + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ != 3 || weight->shape_size_ != 3) { + return NNACL_ERR; + } + + SetShapeArray(output, input->shape_, input->shape_size_); + SetShapeArray(outputs[SECOND_INPUT], H->shape_, H->shape_size_); + SetShapeArray(outputs[THIRD_INPUT], C->shape_, C->shape_size_); + SetShapeArray(outputs[FOURTH_INPUT], weight->shape_, weight->shape_size_); + + return NNACL_OK; +} + +REG_INFER(LSTMGrad, PrimType_LSTMGrad, LstmGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..5044ceca892748d5803664c882ac4e624e33d8b8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LSTM_GRAD_INFER_H +#define MINDSPORE_NNACL_LSTM_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/lstm_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LstmGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LSTM_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_weight_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_weight_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..0e8dfc701f06d1fb24b9eaf9b8a19d599b55aa43 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_weight_infer.c @@ -0,0 +1,61 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/lstm_grad_weight_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" + +int LstmGradWeightInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[FIRST_INPUT]; + const TensorC *H = inputs[SECOND_INPUT]; + const TensorC *Y = inputs[THIRD_INPUT]; + + TensorC *output = outputs[FIRST_INPUT]; + for (int i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], input); + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ != C3NUM || H->shape_size_ != C3NUM || Y->shape_size_ != C3NUM) { + return NNACL_ERR; + } + LstmGradParameter *param = (LstmGradParameter *)parameter; + int has_bias = param->has_bias_; + int output_shape[3] = {0, 1, 1}; + int gate_size = 4 * param->hidden_size_; + output_shape[0] += gate_size * param->input_size_; + output_shape[0] += gate_size * param->hidden_size_; + if (has_bias) { + output_shape[0] += C2NUM * gate_size; + } + int dir_mul = (param->bidirectional_) ? C2NUM : C1NUM; + output_shape[0] *= dir_mul; + SetShapeArray(output, output_shape, C3NUM); + + return NNACL_OK; +} + +REG_INFER(LSTMGradWeight, PrimType_LSTMGradWeight, LstmGradWeightInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_weight_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_weight_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d0ffa18f6a99f708cfed717fba9189f2b798b48a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_weight_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LSTM_GRAD_WEIGHT_INFER_H +#define MINDSPORE_NNACL_LSTM_GRAD_WEIGHT_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/lstm_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LstmGradWeightInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LSTM_GRAD_WEIGHT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..139eedfac6aac760b60e7beb513c3e396a57e540 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.c @@ -0,0 +1,161 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/lstm_infer.h" +#include "nnacl_c/infer/infer_register.h" + +static const int no_of_recorde_values = 5; + +int CheckInputShapeValid(const TensorC *const *inputs, size_t inputs_size, const LstmParameter *parameter) { + if (inputs_size < C6NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + const TensorC *input = inputs[FIRST_INPUT]; + const TensorC *weight_i = inputs[SECOND_INPUT]; + const TensorC *weight_g = inputs[THIRD_INPUT]; + const TensorC *bias = inputs[FOURTH_INPUT]; + const TensorC *hidden_init = inputs[FIFTH_INPUT]; + const TensorC *cell_init = inputs[SIXTH_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(weight_i); + NNACL_CHECK_NULL_RETURN_ERR(weight_g); + NNACL_CHECK_NULL_RETURN_ERR(bias); + NNACL_CHECK_NULL_RETURN_ERR(hidden_init); + NNACL_CHECK_NULL_RETURN_ERR(cell_init); + NNACL_CHECK_TRUE_RET(input->shape_size_ == DIMENSION_3D && weight_i->shape_size_ == DIMENSION_3D && + weight_g->shape_size_ == DIMENSION_3D && bias->shape_size_ == DIMENSION_2D, + NNACL_ERR); + int batch = input->shape_[kNHWC_H]; + int input_size = input->shape_[kNHWC_W]; + int hidden_size = weight_i->shape_[kNHWC_H] / C4NUM; + int out_size = hidden_size; + if (inputs_size == C7NUM) { + NNACL_CHECK_TRUE_RET(inputs[SEVENTH_INPUT]->shape_size_ == DIMENSION_3D, NNACL_INPUT_TENSOR_ERROR); + out_size = inputs[SEVENTH_INPUT]->shape_[kNHWC_H]; + } + bool bidirectional = parameter->bidirectional_; + int bidirection = bidirectional ? C2NUM : C1NUM; + NNACL_CHECK_TRUE_RET(weight_i->shape_[kNHWC_N] == bidirection && weight_i->shape_[kNHWC_H] == hidden_size * C4NUM && + weight_i->shape_[kNHWC_W] == input_size, + NNACL_ERR); + NNACL_CHECK_TRUE_RET(weight_g->shape_[kNHWC_N] == bidirection && weight_g->shape_[kNHWC_H] == hidden_size * C4NUM && + weight_g->shape_[kNHWC_W] == out_size, + NNACL_ERR); + NNACL_CHECK_TRUE_RET(bias->shape_[kNHWC_N] == bidirection && bias->shape_[kNHWC_H] == hidden_size * C8NUM, NNACL_ERR); + if (!bidirectional && hidden_init->shape_size_ == DIMENSION_2D) { + NNACL_CHECK_TRUE_RET(hidden_init->shape_[kNHWC_N] == batch && hidden_init->shape_[kNHWC_H] == out_size, NNACL_ERR); + } else { + NNACL_CHECK_TRUE_RET(hidden_init->shape_size_ == DIMENSION_3D && hidden_init->shape_[kNHWC_N] == bidirection && + hidden_init->shape_[kNHWC_H] == batch && hidden_init->shape_[kNHWC_W] == out_size, + NNACL_ERR); + } + if (!bidirectional && cell_init->shape_size_ == DIMENSION_2D) { + NNACL_CHECK_TRUE_RET(cell_init->shape_[kNHWC_N] == batch && cell_init->shape_[kNHWC_H] == hidden_size, NNACL_ERR); + } else { + NNACL_CHECK_TRUE_RET(cell_init->shape_size_ == DIMENSION_3D && cell_init->shape_[kNHWC_N] == bidirection && + cell_init->shape_[kNHWC_H] == batch && cell_init->shape_[kNHWC_W] == hidden_size, + NNACL_ERR); + } + return NNACL_OK; +} + +int InferFirstOutputMindir(const TensorC *const *inputs, size_t inputs_size, TensorC *output, LstmParameter *param) { + for (size_t i = 0; i < inputs_size; ++i) { + if (inputs[i]->shape_size_ != C3NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + ShapeSet(output->shape_, &output->shape_size_, inputs[0]->shape_, inputs[0]->shape_size_); + int out_size = inputs[SECOND_INPUT]->shape_[THIRD_INPUT]; + output->shape_[THIRD_INPUT] = (param->bidirectional_ ? C2NUM : 1) * out_size; + return NNACL_OK; +} + +int InferFirstOutputNonMindir(const TensorC *const *inputs, size_t inputs_size, TensorC *output, LstmParameter *param) { + if (CheckInputShapeValid(inputs, inputs_size, param) != NNACL_OK) { + return NNACL_ERR; + } + ShapeSet(output->shape_, &output->shape_size_, inputs[0]->shape_, inputs[0]->shape_size_); + const TensorC *hidden_init = inputs[FIFTH_INPUT]; + int out_size = hidden_init->shape_[hidden_init->shape_size_ - 1]; + output->shape_[THIRD_INPUT] = out_size; + int direction = param->bidirectional_ ? C2NUM : C1NUM; + int ret = ShapeInsert(output->shape_, &output->shape_size_, 1, direction); + return ret; +} + +int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 4, 3); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + for (int i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], input); + } + + LstmParameter *param = (LstmParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int hidden_size = 0; + int out_size = 0; + if (inputs_size == C4NUM) { + int ret = InferFirstOutputMindir(inputs, inputs_size, output, param); + if (ret != NNACL_OK) { + return ret; + } + hidden_size = inputs[THIRD_INPUT]->shape_[THIRD_INPUT]; + out_size = inputs[SECOND_INPUT]->shape_[THIRD_INPUT]; + } else { + int ret = InferFirstOutputNonMindir(inputs, inputs_size, output, param); + if (ret != NNACL_OK) { + return ret; + } + hidden_size = inputs[SIXTH_INPUT]->shape_[inputs[SIXTH_INPUT]->shape_size_ - 1]; + out_size = inputs[FIFTH_INPUT]->shape_[inputs[FIFTH_INPUT]->shape_size_ - 1]; + } + + int dir_multiplier = param->bidirectional_ ? C2NUM : C1NUM; + int state_shape[MAX_SHAPE_SIZE]; + size_t state_shape_size = 0; + + ShapeSet(state_shape, &state_shape_size, input->shape_, input->shape_size_); + state_shape[FIRST_INPUT] = dir_multiplier; + state_shape[THIRD_INPUT] = out_size; + SetShapeArray(outputs[SECOND_INPUT], state_shape, state_shape_size); + state_shape[THIRD_INPUT] = hidden_size; + SetShapeArray(outputs[THIRD_INPUT], state_shape, state_shape_size); + + if (outputs_size > DIMENSION_4D) { + int intermediate_states_shape[MAX_SHAPE_SIZE]; + const size_t intermediate_states_shape_size = 1; + int batch_size = input->shape_[SECOND_INPUT]; + int seq_len = input->shape_[FIRST_INPUT]; + intermediate_states_shape[FIRST_INPUT] = + batch_size * seq_len * dir_multiplier * (out_size + no_of_recorde_values * hidden_size); + SetShapeArray(outputs[FOURTH_INPUT], intermediate_states_shape, intermediate_states_shape_size); + SetShapeArray(outputs[FIFTH_INPUT], state_shape, state_shape_size); + } + + return NNACL_OK; +} + +REG_INFER(LSTM, PrimType_LSTM, LstmInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..20392e1f926610809df4e45a7169a19bf330d4ba --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LSTM_INFER_H +#define MINDSPORE_NNACL_LSTM_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/lstm_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LSTM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/matmul_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/matmul_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..a8fd494f908feddb314bc9123c6d78ab6278254a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/matmul_infer.c @@ -0,0 +1,148 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/matmul_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" + +#define MIN_SHAPE_SIZE 2 + +int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_t b_shape_size, const int *bias_shape, + size_t bias_shape_size, const MatMulParameter *param) { + if (a_shape_size < MIN_SHAPE_SIZE || b_shape_size < MIN_SHAPE_SIZE) { + return NNACL_PARAM_INVALID; + } + for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) { + int min_value = MSMIN(a_shape[i], b_shape[i]); + int max_value = MSMAX(a_shape[i], b_shape[i]); + if (min_value != 0 && max_value % min_value != 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + if (param->a_transpose_) { + iswap(&a_shape[a_shape_size - 1], &a_shape[a_shape_size - DIMENSION_2D]); + } + if (param->b_transpose_) { + iswap(&b_shape[b_shape_size - 1], &b_shape[b_shape_size - 2]); + } + if (bias_shape_size == DIMENSION_1D && bias_shape[0] != b_shape[b_shape_size - 1]) { + return NNACL_ERR; + } + if (a_shape[a_shape_size - 1] != b_shape[b_shape_size - 2]) { + return NNACL_ERR; + } + return NNACL_OK; +} + +int CheckMatMulBias(int *shape, size_t dim_size) { + if (dim_size > 1) { + for (size_t i = 0; i < dim_size - 1; i++) { + if (shape[i] != DIMENSION_1D) { + return NNACL_ERR; + } + } + } + return NNACL_OK; +} + +int SetShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + TensorC *input0 = (TensorC *)inputs[0]; + TensorC *input1 = (TensorC *)inputs[1]; + TensorC *output = outputs[0]; + MatMulParameter *param = (MatMulParameter *)parameter; + int a_shape[MAX_SHAPE_SIZE] = {0}; + size_t a_shape_size = 0; + ShapeSet(a_shape, &a_shape_size, input0->shape_, input0->shape_size_); + int b_shape[MAX_SHAPE_SIZE] = {0}; + size_t b_shape_size = 0; + ShapeSet(b_shape, &b_shape_size, input1->shape_, input1->shape_size_); + int *shape_align = a_shape_size > b_shape_size ? b_shape : a_shape; + size_t *shape_size_align = a_shape_size > b_shape_size ? &b_shape_size : &a_shape_size; + int diff = abs((int)a_shape_size - (int)b_shape_size); + for (int i = 0; i < diff; ++i) { + ShapeInsert(shape_align, shape_size_align, 0, 1); + } + int bias_shape[MAX_AXIS_SIZE] = {0}; + size_t bias_shape_size = 0; + if (inputs_size == kInputSize2) { + TensorC *bias = (TensorC *)inputs[2]; + ShapeSet(bias_shape, &bias_shape_size, bias->shape_, bias->shape_size_); + NNACL_CHECK_TRUE_RET(CheckMatMulBias(bias_shape, bias_shape_size) == NNACL_OK, NNACL_ERR); + } + + bool del_start = false; + bool del_end = false; + if (a_shape_size == 1) { + int insert_ret = ShapeInsert(a_shape, &a_shape_size, 0, 1); + if (insert_ret != NNACL_OK) { + return NNACL_ERR; + } + del_start = true; + } + if (b_shape_size == 1) { + ShapePush(b_shape, &b_shape_size, 1); + del_end = true; + } + int ret = CheckMatmulInputShape(a_shape, a_shape_size, b_shape, b_shape_size, bias_shape, bias_shape_size, param); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + int c_shape[MAX_SHAPE_SIZE]; + size_t c_shape_size = 0; + ShapeSet(c_shape, &c_shape_size, a_shape, a_shape_size); + c_shape[c_shape_size - 1] = b_shape[b_shape_size - 1]; + if (del_start) { + int erase_ret = ShapeErase(c_shape, &c_shape_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + } + if (del_end) { + c_shape_size--; + } + + for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) { + c_shape[i] = MSMAX(a_shape[i], b_shape[i]); + } + + SetShapeArray(output, c_shape, c_shape_size); + return NNACL_OK; +} + +int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorC *input0 = (TensorC *)inputs[0]; + TensorC *input1 = (TensorC *)inputs[1]; + TensorC *output = outputs[0]; + + TensorC *input = input1->data_ == NULL ? input1 : input0; // transfer the input which comes from the other node. + SetDataTypeFormat(output, input); + if (input->data_type_ == kNumberTypeInt8 && parameter->quant_type_ == Quant_QuantDynamic) { + output->data_type_ = kNumberTypeFloat32; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + return SetShape(inputs, inputs_size, outputs, outputs_size, parameter); +} + +REG_INFER(MatMul, PrimType_MatMulFusion, MatmulInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/matmul_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/matmul_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..f4d513294376b79b78b5409edfe04e69cf28baf1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/matmul_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_MATMUL_INFER_H +#define MINDSPORE_NNACL_MATMUL_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/matmul_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_MATMUL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/max_min_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/max_min_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..cbeb70a52d2f92db6929dbfc3ddf1ca60a21c14d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/max_min_grad_infer.c @@ -0,0 +1,65 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/max_min_grad_infer.h" +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/infer/infer_register.h" + +int MaxMinGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *x1 = inputs[0]; + const TensorC *x2 = inputs[1]; + const TensorC *dy = inputs[2]; + TensorC *dx1 = outputs[0]; + TensorC *dx2 = outputs[1]; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (x1->shape_size_ > MAX_SHAPE_SIZE || x2->shape_size_ > MAX_SHAPE_SIZE || dy->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + ArithmeticParameter *param = (ArithmeticParameter *)parameter; + + param->ndim_ = dy->shape_size_; + param->in_elements_num0_ = (int)(param->ndim_); + param->in_elements_num1_ = (int)(param->ndim_); + param->out_elements_num_ = (int)(param->ndim_); + int fillDimNum0 = (int)(dy->shape_size_ - x1->shape_size_); + int fillDimNum1 = (int)(dy->shape_size_ - x2->shape_size_); + int j0 = 0; + int j1 = 0; + for (unsigned int i = 0; i < dy->shape_size_; i++) { + param->in_shape0_[i] = ((int)i < fillDimNum0) ? 1 : x1->shape_[j0++]; + param->in_shape1_[i] = ((int)i < fillDimNum1) ? 1 : x2->shape_[j1++]; + param->out_shape_[i] = dy->shape_[i]; + } + + SetShapeTensor(dx1, x1); + SetShapeTensor(dx2, x2); + SetDataTypeFormat(dx1, dy); + SetDataTypeFormat(dx2, dy); + return NNACL_OK; +} + +REG_INFER(MaximumGrad, PrimType_MaximumGrad, MaxMinGradInferShape) +REG_INFER(MinimumGrad, PrimType_MinimumGrad, MaxMinGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/max_min_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/max_min_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..b927f5a53b25ef7ea430cb230db2cf0e39dccd09 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/max_min_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_MAX_MIN_GRAD_INFER_H_ +#define MINDSPORE_NNACL_INFER_MAX_MIN_GRAD_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int MaxMinGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_MAX_MIN_GRAD_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/mfcc_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/mfcc_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..1d9f27a7fc6eaef5d7f31d707a51791e353cf9e9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/mfcc_infer.c @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/mfcc_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int MfccInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 3) { + return NNACL_ERR; + } + if (NNACLGetElementNum(inputs[1]) != 1) { + return NNACL_ERR; + } + output->shape_size_ = 3; + output->shape_[0] = input->shape_[0]; + output->shape_[1] = input->shape_[1]; + MfccParameter *param = (MfccParameter *)parameter; + output->shape_[2] = param->dct_coeff_num_; + return NNACL_OK; +} + +REG_INFER(Mfcc, PrimType_Mfcc, MfccInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/mfcc_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/mfcc_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c2b02349e476928c9d4933d9f12bf4ff6508426f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/mfcc_infer.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_MFCC_INFER_H +#define MINDSPORE_NNACL_MFCC_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct MfccParameter { + OpParameter op_parameter_; + int dct_coeff_num_; +} MfccParameter; + +int MfccInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_MFCC_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..8ab148c922b0402b6db95d5d1f0fd2f3c2135c16 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_grad_infer.c @@ -0,0 +1,54 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/nllloss_grad_infer.h" + +#include "nnacl_c/infer/infer_register.h" + +int NLLLossGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, C5NUM, C1NUM); + if (ret != NNACL_OK) { + return ret; + } + + const TensorC *logits = inputs[0]; + const TensorC *loss_grad = inputs[1]; + const TensorC *labels = inputs[2]; + const TensorC *weight = inputs[3]; + const TensorC *total_weight = inputs[4]; + if (logits->shape_size_ != C2NUM || labels->shape_size_ != C1NUM || weight->shape_size_ != C1NUM || + total_weight->shape_size_ != 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (labels->shape_[0] != logits->shape_[0] || weight->shape_[0] != logits->shape_[1]) { + return NNACL_INPUT_TENSOR_ERROR; + } + + NLLLossParameter *param = (NLLLossParameter *)parameter; + if (param->reduction_type_ == Reduction_None && loss_grad->shape_size_ != C1NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (param->reduction_type_ != Reduction_None && loss_grad->shape_size_ != 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + TensorC *logits_grad = outputs[0]; + SetDataTypeFormat(logits_grad, logits); + SetShapeTensor(logits_grad, logits); + return NNACL_OK; +} + +REG_INFER(NLLLossGrad, PrimType_NLLLossGrad, NLLLossGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..9fcb5f9ddc1bf7de802f398c223bffdc66a458c0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_grad_infer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_NLLLOSS_GRAD_INFER_H +#define MINDSPORE_NNACL_NLLLOSS_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/nllloss_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int NLLLossGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_NLLLOSS_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..ac1f1411e279e5f479748a1ccdf1facfaea143fb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_infer.c @@ -0,0 +1,52 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/nllloss_infer.h" + +#include "nnacl_c/infer/infer_register.h" + +int NLLLossInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C2NUM); + if (ret != NNACL_OK) { + return ret; + } + + const TensorC *logits = inputs[0]; + const TensorC *labels = inputs[1]; + const TensorC *weight = inputs[2]; + if (logits->shape_size_ != C2NUM || labels->shape_size_ != C1NUM || weight->shape_size_ != C1NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (labels->shape_[0] != logits->shape_[0] || weight->shape_[0] != logits->shape_[1]) { + return NNACL_INPUT_TENSOR_ERROR; + } + TensorC *loss = outputs[0]; + TensorC *total_weight = outputs[1]; + + NLLLossParameter *param = (NLLLossParameter *)parameter; + if (param->reduction_type_ == Reduction_None) { + SetShapeTensor(loss, labels); + } else { + loss->shape_size_ = 0; + } + total_weight->shape_size_ = 0; + SetDataTypeFormat(loss, logits); + SetDataTypeFormat(total_weight, logits); + return NNACL_OK; +} + +REG_INFER(NLLLoss, PrimType_NLLLoss, NLLLossInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c9d011547c055a4e62e18ca3ca0ec438255f21dd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_infer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_NLLLOSS_INFER_H +#define MINDSPORE_NNACL_NLLLOSS_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/nllloss_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int NLLLossInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_NLLLOSS_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/non_max_suppression_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/non_max_suppression_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..5f6808e9297d98490dd756796a2c8957b38a1445 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/non_max_suppression_infer.c @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/non_max_suppression_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int NonMaxSuppressionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->data_type_ = kNumberTypeInt32; + output->format_ = input->format_; + return NNACL_INFER_INVALID; +} + +REG_INFER(NonMaxSuppression, PrimType_NonMaxSuppression, NonMaxSuppressionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/non_max_suppression_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/non_max_suppression_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..b802a88ad494798f98c8aa2a1f30946ca9ed4c50 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/non_max_suppression_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_NON_MAX_SUPPRESSION_INFER_H +#define MINDSPORE_NNACL_NON_MAX_SUPPRESSION_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int NonMaxSuppressionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_NON_MAX_SUPPRESSION_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/one_hot_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/one_hot_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..e61a8e0cf16488e02c7c4c41a11ac8c9f0b870bb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/one_hot_infer.c @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/one_hot_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int OneHotInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size != 4 && inputs_size != 3) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + const TensorC *depth_tensor = inputs[1]; + const TensorC *on_value = inputs[2]; + TensorC *output = outputs[0]; + const int *depth = (int *)(depth_tensor->data_); + if (depth == NULL) { + return NNACL_NULL_PTR; + } + SetDataTypeFormat(output, on_value); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + OneHotParameter *param = (OneHotParameter *)parameter; + int axis = param->axis_; + int input_rank = (int)(input->shape_size_); + if (axis < 0) { + axis += input_rank + 1; + } + if (input->shape_size_ >= MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + ShapeSet(output->shape_, &(output->shape_size_), input->shape_, input->shape_size_); + int res_insert = ShapeInsert(output->shape_, &output->shape_size_, axis, *depth); + if (res_insert == NNACL_ERR) { + return NNACL_ERR; + } + + return NNACL_OK; +} + +REG_INFER(OneHot, PrimType_OneHot, OneHotInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/one_hot_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/one_hot_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..b5c0dddfd5c818f2e2457530c424dd020551559e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/one_hot_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ONE_HOT_INFER_H +#define MINDSPORE_NNACL_ONE_HOT_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/one_hot_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int OneHotInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ONE_HOT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..a0609ebb642f97bdbbbb46c6fed76194435e428a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pad_infer.c @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/pad_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int PadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + PadParameter *param = (PadParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ > DEFAULT_PAD_NDIMS) { + return NNACL_INPUT_TENSOR_ERROR; + } + const TensorC *paddings = inputs[1]; + int size = NNACLGetElementNum(paddings); + if (size > MAX_PAD_SIZE) { + return NNACL_PARAM_INVALID; + } + if (paddings->data_ == NULL) { + return NNACL_INFER_INVALID; + } + param->padding_length = size; + for (int i = 0; i < size; ++i) { + NNACL_CHECK_TRUE_RET(((int *)paddings->data_)[i] >= 0, NNACL_INFER_INVALID); + param->paddings_[i] = ((int *)paddings->data_)[i]; + } + + int output_shape[DEFAULT_PAD_NDIMS] = {0}; + size_t output_shape_size = 0; + for (size_t i = 0; i < input->shape_size_; i++) { + int shape = input->shape_[i] + param->paddings_[2 * i] + param->paddings_[2 * i + 1]; + ShapePush(output_shape, &output_shape_size, shape); + } + + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(Pad, PrimType_PadFusion, PadInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..b5d13882a7963317f3a50ef7ea8aa39a134d1a01 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pad_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_PAD_INFER_H +#define MINDSPORE_NNACL_PAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/pad_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_PAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..6929016c16e40ba580796b742d116c6702c6b0d3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_grad_infer.c @@ -0,0 +1,74 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/pooling_grad_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" + +int PoolingGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + int input_h = input->shape_[1]; + int input_w = input->shape_[2]; + if (input->shape_size_ != 4) { + return NNACL_INPUT_TENSOR_ERROR; + } + PoolingParameter *param = (PoolingParameter *)parameter; + int window_h = param->window_h_; + int window_w = param->window_w_; + if (param->global_) { + window_h = input_h; + window_w = input_w; + } + + if (param->stride_h_ == 0 || param->stride_w_ == 0) { + return NNACL_PARAM_INVALID; + } + if (param->pad_mode_ == Pad_same) { + NNACL_CHECK_ZERO_RETURN_ERR(param->stride_w_); + NNACL_CHECK_ZERO_RETURN_ERR(param->stride_h_); + int output_w = ceil((float)(input_w) / (float)(param->stride_w_)); + int output_h = ceil((float)(input_h) / (float)(param->stride_h_)); + int pad_h_all = ((output_h - 1) * param->stride_h_ + (window_h - 1) + 1 - input_h); + int pad_w_all = ((output_w - 1) * param->stride_w_ + (window_w - 1) + 1 - input_w); + if (pad_h_all < 0) { + param->pad_u_ = param->pad_d_ = 0; + } else { + param->pad_u_ = pad_h_all / 2; + param->pad_d_ = pad_h_all - param->pad_u_; + } + if (pad_w_all < 0) { + param->pad_l_ = param->pad_r_ = 0; + } else { + param->pad_l_ = pad_w_all / 2; + param->pad_r_ = pad_w_all - param->pad_l_; + } + } + SetDataTypeFormat(outputs[0], input); + SetShapeTensor(outputs[0], input); + return NNACL_OK; +} + +REG_INFER(AvgPoolGrad, PrimType_AvgPoolGrad, PoolingGradInferShape) +REG_INFER(MaxPoolGrad, PrimType_MaxPoolGrad, PoolingGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..80c13b25f8f27e163fba5d34f6ac6ee569508c27 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_grad_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_POOLING_GRAD_INFER_H +#define MINDSPORE_NNACL_POOLING_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/pooling_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PoolingGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_POOLING_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..1d970cbfac7e1e3f2b6b961817889ff0118bd03f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_infer.c @@ -0,0 +1,107 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/pooling_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" + +int ComputePadList(PoolingParameter *param, int input_h, int input_w, int output_h, int output_w) { + if (param == NULL) { + return NNACL_NULL_PTR; + } + int pad_h_all = ((output_h - 1) * param->stride_h_ + (param->window_h_ - 1) + 1 - input_h); + int pad_w_all = ((output_w - 1) * param->stride_w_ + (param->window_w_ - 1) + 1 - input_w); + if (pad_h_all < 0) { + param->pad_u_ = param->pad_d_ = 0; + } else { + param->pad_u_ = pad_h_all / 2; + param->pad_d_ = pad_h_all - param->pad_u_; + } + if (pad_w_all < 0) { + param->pad_l_ = param->pad_r_ = 0; + } else { + param->pad_l_ = pad_w_all / 2; + param->pad_r_ = pad_w_all - param->pad_l_; + } + return NNACL_OK; +} + +int PoolingInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + NNACL_CHECK_TRUE_RET(input->format_ == Format_NHWC, NNACL_FORMAT_ERROR); + for (size_t i = 0; i < outputs_size; i++) { + TensorC *output = outputs[i]; + SetDataTypeFormat(output, input); + } + PoolingParameter *param = (PoolingParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ < 3 || input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input_h = input->shape_[1]; + int input_w = input->shape_[2]; + + int window_h = param->window_h_; + int window_w = param->window_w_; + if (param->global_) { + param->window_h_ = window_h = input_h; + param->window_w_ = window_w = input_w; + } + int output_h = 0; + int output_w = 0; + if ((param->stride_h_ == 0 || param->stride_w_ == 0) && !param->global_) { + return NNACL_PARAM_INVALID; + } + if (param->pad_mode_ == Pad_same) { + output_w = ceil((float)(input_w) / (float)(param->stride_w_)); + output_h = ceil((float)(input_h) / (float)(param->stride_h_)); + if (ComputePadList(param, input_h, input_w, output_h, output_w) != NNACL_OK) { + return NNACL_NULL_PTR; + } + } else { + int round_mode = (RoundType)param->round_type_; + if (round_mode == RoundType_Floor) { + output_h = floor((float)(input_h + param->pad_u_ + param->pad_d_ - window_h) / param->stride_h_) + 1; + output_w = floor((float)(input_w + param->pad_l_ + param->pad_r_ - window_w) / param->stride_w_) + 1; + } else if (round_mode == RoundType_Ceil) { + output_h = ceil((float)(input_h + param->pad_u_ + param->pad_d_ - window_h) / param->stride_h_) + 1; + output_w = ceil((float)(input_w + param->pad_l_ + param->pad_r_ - window_w) / param->stride_w_) + 1; + } else { + return NNACL_ERR; + } + } + int input_shape[MAX_SHAPE_SIZE]; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_); + input_shape[1] = output_h > 0 ? output_h : 1; + input_shape[2] = output_w > 0 ? output_w : 1; + for (size_t i = 0; i < outputs_size; i++) { + TensorC *output = outputs[i]; + SetShapeArray(output, input_shape, input_shape_size); + } + return NNACL_OK; +} + +REG_INFER(MaxPool, PrimType_MaxPoolFusion, PoolingInferShape) +REG_INFER(AvgPool, PrimType_AvgPoolFusion, PoolingInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c5587c6e0b6e511ca19231e1afa76ced4b343012 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_POOLING_INFER_H +#define MINDSPORE_NNACL_POOLING_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/pooling_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PoolingInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_POOLING_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/power_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/power_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..00db9d12c6cccbe99fe4a20e65e01395015740e7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/power_infer.c @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/power_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int PowerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *x_tensor = inputs[0]; + TensorC *exp_tensor = NULL; + if (inputs_size == 2) { + exp_tensor = (TensorC *)inputs[1]; + PowParameter *param = (PowParameter *)parameter; + float *exp_data = (float *)(exp_tensor->data_); + if (exp_data == NULL) { + return NNACL_INFER_INVALID; + } + param->power_ = *exp_data; + } + TensorC *output_tensor = outputs[0]; + + SetDataTypeFormat(output_tensor, x_tensor); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (exp_tensor != NULL) { + bool exp_x_equal = ShapeEqual(exp_tensor->shape_, exp_tensor->shape_size_, x_tensor->shape_, x_tensor->shape_size_); + if (!exp_x_equal && NNACLGetElementNum(exp_tensor) != 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + + SetShapeTensor(output_tensor, x_tensor); + return NNACL_OK; +} + +REG_INFER(Pow, PrimType_PowFusion, PowerInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/power_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/power_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..8395060e8560b5f97b2507d85fcbd5f5baa7fd28 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/power_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_POWER_INFER_H +#define MINDSPORE_NNACL_POWER_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/pow_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PowerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_POWER_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/prior_box_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/prior_box_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..a49b2c38e443a1c507892920dc215a742af69cb1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/prior_box_infer.c @@ -0,0 +1,87 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/prior_box_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int PriorBoxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->data_type_ = kNumberTypeFloat32; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + float different_aspect_ratios[MAX_SHAPE_SIZE * 2 + 1]; // NOTE: flip double the number + different_aspect_ratios[0] = 1.0; + int32_t different_aspect_ratios_size = 1; + + PriorBoxParameter *param = (PriorBoxParameter *)parameter; + float *aspect_ratios = param->aspect_ratios; + if (aspect_ratios == NULL) { + return NNACL_NULL_PTR; + } + int32_t aspect_ratios_size = param->aspect_ratios_size; + NNACL_CHECK_TRUE_RET(aspect_ratios_size <= MAX_SHAPE_SIZE, NNACL_ERR); + for (int32_t i = 0; i < aspect_ratios_size; i++) { + float ratio = aspect_ratios[i]; + if (fabsf(ratio) < EPSILON_VALUE) { + return NNACL_ERR; + } + + bool exist = false; + for (int32_t j = 0; j < different_aspect_ratios_size; j++) { + if (fabsf(ratio - different_aspect_ratios[j]) < EPSILON_VALUE) { + exist = true; + break; + } + } + if (!exist) { + different_aspect_ratios[different_aspect_ratios_size] = ratio; + different_aspect_ratios_size++; + if (param->flip) { + different_aspect_ratios[different_aspect_ratios_size] = 1.0f / ratio; + different_aspect_ratios_size++; + } + } + } + + int32_t min_sizes_size = param->min_sizes_size; + int32_t max_sizes_size = param->max_sizes_size; + int32_t num_priors_box = min_sizes_size * different_aspect_ratios_size + max_sizes_size; + const int kPriorBoxPoints = 4; + const int kPriorBoxN = 1; + const int kPriorBoxW = 1; + const int kPriorBoxC = 2; + + int32_t h = NNACLGetHeight(input) * NNACLGetWidth(input) * num_priors_box * kPriorBoxPoints; + output->shape_size_ = 4; + output->shape_[0] = kPriorBoxN; + output->shape_[1] = h; + output->shape_[2] = kPriorBoxW; + output->shape_[3] = kPriorBoxC; + return NNACL_OK; +} + +REG_INFER(PriorBox, PrimType_PriorBox, PriorBoxInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/prior_box_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/prior_box_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..a113415c030255e906d3dfe85c0e9d8cf0ba9c29 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/prior_box_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_PRIOR_BOX_INFER_H +#define MINDSPORE_NNACL_PRIOR_BOX_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/prior_box_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PriorBoxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_PRIOR_BOX_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/quant_dtype_cast_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/quant_dtype_cast_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..d0c58a0699e16d0e1a7b78f2b1d0c34105095c99 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/quant_dtype_cast_infer.c @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/quant_dtype_cast_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int QuantDtypeCastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + QuantDtypeCastParameter *param = (QuantDtypeCastParameter *)parameter; + output->data_type_ = param->dstT_; + NNACL_CHECK_TRUE_RET(output->data_type_ > kNumberTypeBegin && output->data_type_ < kNumberTypeEnd, NNACL_ERR); + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(QuantDTypeCast, PrimType_QuantDTypeCast, QuantDtypeCastInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/quant_dtype_cast_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/quant_dtype_cast_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..fba14604cf60b6e09bd1d1e04855b41e43f3c1e3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/quant_dtype_cast_infer.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_QUANT_DTYPE_CAST_INFER_H +#define MINDSPORE_NNACL_QUANT_DTYPE_CAST_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct QuantDtypeCastParameter { + OpParameter op_parameter_; + int srcT_; // deprecated + int dstT_; +} QuantDtypeCastParameter; + +int QuantDtypeCastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_QUANT_DTYPE_CAST_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/ragged_range_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/ragged_range_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..3b1c7aec6a5e2c8aabf67b9f815d4aa52b5b5e7c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/ragged_range_infer.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/ragged_range_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" + +int CheckInputTensor(const TensorC *const *inputs) { + if (inputs[0]->data_ == NULL || inputs[1]->data_ == NULL || inputs[2]->data_ == NULL) { + return NNACL_INFER_INVALID; + } + if (inputs[0]->shape_size_ != 0 && inputs[0]->shape_size_ != 1) { + return NNACL_ERR; + } + return NNACL_OK; +} + +int GetRows(const TensorC *const *inputs, bool starts_is_scalar, bool limits_is_scalar, bool deltas_is_scalar, + int *rows) { + NNACL_CHECK_NULL_RETURN_ERR(rows); + int sizes[3]; + int not_scalar_count = 0; + if (!starts_is_scalar) { + sizes[not_scalar_count++] = inputs[0]->shape_[0]; + } + if (!limits_is_scalar) { + sizes[not_scalar_count++] = inputs[1]->shape_[0]; + } + if (!deltas_is_scalar) { + sizes[not_scalar_count++] = inputs[2]->shape_[0]; + } + for (int i = 1; i < not_scalar_count; i++) { + if (sizes[i] != sizes[i - 1]) { + return NNACL_ERR; + } + } + *rows = not_scalar_count == 0 ? 1 : sizes[0]; + return NNACL_OK; +} + +int GetOutputValueElementNum(const TensorC *const *inputs, bool starts_is_scalar, bool limits_is_scalar, + bool deltas_is_scalar, int rows, int *output_value_element_num) { + int count = 0; + switch (inputs[0]->data_type_) { + case kNumberTypeInt32: { + int *starts = (int *)(inputs[0]->data_); + int *limits = (int *)(inputs[1]->data_); + int *deltas = (int *)(inputs[2]->data_); + for (int i = 0; i < rows; i++) { + int start = starts_is_scalar ? starts[0] : starts[i]; + int limit = limits_is_scalar ? limits[0] : limits[i]; + int delta = deltas_is_scalar ? deltas[0] : deltas[i]; + NNACL_CHECK_ZERO_RETURN_ERR(delta); + count += MSMAX((int)(ceil((float)(limit - start) / delta)), 0); + } + } break; + case kNumberTypeFloat32: { + float *starts = (float *)(inputs[0]->data_); + float *limits = (float *)(inputs[1]->data_); + float *deltas = (float *)(inputs[2]->data_); + for (int i = 0; i < rows; i++) { + float start = starts_is_scalar ? starts[0] : starts[i]; + float limit = limits_is_scalar ? limits[0] : limits[i]; + float delta = deltas_is_scalar ? deltas[0] : deltas[i]; + NNACL_CHECK_ZERO_RETURN_ERR(delta); + count += MSMAX((ceil((limit - start) / delta)), 0); + } + } break; + default: { + return NNACL_ERR; + } + } + *output_value_element_num = count; + return NNACL_OK; +} + +int RaggedRangeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + outputs[0]->data_type_ = kNumberTypeInt32; + outputs[0]->format_ = inputs[0]->format_; + SetDataTypeFormat(outputs[1], inputs[0]); + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int ret = CheckInputTensor(inputs); + if (ret != NNACL_OK) { + return ret; + } + + bool starts_is_scalar = inputs[0]->shape_size_ == 0; + bool limits_is_scalar = inputs[1]->shape_size_ == 0; + bool deltas_is_scalar = inputs[2]->shape_size_ == 0; + int rows; + ret = GetRows(inputs, starts_is_scalar, limits_is_scalar, deltas_is_scalar, &rows); + if (ret != NNACL_OK) { + return ret; + } + int output_value_element_num; + ret = GetOutputValueElementNum(inputs, starts_is_scalar, limits_is_scalar, deltas_is_scalar, rows, + &output_value_element_num); + if (ret != NNACL_OK) { + return ret; + } + outputs[0]->shape_size_ = 1; + outputs[0]->shape_[0] = rows + 1; + outputs[1]->shape_size_ = 1; + outputs[1]->shape_[0] = output_value_element_num; + return NNACL_OK; +} + +REG_INFER(RaggedRange, PrimType_RaggedRange, RaggedRangeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/ragged_range_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/ragged_range_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..22613326809c25e0ea62a548c98772e360d4f759 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/ragged_range_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RAGGED_RANGE_INFER_H +#define MINDSPORE_NNACL_RAGGED_RANGE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/ragged_range_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int RaggedRangeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RAGGED_RANGE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_normal_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_normal_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..20d18626cc75cf40fb218e114f84e8cb51231bcd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_normal_infer.c @@ -0,0 +1,38 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/random_normal_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int RandomNormalInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + outputs[0]->data_type_ = inputs[0]->data_type_; + outputs[0]->format_ = inputs[0]->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + SetShapeTensor(outputs[0], inputs[0]); + + return NNACL_OK; +} + +REG_INFER(RandomNormal, PrimType_RandomNormal, RandomNormalInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_normal_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_normal_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..5dce4607e356b96e241c40e3c084525295b38c9a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_normal_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RANDOM_NORMAL_INFER_H +#define MINDSPORE_NNACL_RANDOM_NORMAL_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int RandomNormalInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RANDOM_STANDARD_NORMAL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_standard_normal_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_standard_normal_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..5f214095e0fd59e34df9a492b8fc09dd19a7b618 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_standard_normal_infer.c @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/random_standard_normal_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int RandomStandardNormalInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + outputs[0]->data_type_ = kNumberTypeFloat32; + outputs[0]->format_ = inputs[0]->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int32_t *input_data = (int32_t *)(inputs[0]->data_); + if (input_data == NULL) { + return NNACL_INFER_INVALID; + } + int input_num = NNACLGetElementNum(inputs[0]); + if (input_num > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + for (int i = 0; i < input_num; i++) { + ShapePush(output_shape, &output_shape_size, input_data[i]); + } + SetShapeArray(outputs[0], output_shape, output_shape_size); + + return NNACL_OK; +} + +REG_INFER(RandomStandardNormal, PrimType_RandomStandardNormal, RandomStandardNormalInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_standard_normal_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_standard_normal_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..6a31082a2714e0ae96dd850e2609e9393923a8ac --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_standard_normal_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RANDOM_STANDARD_NORMAL_INFER_H +#define MINDSPORE_NNACL_RANDOM_STANDARD_NORMAL_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int RandomStandardNormalInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RANDOM_STANDARD_NORMAL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/range_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/range_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..619c658f9be502af99049b5de80bae25ab51fa8c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/range_infer.c @@ -0,0 +1,91 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/range_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/range_parameter.h" +#include "nnacl_c/tensor_c_utils.h" + +int RangeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 1, C3NUM, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->data_type_ = inputs_size == C3NUM ? input->data_type_ : kNumberTypeInt32; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(inputs[FIRST_INPUT]) < 1) { + return NNACL_ERR; + } + int shape_size = 0; + if (inputs_size == C3NUM) { + NNACL_CHECK_FALSE(inputs[FIRST_INPUT]->data_ == NULL, NNACL_INFER_INVALID); + NNACL_CHECK_FALSE(inputs[SECOND_INPUT]->data_ == NULL, NNACL_INFER_INVALID); + NNACL_CHECK_FALSE(inputs[THIRD_INPUT]->data_ == NULL, NNACL_INFER_INVALID); + if ((inputs[FIRST_INPUT]->data_type_ != inputs[SECOND_INPUT]->data_type_) || + (inputs[FIRST_INPUT]->data_type_ != inputs[THIRD_INPUT]->data_type_)) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(inputs[SECOND_INPUT]) < 1 || NNACLGetElementNum(inputs[THIRD_INPUT]) < 1) { + return NNACL_ERR; + } + switch (inputs[0]->data_type_) { + case kNumberTypeInt: + case kNumberTypeInt32: { + int start = *(int *)(inputs[0]->data_); + int limit = *(int *)(inputs[1]->data_); + int delta = *(int *)(inputs[2]->data_); + if (delta == 0) { + return NNACL_ERR; + } + shape_size = imax((int)(ceil((float)(limit - start) / delta)), 0); + } break; + case kNumberTypeFloat32: + case kNumberTypeFloat: { + float start = *(float *)(inputs[0]->data_); + float limit = *(float *)(inputs[1]->data_); + float delta = *(float *)(inputs[2]->data_); + if (fabsf(delta) < EPSILON_VALUE) { + return NNACL_ERR; + } + shape_size = imax((int)(ceil((float)(limit - start) / delta)), 0); + } break; + default: { + return NNACL_ERR; + } + } + } else { + RangeParameter *param = (RangeParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (param->delta_ == 0) { + return NNACL_PARAM_INVALID; + } + shape_size = ceil((float)(param->limit_ - param->start_) / param->delta_); + } + + output->shape_size_ = 1; + output->shape_[0] = shape_size; + return NNACL_OK; +} + +REG_INFER(Range, PrimType_Range, RangeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/range_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/range_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..eb1401a1c9661fb58cfbc06bb9d98aff86882653 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/range_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RANGE_INFER_H +#define MINDSPORE_NNACL_RANGE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int RangeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RANGE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rank_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rank_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..2c6d9299f2eeec7408716c6378daf989d81c01b3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rank_infer.c @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/rank_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int RankInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + output->shape_size_ = 1; + output->shape_[0] = 1; + return NNACL_OK; +} + +REG_INFER(Rank, PrimType_Rank, RankInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rank_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rank_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..5f7d2c461780a0a9f9e3c3af2be38c91a2d6c895 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rank_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RANK_INFER_H +#define MINDSPORE_NNACL_RANK_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int RankInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RANK_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_concat_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_concat_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..00e68b34a433b3da2649e1e6303f455c637b0135 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_concat_infer.c @@ -0,0 +1,95 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/reduce_concat_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/split_parameter.h" + +int DataTypeJudge2(const TensorC *input, const TensorC *output) { + if ((input->data_type_ != output->data_type_) && + !((input->data_type_ == kNumberTypeFloat16 && output->data_type_ == kNumberTypeFloat32) || + (input->data_type_ == kNumberTypeFloat32 && output->data_type_ == kNumberTypeFloat16))) { + return NNACL_PARAM_INVALID; + } + return NNACL_OK; +} + +int ReduceConcatFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullOutputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input0 = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input0); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + const int *input0_shape = inputs[0]->shape_; + size_t input0_shape_size = inputs[0]->shape_size_; + + int axis = C2NUM; + if (axis < 0 || axis >= (int)input0_shape_size) { + return NNACL_ERR; + } + if (input0_shape_size > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input0_shape_without_axis[MAX_SHAPE_SIZE] = {0}; + size_t input0_shape_without_axis_size = 0; + ShapeSet(input0_shape_without_axis, &input0_shape_without_axis_size, input0_shape, input0_shape_size); + int erase_ret = ShapeErase(input0_shape_without_axis, &input0_shape_without_axis_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + int output_axis_dim = input0_shape[axis]; + for (size_t i = 1; i < inputs_size; ++i) { + size_t input_i_shape_size = inputs[i]->shape_size_; + if (input_i_shape_size != input0_shape_size) { + return NNACL_PARAM_INVALID; + } + int shape_tmp[MAX_SHAPE_SIZE] = {0}; + size_t shape_tmp_size = 0; + ShapeSet(shape_tmp, &shape_tmp_size, inputs[i]->shape_, inputs[i]->shape_size_); + int data_type_judge = DataTypeJudge2(inputs[i], output); + if (data_type_judge != NNACL_OK) { + return data_type_judge; + } + int axis_tmp = shape_tmp[axis]; + erase_ret = ShapeErase(shape_tmp, &shape_tmp_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + + output_axis_dim += axis_tmp; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input0_shape_size; + for (size_t i = 0; i < input0_shape_size; i++) { + output_shape[i] = input0_shape[i]; + } + output_shape[axis] = output_axis_dim; + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(ReduceConcatFusion, PrimType_Inner_ReduceConcatFusion, ReduceConcatFusionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_concat_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_concat_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..267855ecb989e326361bf3bb71742378d48aed55 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_concat_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_REDUCE_CONCAT_ONLINE_FUSION_INFER_H +#define MINDSPORE_NNACL_REDUCE_CONCAT_ONLINE_FUSION_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReduceConcatFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_REDUCE_CONCAT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..010d373988123f366d09c4f586349e434a2eac3d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_infer.c @@ -0,0 +1,140 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/reduce_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int ReduceOnAllAxes(const TensorC *input, TensorC *output, int *out_shape, size_t out_shape_size, bool keep_dims) { + if (keep_dims) { + for (size_t i = 0; i < input->shape_size_; i++) { + ShapePush(out_shape, &out_shape_size, 1); + } + } + SetShapeArray(output, out_shape, out_shape_size); + output->data_type_ = input->data_type_; + return NNACL_OK; +} + +int ReduceOnSelectedAxes(const TensorC *input, size_t num_axes, const int *actual_axes, TensorC *output, int *out_shape, + size_t out_shape_size, bool keep_dims) { + for (size_t i = 0; i < input->shape_size_; i++) { + bool reduce_axis = false; + for (size_t idx = 0; idx < num_axes; ++idx) { + if ((size_t)(actual_axes[idx]) == i || (size_t)(actual_axes[idx]) + input->shape_size_ == i) { + reduce_axis = true; + break; + } + } + if (reduce_axis) { + if (keep_dims) { + ShapePush(out_shape, &out_shape_size, 1); + } + } else { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +bool IsReduceAllAxes(const TensorC *const *inputs, size_t inputs_size) { + if (inputs_size == 1) { + return true; + } + // When axes not given, reduce op will have two input tensor by the old version converter_lite tool. + if (inputs_size == 2 && inputs[1]->shape_size_ == 1 && inputs[1]->shape_[0] == 0) { + return true; + } + return false; +} + +int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + ReduceParameter *param = (ReduceParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + bool keep_dims = param->keep_dims_; + int out_shape[MAX_SHAPE_SIZE] = {0}; + const size_t out_shape_size = 0; + if (IsReduceAllAxes(inputs, inputs_size)) { + return ReduceOnAllAxes(input, output, out_shape, out_shape_size, keep_dims); + } + + // get axes from input tensor + const TensorC *axes_input = inputs[1]; + NNACL_CHECK_NULL_RETURN_ERR(axes_input->data_); + + int num_axes; + if (axes_input->shape_size_ == 1) { + num_axes = axes_input->shape_[0]; + } else if (axes_input->shape_size_ == 0) { + num_axes = 1; + } else { + return NNACL_ERR; + } + if (num_axes > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int rank = (int)(input->shape_size_); + if (rank > MAX_SHAPE_SIZE || rank < 0) { + return NNACL_ERR; + } + int actual_axes[MAX_SHAPE_SIZE] = {0}; + size_t actual_axes_size = 0; + int ret = GetInt32DataFromTensor(axes_input, actual_axes, &actual_axes_size); + if (ret != NNACL_OK) { + return ret; + } + + if (param->reduce_to_end_) { + if (num_axes != 1) { + return NNACL_ERR; + } + + if (actual_axes[0] < -1 * rank || actual_axes[0] >= rank) { + return NNACL_PARAM_INVALID; + } + int begin_axis; + begin_axis = actual_axes[0] < 0 ? actual_axes[0] + rank : actual_axes[0]; + for (int i = begin_axis + 1; i < rank; ++i) { + ShapePush(actual_axes, &actual_axes_size, i); + } + num_axes = rank - begin_axis; + keep_dims = false; + } + // reduce on all axes + if (num_axes == 0) { + return ReduceOnAllAxes(input, output, out_shape, out_shape_size, keep_dims); + } + // reduce on selected axes + return ReduceOnSelectedAxes(input, (size_t)num_axes, actual_axes, output, out_shape, out_shape_size, keep_dims); +} + +REG_INFER(Reduce, PrimType_ReduceFusion, ReduceInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c60eea7c1c0e19d558e220143aae5e3f93c5883e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_REDUCE_INFER_H +#define MINDSPORE_NNACL_REDUCE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/reduce_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_REDUCE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_scatter_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_scatter_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..6c9c0a3e2a4dfea88136047aecbdcab8526f6759 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_scatter_infer.c @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/reduce_scatter_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int ReduceScatterInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { + return NNACL_NULL_PTR; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + ReduceScatterParameter *param = (ReduceScatterParameter *)parameter; + if (param->rank_size_ <= 0) { + return NNACL_INFER_INVALID; + } + + const TensorC *input_tensor = inputs[0]; + const int *in_shape = input_tensor->shape_; + TensorC *out_tensor = outputs[0]; + + if (in_shape[0] % param->rank_size_ != 0) { + return NNACL_INFER_INVALID; + } + + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + out_shape[0] = in_shape[0] / param->rank_size_; + out_shape_size++; + for (int i = 1; i < input_tensor->shape_size_; i++) { + out_shape[i] = in_shape[i]; + out_shape_size++; + } + SetShapeArray(out_tensor, out_shape, out_shape_size); + + return NNACL_OK; +} + +REG_INFER(ReduceScatter, PrimType_ReduceScatter, ReduceScatterInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_scatter_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_scatter_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..80246f1c9b737e7f9af8b9e9cf88563d3d43d2d9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_scatter_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_REDUCE_SCATTER_INFER_H +#define MINDSPORE_NNACL_REDUCE_SCATTER_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/reduce_scatter_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReduceScatterInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_REDUCE_SCATTER_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reshape_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reshape_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..8c28b9dd4717b021a0adcf79c00fae6676d596d5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reshape_infer.c @@ -0,0 +1,221 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/reshape_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c_utils.h" + +int CalShape(const int *data, const TensorC *const *inputs, int *out_shape, size_t *out_shape_size, int shape_size) { + int input_count = NNACLGetElementNum(inputs[0]); + int index = 0; + int size = 1; + for (int i = 0; i < shape_size; i++) { + if ((int)(data[i]) == -1) { + index = i; + } else if ((int)(data[i]) == 0) { + size *= inputs[0]->shape_[i]; + } else { + size *= data[i]; + } + ShapePush(out_shape, out_shape_size, data[i]); + } + if (size == 0) { + return NNACL_ERR; + } + if ((int)(data[index]) == -1) { + if (index >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + out_shape[index] = input_count / size; + } + return NNACL_OK; +} + +int CalNewShape(const TensorC *in_tensor, int *out_shape, size_t out_shape_size) { + int in_shape_size = 1; + for (size_t i = 0; i < in_tensor->shape_size_; i++) { + in_shape_size *= in_tensor->shape_[i]; + } + int64_t infer_index = -1; + int out_shape_size_new = 1; + for (size_t i = 0; i < out_shape_size; i++) { + if (out_shape[i] == -1) { + if (infer_index == -1) { + infer_index = (int64_t)(i); + } else { + return NNACL_ERR; + } + } else if (out_shape[i] < 0) { + return NNACL_ERR; + } else if (out_shape[i] == 0) { + if (NNACLGetElementNum(in_tensor) != 0) { + out_shape[i] = in_tensor->shape_[i]; + out_shape_size_new *= out_shape[i]; + } else { + out_shape_size_new = 0; + break; + } + } else { + out_shape_size_new *= out_shape[i]; + } + } + if (infer_index == -1 && out_shape_size_new != in_shape_size) { + return NNACL_ERR; + } + if (infer_index != -1) { + if (out_shape_size_new == 0) { + return NNACL_ERR; + } + if (infer_index >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + out_shape[infer_index] = in_shape_size / out_shape_size_new; + } + return NNACL_OK; +} + +int CalShapeByType(const TensorC *const *inputs, size_t shape_size, int *out_shape, size_t *out_shape_size) { + const TensorC *shape_tensor = inputs[1]; + if (shape_size == 0) { + return NNACL_ERR; + } + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW((sizeof(int)), shape_size), NNACL_ERR); + int *data_int = (int *)malloc(sizeof(int) * shape_size); + if (data_int == NULL) { + return NNACL_ERR; + } + switch (shape_tensor->data_type_) { + case kNumberTypeInt8: { + int8_t *data = (int8_t *)(shape_tensor->data_); + for (size_t i = 0; i < shape_size; i++) { + data_int[i] = data[i]; + } + int cal_ret = CalShape(data_int, inputs, out_shape, out_shape_size, shape_size); + if (cal_ret != NNACL_OK) { + free(data_int); + return NNACL_ERR; + } + } break; + case kNumberTypeInt32: { + int32_t *data = (int32_t *)(shape_tensor->data_); + for (size_t i = 0; i < shape_size; i++) { + data_int[i] = data[i]; + } + int cal_ret = CalShape(data_int, inputs, out_shape, out_shape_size, shape_size); + if (cal_ret != NNACL_OK) { + free(data_int); + return NNACL_ERR; + } + } break; + case kNumberTypeInt64: { + int64_t *data = (int64_t *)(shape_tensor->data_); + for (size_t i = 0; i < shape_size; i++) { + data_int[i] = data[i]; + } + int cal_ret = CalShape(data_int, inputs, out_shape, out_shape_size, shape_size); + if (cal_ret != NNACL_OK) { + free(data_int); + return NNACL_ERR; + } + } break; + case kNumberTypeFloat: { + float *data = (float *)(shape_tensor->data_); + for (size_t i = 0; i < shape_size; i++) { + data_int[i] = data[i]; + } + int cal_ret = CalShape(data_int, inputs, out_shape, out_shape_size, shape_size); + if (cal_ret != NNACL_OK) { + free(data_int); + return NNACL_ERR; + } + } break; + case kNumberTypeUInt32: { + uint32_t *data = (uint32_t *)(shape_tensor->data_); + for (size_t i = 0; i < shape_size; i++) { + data_int[i] = (int)data[i]; + } + int cal_ret = CalShape(data_int, inputs, out_shape, out_shape_size, shape_size); + if (cal_ret != NNACL_OK) { + free(data_int); + return NNACL_ERR; + } + } break; + default: { + free(data_int); + return NNACL_ERR; + } + } + free(data_int); + return NNACL_OK; +} + +int ReshapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + ReshapeParameter *param = (ReshapeParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + if (inputs_size == 2) { + const TensorC *shape_tensor = inputs[1]; + if (NNACLGetElementNum(input) == 1) { + if (shape_tensor->data_ == NULL || (shape_tensor->shape_size_ == 1 && shape_tensor->shape_[0] == 0)) { + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; + } + } + + if (shape_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int shape_size = NNACLGetElementNum(shape_tensor); + if (shape_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int calRet = CalShapeByType(inputs, shape_size, out_shape, &out_shape_size); + if (calRet != NNACL_OK) { + return calRet; + } + } else if (inputs_size == 1) { + if (param->shape_dim_ > MAX_SHAPE_SIZE) { + return NNACL_PARAM_INVALID; + } + for (int i = 0; i < param->shape_dim_; ++i) { + ShapePush(out_shape, &out_shape_size, param->shape_[i]); + } + } else { + return NNACL_ERR; + } + int ret = CalNewShape(inputs[0], out_shape, out_shape_size); + if (ret != NNACL_OK) { + return ret; + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(Reshape, PrimType_Reshape, ReshapeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reshape_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reshape_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..1f79b52f67da9f97a47e84bac575d1fd9837736e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reshape_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RESHAPE_INFER_H +#define MINDSPORE_NNACL_RESHAPE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/reshape_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReshapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RESHAPE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..0867769e16c6f72f7375e8ae4c9350d3452a4fa2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_grad_infer.c @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/resize_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int ResizeGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + if (input->shape_size_ != 4) { + return NNACL_ERR; + } + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *input_1 = inputs[1]; + if (input_1->shape_size_ == 4) { + ShapeSet(output->shape_, &output->shape_size_, input_1->shape_, input_1->shape_size_); + } else if (input_1->shape_size_ == 1 && input_1->shape_[0] == 2 && input_1->data_type_ == kNumberTypeInt32) { + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + int32_t *data = (int32_t *)(input_1->data_); + + ShapePush(output_shape, &output_shape_size, NNACLGetBatch(input)); + ShapePush(output_shape, &output_shape_size, data[0]); + ShapePush(output_shape, &output_shape_size, data[1]); + ShapePush(output_shape, &output_shape_size, NNACLGetChannel(input)); + SetShapeArray(output, output_shape, output_shape_size); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} + +REG_INFER(ResizeGrad, PrimType_ResizeGrad, ResizeGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..87f2786a28831f6c1a2736c75f0d92eb153d2f16 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_grad_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RESIZE_GRAD_INFER_H_ +#define MINDSPORE_NNACL_RESIZE_GRAD_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32_grad/resize_grad.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ResizeGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RESIZE_GRAD_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..daf5e9de32e3885b865952bd1807d8e4a3708b85 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_infer.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/resize_infer.h" +#include +#include +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/tensor_c_utils.h" + +int HandleTwoInputs(const TensorC *const *inputs, ResizeParameter *param) { + const TensorC *input = inputs[0]; + const TensorC *shape_tensor = inputs[1]; + if (shape_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int shape_size = NNACLGetElementNum(shape_tensor); + void *origin_data = shape_tensor->data_; + if (origin_data == NULL) { + return NNACL_INFER_INVALID; + } + switch (shape_size) { + case 2: + case 4: { + int height_index = 0; + int width_index = 1; + if (shape_size == 4) { + height_index = kNHWC_H; + width_index = kNHWC_W; + } + if (shape_tensor->data_type_ == kNumberTypeInt32) { + int32_t *data = (int32_t *)(origin_data); + param->new_height_ = data[height_index]; + param->new_width_ = data[width_index]; + } else if (shape_tensor->data_type_ == kNumberTypeFloat32) { + float *data = (float *)(origin_data); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW((int)(data[height_index]), NNACLGetHeight(input), NNACL_ERRCODE_MUL_OVERFLOW); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW((int)(data[width_index]), NNACLGetWidth(input), NNACL_ERRCODE_MUL_OVERFLOW); + param->new_height_ = round(data[height_index] * NNACLGetHeight(input)); + param->new_width_ = round(data[width_index] * NNACLGetWidth(input)); + } else if (shape_tensor->data_type_ == kNumberTypeFloat16) { + uint16_t *data = (uint16_t *)(shape_tensor->data_); + float scale_height = ShortToFloat32(data[height_index]); + float scale_width = ShortToFloat32(data[width_index]); + param->new_height_ = round(scale_height * NNACLGetHeight(input)); + param->new_width_ = round(scale_width * NNACLGetWidth(input)); + } + break; + } + case 1: { + // caffe zoom_factor + int scale; + if (shape_tensor->data_type_ == kNumberTypeInt32) { + int *data = (int *)(origin_data); + scale = data[0]; + } else { + return NNACL_ERR; + } + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(NNACLGetHeight(input) - 1, scale - 1, NNACL_ERRCODE_MUL_OVERFLOW); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(NNACLGetWidth(input) - 1, scale - 1, NNACL_ERRCODE_MUL_OVERFLOW); + param->new_height_ = NNACLGetHeight(input) + (NNACLGetHeight(input) - 1) * (scale - 1); + param->new_width_ = NNACLGetWidth(input) + (NNACLGetWidth(input) - 1) * (scale - 1); + break; + } + default: { + return NNACL_ERR; + } + } + return NNACL_OK; +} + +int CalculateNewHeightAndWidth(const TensorC *const *inputs, size_t inputs_size, ResizeParameter *param) { + if (inputs_size == 2) { + return HandleTwoInputs(inputs, param); + } else if (inputs_size == 1) { + } else { + return NNACL_ERR; + } + return NNACL_OK; +} + +int ResizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 0 && input->shape_size_ != 4) { + return NNACL_ERR; + } + ResizeParameter *param = (ResizeParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + ShapePush(output_shape, &output_shape_size, NNACLGetBatch(input)); + int ret = CalculateNewHeightAndWidth(inputs, inputs_size, param); + if (ret == NNACL_OK) { + ShapePush(output_shape, &output_shape_size, param->new_height_); + ShapePush(output_shape, &output_shape_size, param->new_width_); + ShapePush(output_shape, &output_shape_size, NNACLGetChannel(input)); + SetShapeArray(output, output_shape, output_shape_size); + } + return ret; +} + +REG_INFER(Resize, PrimType_Resize, ResizeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c9549b0527a925d5e294f2d12bcfce39c452a1db --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RESIZE_INFER_H +#define MINDSPORE_NNACL_RESIZE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/resize_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ResizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RESIZE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rfft_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rfft_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..a95fa84ff6d3a7e82ef87b0438ffd421869a4587 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rfft_infer.c @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/rfft_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int RfftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->data_type_ = kNumberTypeComplex64; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ >= MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + ShapeSet(output->shape_, &(output->shape_size_), input->shape_, input->shape_size_); + RfftParameter *param = (RfftParameter *)parameter; + if (input->shape_size_ < 1) { + return NNACL_ERR; + } + output->shape_[input->shape_size_ - 1] = param->fft_length_ / 2 + 1; + ShapePush(output->shape_, &(output->shape_size_), 2); + return NNACL_OK; +} + +REG_INFER(Rfft, PrimType_Rfft, RfftInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rfft_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rfft_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c863ede9932c424db510354399bc2d87bcf3e7c9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rfft_infer.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RFFT_INFER_H +#define MINDSPORE_NNACL_RFFT_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct RfftParameter { + OpParameter op_parameter_; + int fft_length_; +} RfftParameter; + +int RfftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RFFT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/roi_pooling_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/roi_pooling_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..6ae35800c8f35832e25517ba44fe4c14b26fe70f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/roi_pooling_infer.c @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/roi_pooling_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int ROIPoolingInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (outputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + const TensorC *roi = inputs[1]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + ROIPoolingParameter *param = (ROIPoolingParameter *)parameter; + output->shape_size_ = 4; + output->shape_[0] = roi->shape_[0]; + output->shape_[1] = param->pooledH_; + output->shape_[2] = param->pooledW_; + output->shape_[3] = NNACLGetChannel(input); + return NNACL_OK; +} + +REG_INFER(ROIPooling, PrimType_ROIPooling, ROIPoolingInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/roi_pooling_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/roi_pooling_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..4410ced9d7a9810ff55a585beddd8cafd9f31acd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/roi_pooling_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ROI_POOLING_INFER_H +#define MINDSPORE_NNACL_ROI_POOLING_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/roi_pooling_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ROIPoolingInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ROI_POOLING_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..536244b66aff79146936ae27353219f9bbe2ba9e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/scatter_nd_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int ScatterNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *shape = inputs[THIRD_INPUT]; + if (shape->data_ == NULL) { + return NNACL_INFER_INVALID; + } + const TensorC *update = inputs[SECOND_INPUT]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, update); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int *shape_data = (int *)(shape->data_); + NNACL_CHECK_TRUE_RET(NNACLGetElementNum(shape) <= MAX_SHAPE_SIZE, NNACL_ERR); + SetShapeArray(output, shape_data, (size_t)NNACLGetElementNum(shape)); + return NNACL_OK; +} + +REG_INFER(ScatterNd, PrimType_ScatterNd, ScatterNdInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..154c356c11b635fe2391a832a3cb2ee50af1ffe3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SCATTER_ND_INFER_H +#define MINDSPORE_NNACL_SCATTER_ND_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ScatterNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SCATTER_ND_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_update_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_update_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..5ae93168b33223ef7f0ed28098dd857266bf0f15 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_update_infer.c @@ -0,0 +1,59 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/scatter_nd_update_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int ScatterNdUpdateInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input_x = inputs[0]; + const TensorC *indices = inputs[1]; + const TensorC *updates = inputs[2]; + TensorC *output = outputs[0]; + if (updates->data_type_ != input_x->data_type_ || + (indices->data_type_ != kNumberTypeInt32 && indices->data_type_ != kNumberTypeInt64)) { + return NNACL_ERR; + } + SetDataTypeFormat(output, input_x); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (indices->shape_size_ < 2 || indices->shape_[indices->shape_size_ - 1] > input_x->shape_size_) { + return NNACL_ERR; + } + if (updates->shape_size_ != + (indices->shape_size_ - 1) + input_x->shape_size_ - indices->shape_[indices->shape_size_ - 1]) { + return NNACL_ERR; + } + for (int i = 0; i < updates->shape_size_; i++) { + if ((i < indices->shape_size_ - 1 && updates->shape_[i] != indices->shape_[i]) || + (i >= indices->shape_size_ - 1 && + updates->shape_[i] != + input_x->shape_[indices->shape_[indices->shape_size_ - 1] + i - indices->shape_size_ + 1])) { + return NNACL_ERR; + } + } + SetShapeArray(output, input_x->shape_, input_x->shape_size_); + return NNACL_OK; +} + +REG_INFER(ScatterNdUpdate, PrimType_ScatterNdUpdate, ScatterNdUpdateInferShape) +REG_INFER(TensorScatterAdd, PrimType_TensorScatterAdd, ScatterNdUpdateInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_update_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_update_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..37b827480c71c8b65fae7904bceca7364e26299d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_update_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SCATTER_ND_UPDATE_INFER_H +#define MINDSPORE_NNACL_SCATTER_ND_UPDATE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ScatterNdUpdateInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SCATTER_ND_UPDATE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/select_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/select_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..1b589d8fb37db963f225f95e369ca91386928e33 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/select_infer.c @@ -0,0 +1,62 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/select_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensorlist_c_utils.h" + +int SelectInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = + CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2 * outputs_size + 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + for (size_t i = 0; i < outputs_size; i++) { + const TensorC *input = inputs[i + 1]; + TensorC *output = outputs[i]; + SetDataTypeFormat(output, input); + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + for (size_t i = 0; i < outputs_size; i++) { + const TensorC *input = inputs[i + 1]; + TensorC *output = outputs[i]; + if (input->data_type_ == kObjectTypeTensorType) { + TensorListC *input_tensorlist = (TensorListC *)(input); + TensorListC *output_tensorlist = (TensorListC *)(output); + output_tensorlist->element_shape_size_ = input_tensorlist->element_shape_size_; + for (size_t j = 0; j < input_tensorlist->element_shape_size_; j++) { + output_tensorlist->element_shape_[j] = input_tensorlist->element_shape_[j]; + } + output_tensorlist->max_elements_num_ = input_tensorlist->max_elements_num_; + output_tensorlist->tensors_data_type_ = input_tensorlist->tensors_data_type_; + output_tensorlist->element_num_ = input_tensorlist->element_num_; + + for (size_t j = 0; j < output_tensorlist->element_num_; j++) { + memcpy(&output_tensorlist->tensors_[j], &input_tensorlist->tensors_[j], sizeof(TensorC)); + } + } else { + SetShapeTensor(output, input); + } + } + return NNACL_OK; +} + +REG_INFER(Select, PrimType_Select, SelectInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/select_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/select_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..8575b19e59c0321a6328a94c821310d75ec5798e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/select_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SELECT_INFER_H +#define MINDSPORE_NNACL_SELECT_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SelectInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SELECT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sgd_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sgd_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..b198cdf5f8851e2cd23b9abb7de7a6ffd7fd6535 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sgd_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/sgd_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int SgdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 6); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[1]) || + NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[3]) || NNACLGetElementNum(inputs[2]) != 1 || + NNACLGetElementNum(inputs[4]) != 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (outputs_size != 0) { + TensorC *out = outputs[0]; + SetDataTypeFormat(out, inputs[0]); + out->shape_size_ = 1; + out->shape_[0] = 1; + } + + return NNACL_OK; +} + +REG_INFER(SGD, PrimType_SGD, SgdInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sgd_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sgd_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..8246a6a287525b4bc8686e34a87c82c5912732b5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sgd_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SGD_INFER_H +#define MINDSPORE_NNACL_SGD_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SgdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SGD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..8d4dfba0da8b4699431a56e090076db9038eb90e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.c @@ -0,0 +1,97 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/shape_fusion_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CalculateOutput(const TensorC *in_tensor, const TensorC *matrix_tensor, TensorC *out_tensor, size_t input_len, + size_t origin_out_size) { + size_t out_size = out_tensor->shape_size_ == 0 ? 1 : (size_t)(out_tensor->shape_[0]); + if (out_size != origin_out_size && out_tensor->data_ != NULL) { + free(out_tensor->data_); + out_tensor->data_ = NULL; + } + size_t matrix_data_size = input_len * out_size * sizeof(float); + float *matrix_data = (float *)(malloc(matrix_data_size)); + NNACL_CHECK_NULL_RETURN_ERR(matrix_data); + if (matrix_tensor->data_type_ == kNumberTypeFloat32 || matrix_tensor->data_type_ == kNumberTypeFloat) { + memcpy(matrix_data, matrix_tensor->data_, matrix_data_size); +#ifdef ENABLE_FP16 + } else if (matrix_tensor->data_type_ == kNumberTypeFloat16) { + for (size_t i = 0; i < input_len * out_size; i++) { + matrix_data[i] = (float)(((float16_t *)(matrix_tensor->data_))[i]); + } +#endif + } else { + free(matrix_data); + return NNACL_ERR; + } + if (out_tensor->data_ == NULL) { + out_tensor->data_ = malloc(out_size * sizeof(int)); + } + int *data = (int *)out_tensor->data_; + if (data == NULL) { + free(matrix_data); + return NNACL_ERR; + } + memset(data, 0, out_size * sizeof(int)); + for (size_t i = 0; i < out_size; i++) { + for (size_t j = 0; j < input_len - 1; j++) { + data[i] += (int)(in_tensor->shape_[j] * matrix_data[i * input_len + j]); + } + data[i] += (int)(matrix_data[i * input_len + input_len - 1]); + } + free(matrix_data); + return NNACL_OK; +} + +int ShapeFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + NNACL_CHECK_TRUE_RET(inputs_size == outputs_size + 1, NNACL_INPUT_TENSOR_ERROR); + const TensorC *in_tensor = inputs[0]; + size_t input_len = in_tensor->shape_size_ + 1; + for (size_t out_idx = 0; out_idx < outputs_size; out_idx++) { + TensorC *out_tensor = outputs[out_idx]; + size_t origin_out_size = + out_tensor->data_ == NULL ? 0 : (out_tensor->shape_size_ == 0 ? 1 : (size_t)out_tensor->shape_[0]); + out_tensor->data_type_ = kNumberTypeInt32; + out_tensor->format_ = in_tensor->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + // calculate output tensor shape. + const TensorC *matrix_tensor = inputs[out_idx + 1]; + if (matrix_tensor->shape_size_ == 1) { + out_tensor->shape_size_ = 0; + out_tensor->shape_[0] = 0; + } else { + out_tensor->shape_size_ = 1; + out_tensor->shape_[0] = (int)(matrix_tensor->shape_[0]); + } + int ret = CalculateOutput(in_tensor, matrix_tensor, out_tensor, input_len, origin_out_size); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +REG_INFER(ShapeFusion, PrimType_Inner_ShapeFusion, ShapeFusionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..3c014100a16e02f29b2590d7d13e33ee454c0bf6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SHAPE_FUSION_INFER_H_ +#define MINDSPORE_NNACL_SHAPE_FUSION_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ShapeFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SHAPE_FUSION_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..8a2e3ff27566aa2bfd2c75f7d7490b7189999668 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_infer.c @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/shape_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int ShapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in_tensor = inputs[0]; + TensorC *out_tensor = outputs[0]; + + out_tensor->data_type_ = kNumberTypeInt32; + out_tensor->format_ = in_tensor->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + out_tensor->shape_size_ = 1; + out_tensor->shape_[0] = (int)(in_tensor->shape_size_); + return NNACL_OK; +} + +REG_INFER(Shape, PrimType_Shape, ShapeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..27721d0b822e591c61f8c35dbb6ca393f3808d88 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SHAPE_INFER_H +#define MINDSPORE_NNACL_SHAPE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ShapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SHAPE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/size_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/size_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..8b9ec9bd665eaa914533483530a1038fc9e9adee --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/size_infer.c @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/size_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in_tensor = inputs[0]; + TensorC *out_tensor = outputs[0]; + out_tensor->data_type_ = kNumberTypeInt32; + out_tensor->format_ = in_tensor->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + out_tensor->shape_size_ = 0; + out_tensor->shape_[0] = 1; + + return NNACL_OK; +} + +REG_INFER(SizeOp, PrimType_Size, SizeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/size_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/size_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..f1ccf7cfd067c51f82cc546601961438f11186ae --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/size_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SIZE_INFER_H +#define MINDSPORE_NNACL_SIZE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SIZE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/slice_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/slice_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..c041b3ea8ee4c3e6a5baee701a07dc323988235a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/slice_infer.c @@ -0,0 +1,126 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/slice_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +static bool CheckInputsDataType(const TensorC *const *inputs, size_t inputs_size) { + // not support data_type of slice's begin and size is not int32 + if (inputs_size >= 2) { + if (inputs[1]->data_type_ != kNumberTypeInt32) { + return false; + } + } + if (inputs_size == 3) { + if (inputs[2]->data_type_ != kNumberTypeInt32) { + return false; + } + } + return true; +} + +int InitBeginAndSizeParam(const TensorC *const *inputs, int *begin, int *size, int param_length) { + /* init begin parameter */ + int slice_begin_size = NNACLGetElementNum(inputs[1]); + int *begin_ptr = (int *)(inputs[1]->data_); + if (slice_begin_size != param_length || begin_ptr == NULL) { + return NNACL_INFER_INVALID; + } + if (slice_begin_size > MAX_AXIS_SIZE) { + return NNACL_ERR; + } + for (int i = 0; i < slice_begin_size; i++) { + begin[i] = begin_ptr[i]; + } + + /* init size parameter */ + int slice_size_size = NNACLGetElementNum(inputs[2]); + int *size_ptr = (int *)(inputs[2]->data_); + if (slice_size_size != param_length || size_ptr == NULL) { + return NNACL_INFER_INVALID; + } + if (slice_size_size > MAX_AXIS_SIZE) { + return NNACL_ERR; + } + for (int i = 0; i < slice_size_size; i++) { + size[i] = size_ptr[i]; + } + return NNACL_OK; +} + +int SliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (ret != NNACL_OK) { + return ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + + if (!CheckInputsDataType(inputs, inputs_size)) { + return NNACL_ERR; + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + SliceParameter *param = (SliceParameter *)parameter; + int param_length = (int)(input->shape_size_); + output->shape_size_ = input->shape_size_; + int begin[MAX_SHAPE_SIZE]; + int size[MAX_SHAPE_SIZE]; + + ret = InitBeginAndSizeParam(inputs, begin, size, param_length); + if (ret != NNACL_OK) { + return ret; + } + + for (int32_t i = 0; i < param_length; ++i) { + if (param->axis_[i] < 0) { + NNACL_CHECK_INT_ADD_NOT_OVERFLOW(param->axis_[i], (int)input->shape_size_, NNACL_PARAM_INVALID); + param->axis_[i] += (int)input->shape_size_; + } + NNACL_CHECK_TRUE_RET(param->axis_[i] >= 0 && param->axis_[i] < param_length, NNACL_PARAM_INVALID); + begin[param->axis_[i]] = begin[i]; + size[param->axis_[i]] = size[i]; + } + + for (int32_t i = 0; i < param_length; ++i) { + if (size[i] < 0 && size[i] != -1) { + return NNACL_PARAM_INVALID; + } + if (begin[i] < 0) { + return NNACL_PARAM_INVALID; + } + if (input->shape_[i] < begin[i]) { + return NNACL_PARAM_INVALID; + } + if (size[i] > (input->shape_[i] - begin[i])) { + return NNACL_PARAM_INVALID; + } + + output->shape_[i] = size[i] < 0 ? input->shape_[i] - begin[i] : size[i]; + } + return NNACL_OK; +} + +REG_INFER(Slice, PrimType_SliceFusion, SliceInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/slice_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/slice_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..cdd0a09b5c56ae097f4f81b26acdc0b60bfc1797 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/slice_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SLICE_INFER_H +#define MINDSPORE_NNACL_SLICE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SLICE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_cross_entropy_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_cross_entropy_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..eeb4ce5903375aa863ea7be49ed2792e66fbf9d4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_cross_entropy_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/softmax_cross_entropy_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SoftmaxCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in0 = inputs[0]; + TensorC *out = outputs[0]; + + out->shape_size_ = 2; + out->shape_[0] = in0->shape_[0]; + out->shape_[1] = 1; + SetDataTypeFormat(out, in0); + + if (1 < outputs_size) { + TensorC *grads = outputs[1]; + SetShapeTensor(grads, in0); + SetDataTypeFormat(grads, in0); + } + return NNACL_OK; +} + +REG_INFER(SoftmaxCrossEntropyWithLogits, PrimType_SoftmaxCrossEntropyWithLogits, SoftmaxCrossEntropyInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_cross_entropy_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_cross_entropy_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..ac407fb5193db2fda4f893c18cb9176368c7c19d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_cross_entropy_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SOFTMAX_CROSS_ENTROPY_INFER_H +#define MINDSPORE_NNACL_SOFTMAX_CROSS_ENTROPY_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SoftmaxCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SOFTMAX_ENTROPY_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..12dd4f68830e7b799e311f97c150bda7b8b8b5ee --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_infer.c @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/softmax_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SoftMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + output->data_type_ = input->data_type_; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + // there is a model with an 8-dim input, which runs on ascend910. + if (input->shape_size_ > DIMENSION_8D) { + return NNACL_ERR; + } + + SoftmaxParameter *param = (SoftmaxParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (param->axis_ < (-1 * (int)(input->shape_size_)) || param->axis_ > (int)(input->shape_size_)) { + return NNACL_PARAM_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(Softmax, PrimType_Softmax, SoftMaxInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..556f46d46abd3b91b0d1f9645ceb38ce3b35e2da --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SOFTMAX_INFER_H +#define MINDSPORE_NNACL_SOFTMAX_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/softmax_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SoftMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SOFTMAX_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..9e5d6f83af2ce033118bb46dfd376836fad8b272 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_infer.c @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/space_to_batch_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SpaceToBatchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[0], input); + SpaceToBatchParameter *param = (SpaceToBatchParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 4) { + return NNACL_ERR; + } + + int *block_shape = param->block_sizes_; + int block_shape_size = param->m_; + int *paddings = param->paddings_; + int padding_left = 0; + int padding_right = 0; + int block_w = 1; + if (block_shape_size == 2) { + padding_left = paddings[2]; + padding_right = paddings[3]; + block_w = block_shape[1]; + } + + NNACL_CHECK_ZERO_RETURN_ERR(block_shape[0]); + NNACL_CHECK_ZERO_RETURN_ERR(block_w); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(block_shape[0], block_w, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input->shape_[kNHWC_N], block_shape[0] * block_w, NNACL_ERR); + outputs[0]->shape_[kNHWC_N] = input->shape_[kNHWC_N] * (block_shape[0] * block_w); + outputs[0]->shape_[kNHWC_H] = (input->shape_[kNHWC_H] + paddings[0] + paddings[1]) / block_shape[0]; + outputs[0]->shape_[kNHWC_W] = (input->shape_[kNHWC_W] + padding_left + padding_right) / block_w; + outputs[0]->shape_[kNHWC_C] = input->shape_[kNHWC_C]; + outputs[0]->shape_size_ = input->shape_size_; + return NNACL_OK; +} + +REG_INFER(SpaceToBatch, PrimType_SpaceToBatch, SpaceToBatchInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d07d8ebf942a344291d8eaf2142dd23594edf19f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPACE_TO_BATCH_INFER_H +#define MINDSPORE_NNACL_SPACE_TO_BATCH_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SpaceToBatchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPACE_TO_BATCH_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_nd_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_nd_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..2415942a8509a52f2a748e7610f68035cf491928 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_nd_infer.c @@ -0,0 +1,143 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/space_to_batch_nd_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int SpaceSetOutputShapeFromParam(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, const OpParameter *parameter) { + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + if (input->shape_size_ != 4) { + return NNACL_ERR; + } + const SpaceToBatchParameter *param = (const SpaceToBatchParameter *)parameter; + const int *block_shape = param->block_sizes_; + int block_shape_size = param->m_; + const int *padding = param->paddings_; + int padding_left = 0; + int padding_right = 0; + int block_w = 1; + if (block_shape_size == 2) { + padding_left = padding[2]; + padding_right = padding[3]; + block_w = block_shape[1]; + } + if (input->shape_[kNHWC_N] == 0 || block_shape[0] * block_w > INT_MAX / input->shape_[kNHWC_N]) { + return NNACL_ERR; + } + outputs[0]->shape_[kNHWC_N] = input->shape_[kNHWC_N] * block_shape[0] * block_w; + if (padding[0] + padding[1] > INT_MAX - input->shape_[kNHWC_H]) { + return NNACL_ERR; + } + if (block_shape[0] == 0 || block_w == 0) { + return NNACL_ERR; + } + outputs[0]->shape_[kNHWC_H] = (input->shape_[kNHWC_H] + padding[0] + padding[1]) / block_shape[0]; + if (padding_left + padding_right > INT_MAX - input->shape_[kNHWC_W]) { + return NNACL_ERR; + } + outputs[0]->shape_[kNHWC_W] = (input->shape_[kNHWC_W] + padding_left + padding_right) / block_w; + if (input->shape_size_ > 3) { + outputs[0]->shape_[kNHWC_C] = input->shape_[kNHWC_C]; + } + outputs[0]->shape_size_ = input->shape_size_; + return NNACL_OK; +} + +int SpaceSetOutputShapeFromInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + const TensorC *input = inputs[0]; + if (input->shape_size_ != 4) { + return NNACL_ERR; + } + if (NNACLGetElementNum(inputs[2]) != 4) { + return NNACL_ERR; + } + int *block_shape = (int *)(inputs[1]->data_); + int *padding = (int *)(inputs[2]->data_); + int padding_left = 0; + int padding_right = 0; + int block_w = 1; + if (NNACLGetElementNum(inputs[1]) == 2) { + padding_left = padding[2]; + padding_right = padding[3]; + block_w = block_shape[1]; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input->shape_size_; + if (input->shape_[kNHWC_N] == 0 || block_shape[0] * block_w > INT_MAX / input->shape_[kNHWC_N]) { + return NNACL_ERR; + } + output_shape[kNHWC_N] = input->shape_[kNHWC_N] * block_shape[0] * block_w; + if (padding[0] + padding[1] > INT_MAX - input->shape_[kNHWC_H]) { + return NNACL_ERR; + } + if (block_shape[0] == 0 || block_w == 0) { + return NNACL_ERR; + } + output_shape[kNHWC_H] = (input->shape_[kNHWC_H] + padding[0] + padding[1]) / block_shape[0]; + if (padding_left + padding_right > INT_MAX - input->shape_[kNHWC_W]) { + return NNACL_ERR; + } + output_shape[kNHWC_W] = (input->shape_[kNHWC_W] + padding_left + padding_right) / block_w; + if (input->shape_size_ > 3) { + output_shape[kNHWC_C] = input->shape_[kNHWC_C]; + } + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +int SpaceToBatchNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 1, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_ERR; + } + outputs[0]->data_type_ = input->data_type_; + outputs[0]->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (inputs_size == 1) { + int ret = SpaceSetOutputShapeFromParam(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + } + if (inputs_size == 3) { + if (inputs[1]->data_ == NULL || inputs[2]->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int ret = SpaceSetOutputShapeFromInput(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +REG_INFER(SpaceToBatchND, PrimType_SpaceToBatchND, SpaceToBatchNdInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_nd_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_nd_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..e1c07688fc9015d2253fc1266f5e273d45f65957 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_nd_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPACE_TO_BATCH_ND_INFER_H +#define MINDSPORE_NNACL_SPACE_TO_BATCH_ND_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SpaceToBatchNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPACE_TO_BATCH_ND_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_depth_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_depth_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..09b5aecfe4be1c28f214840061c802ec4d1f03a3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_depth_infer.c @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/space_to_depth_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" + +int SpaceToDepthInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[0], input); + SpaceToDepthParameter *param = (SpaceToDepthParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 4) { + return NNACL_ERR; + } + + int32_t block_size = param->block_size_; + if (block_size == 0) { + return NNACL_ERR; + } + if (input->shape_[kNHWC_H] % block_size != 0 || input->shape_[kNHWC_H] == 0 || + input->shape_[kNHWC_W] % block_size != 0 || input->shape_[kNHWC_W] == 0) { + return NNACL_ERR; + } + outputs[0]->shape_[kNHWC_N] = input->shape_[kNHWC_N]; + outputs[0]->shape_[kNHWC_H] = input->shape_[kNHWC_H] / block_size; + outputs[0]->shape_[kNHWC_W] = input->shape_[kNHWC_W] / block_size; + if (input->shape_[kNHWC_C] == 0 || block_size * block_size > INT_MAX / input->shape_[kNHWC_C]) { + return NNACL_ERR; + } + outputs[0]->shape_[kNHWC_C] = input->shape_[kNHWC_C] * (block_size * block_size); + outputs[0]->shape_size_ = input->shape_size_; + return NNACL_OK; +} + +REG_INFER(SpaceToDepth, PrimType_SpaceToDepth, SpaceToDepthInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_depth_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_depth_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..88fe7f208d145d9931e0ee3159061a9613d0b548 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_depth_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPACE_TO_DEPTH_INFER_H +#define MINDSPORE_NNACL_SPACE_TO_DEPTH_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/space_to_depth_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SpaceToDepthInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPACE_TO_DEPTH_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_fill_empty_rows_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_fill_empty_rows_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..c1369dd9b397430be1054f1aad0a3111245e0ea2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_fill_empty_rows_infer.c @@ -0,0 +1,49 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/sparse_fill_empty_rows_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SparseFillEmptyRowsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, C4NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input0 = inputs[0]; + TensorC *output0 = outputs[0]; + SetDataTypeFormat(output0, input0); + + const TensorC *input1 = inputs[1]; + TensorC *output1 = outputs[1]; + SetDataTypeFormat(output1, input1); + + TensorC *output2 = outputs[C2NUM]; + SetDataTypeFormat(output2, input0); + output2->data_type_ = kNumberTypeBool; + + if (outputs_size == C4NUM) { + TensorC *output3 = outputs[C3NUM]; + SetDataTypeFormat(output3, input0); + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + return NNACL_INFER_INVALID; +} + +REG_INFER(SparseFillEmptyRows, PrimType_SparseFillEmptyRows, SparseFillEmptyRowsInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_fill_empty_rows_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_fill_empty_rows_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..e6ce7882942759e574db3fbc564d02cd54f5d777 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_fill_empty_rows_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPARSE_FILL_EMPTY_ROWS_H +#define MINDSPORE_NNACL_SPARSE_FILL_EMPTY_ROWS_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SparseFillEmptyRowsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPARSE_FILL_EMPTY_ROWS_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_reshape_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_reshape_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..a00be2badc10310c64f6b4535c45058252dc7627 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_reshape_infer.c @@ -0,0 +1,53 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/sparse_reshape_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SparseReshapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C2NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in_indices_tensor = inputs[0]; + TensorC *out_indices_tensor = outputs[0]; + SetDataTypeFormat(out_indices_tensor, in_indices_tensor); + + const TensorC *in_out_shape_tensor = inputs[C2NUM]; + TensorC *out_shape_tensor = outputs[C1NUM]; + SetDataTypeFormat(out_shape_tensor, in_out_shape_tensor); + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + SetShapeArray(out_shape_tensor, in_out_shape_tensor->shape_, in_out_shape_tensor->shape_size_); + + int out_indices_shape[MAX_SHAPE_SIZE] = {0}; + out_indices_shape[0] = in_indices_tensor->shape_[0]; + size_t out_indices_shape_size = 1; + + for (int i = 0; i < in_out_shape_tensor->shape_size_; ++i) { + out_indices_shape[i + 1] = in_out_shape_tensor->shape_[i]; + out_indices_shape_size++; + } + SetShapeArray(out_indices_tensor, out_indices_shape, out_indices_shape_size); + return NNACL_OK; +} + +REG_INFER(SparseReshape, PrimType_SparseReshape, SparseReshapeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_reshape_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_reshape_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..e594ffffeacc724b840d7a103f2f67f63c988fa3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_reshape_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPARSE_RESHAPE_INFER_H +#define MINDSPORE_NNACL_SPARSE_RESHAPE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SparseReshapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPARSE_RESHAPE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_segment_sum_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_segment_sum_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..cc0263a020b13742c1e0e1f8c09d8a157fc682e9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_segment_sum_infer.c @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/sparse_segment_sum_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SparseSegmentSumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + return NNACL_OK; +} + +REG_INFER(SparseSegmentSum, PrimType_SparseSegmentSum, SparseSegmentSumInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_segment_sum_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_segment_sum_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..4589c724aeb60cf2c2a58571493d65cf3b16fd26 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_segment_sum_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPARSE_SEGMENT_SUM_INFER_H +#define MINDSPORE_NNACL_SPARSE_SEGMENT_SUM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SparseSegmentSumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPARSE_SEGMENT_SUM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..84fb7d9e2b3c8b2aa960d003f48655b3233b7869 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.h" +#include "nnacl_c/fp32_grad/softmax_grad.h" +#include "nnacl_c/infer/infer_register.h" + +int SparseSoftmaxCrossEntropyWithLogitsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in0 = inputs[0]; + TensorC *out = outputs[0]; + + SoftmaxCrossEntropyParameter *param = (SoftmaxCrossEntropyParameter *)parameter; + if (param->is_grad_ != 0) { + SetShapeTensor(out, in0); + SetDataTypeFormat(out, in0); + } else { + out->shape_size_ = 1; + out->shape_[0] = 1; + SetDataTypeFormat(out, in0); + } + + return NNACL_OK; +} + +REG_INFER(SparseSoftmaxCrossEntropyWithLogits, PrimType_SparseSoftmaxCrossEntropyWithLogits, + SparseSoftmaxCrossEntropyWithLogitsInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..396b50e0ff025f5f7877d9b853677ee53ac5dc02 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_INFER_H_ +#define MINDSPORE_NNACL_INFER_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SparseSoftmaxCrossEntropyWithLogitsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_to_dense_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_to_dense_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..1d017c7352ced6b2a1893cc99c0b1040ec194104 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_to_dense_infer.c @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/sparse_to_dense_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int SparseToDenseInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorC *output = outputs[0]; + if (inputs_size < 3) { + return NNACL_INPUT_TENSOR_ERROR; + } + const TensorC *input1 = inputs[1]; + SetDataTypeFormat(output, input1); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int *input1_data = (int *)(input1->data_); + int data_num = NNACLGetElementNum(input1); + if (input1_data == 0 || data_num > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + for (int i = 0; i < data_num; i++) { + ShapePush(output_shape, &output_shape_size, input1_data[i]); + } + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(SparseToDense, PrimType_SparseToDense, SparseToDenseInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_to_dense_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_to_dense_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..9a521b55e01efd0eda65b4aff128477d76cf8576 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_to_dense_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPACE_TO_DENSE_INFER_H +#define MINDSPORE_NNACL_SPACE_TO_DENSE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SparseToDenseInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPACE_TO_DENSE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/splice_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/splice_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..471af16471234ce6a3bc8696659ca76e9bbbe689 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/splice_infer.c @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/splice_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SpliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ != DIMENSION_3D) { + return NNACL_INPUT_TENSOR_ERROR; + } + SpliceParameter *param = (SpliceParameter *)parameter; + if (param == NULL) { + return NNACL_NULL_PTR; + } + int out_dim = param->output_dim_; + ShapeSet(output->shape_, &output->shape_size_, input->shape_, input->shape_size_); + + if (param->context_dim_ == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + if (param->forward_indexes_dim_ % param->context_dim_ != 0) { + return NNACL_PARAM_INVALID; + } + int out_size = param->forward_indexes_dim_ / param->context_dim_; + output->shape_[DIMENSION_1D] = out_size; + output->shape_[DIMENSION_2D] = out_dim; + return NNACL_OK; +} + +REG_INFER(Splice, PrimType_Splice, SpliceInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/splice_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/splice_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..312b1ee738f764f8cdaebef8fa52de80cb3b9dc4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/splice_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_INFER_SPLICE_INFER_H_ +#define MINDSPORE_NNACL_INFER_SPLICE_INFER_H_ +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/splice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SpliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_SPLICE_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..1c4b79b6aa0230f33479a883f2fa921793907384 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_infer.c @@ -0,0 +1,120 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/split_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int UpdateSplitSize(const TensorC *const *inputs, size_t inputs_size, SplitParameter *param) { + // get split size from the second input. + if (inputs_size == DIMENSION_2D && inputs[SECOND_INPUT]->data_ != NULL) { + if (inputs[SECOND_INPUT]->data_type_ != kNumberTypeInt32) { + return NNACL_ERR; + } + int split_count = 1; + for (size_t i = 0; i < inputs[SECOND_INPUT]->shape_size_; i++) { + split_count *= inputs[SECOND_INPUT]->shape_[i]; + } + param->split_count_ = split_count; + for (int i = 0; i < split_count; i++) { + param->split_sizes_[i] = ((int *)(inputs[SECOND_INPUT]->data_))[i]; + } + } + if (param->split_count_ == 0) { + const TensorC *input = inputs[0]; + int32_t split_chunk_size = UP_DIV(input->shape_[param->split_dim_], param->num_split_); + for (int i = 0; i < param->num_split_; ++i) { + if (i != param->num_split_ - 1) { + param->split_sizes_[i] = split_chunk_size; + } else { + param->split_sizes_[i] = input->shape_[param->split_dim_] - split_chunk_size * i; + } + } + } + return NNACL_OK; +} + +int SetSplitOutputShape(const TensorC *input, TensorC **outputs, SplitParameter *param) { + for (int i = 0; i < param->num_split_; ++i) { + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = 0; + ShapeSet(output_shape, &output_shape_size, input->shape_, input->shape_size_); + int split_dim_i = input->shape_[param->split_dim_]; + if (i == param->num_split_ - 1 && param->split_sizes_[i] == -1) { + if (param->num_split_ - 1 < 0) { + return NNACL_ERR; + } + for (int j = 0; j < param->num_split_ - 1; ++j) { + split_dim_i -= param->split_sizes_[j]; + } + param->split_sizes_[i] = split_dim_i; + } else { + split_dim_i = param->split_sizes_[i]; + } + NNACL_CHECK_TRUE_RET(split_dim_i >= 0 && split_dim_i <= input->shape_[param->split_dim_], NNACL_ERR); + output_shape[param->split_dim_] = split_dim_i; + SetShapeArray(outputs[i], output_shape, output_shape_size); + SetDataTypeFormat(outputs[i], input); + } + return NNACL_OK; +} + +int SplitInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + for (size_t i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], input); + } + + SplitParameter *param = (SplitParameter *)parameter; + + int num_split = param->num_split_ == 0 ? (int)(outputs_size) : param->num_split_; + if (num_split == 0) { + return NNACL_ERR; + } + param->num_split_ = num_split; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int split_dim = param->split_dim_ < 0 ? ((int)(input->shape_size_)) + param->split_dim_ : param->split_dim_; + if (split_dim >= (int)(input->shape_size_) || split_dim < 0) { + return NNACL_ERR; + } + param->split_dim_ = split_dim; + if ((int)(outputs_size) != num_split) { + return NNACL_ERR; + } + + int ret = UpdateSplitSize(inputs, inputs_size, param); + if (ret != NNACL_OK) { + return ret; + } + ret = SetSplitOutputShape(input, outputs, param); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +REG_INFER(Split, PrimType_Split, SplitInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..acb8f5e1e453de750ad6d59018d1b42fb9cbcb70 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPLIT_INFER_H +#define MINDSPORE_NNACL_SPLIT_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SplitInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_reduce_concat_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_reduce_concat_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..39c1463e8c5c8a59e088bc4691293286ed3aeaa9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_reduce_concat_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/split_reduce_concat_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/split_parameter.h" + +int SplitReduceConcatFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + NNACL_CHECK_TRUE_RET(inputs_size == outputs_size, NNACL_INPUT_TENSOR_ERROR); + const TensorC *in_tensor = inputs[0]; + TensorC *out_tensor = outputs[0]; + out_tensor->format_ = in_tensor->format_; + for (size_t i = 0; i < in_tensor->shape_size_; i++) { + out_tensor->shape_[i] = in_tensor->shape_[i]; + } + SplitParameter *param = (SplitParameter *)parameter; + out_tensor->shape_[param->split_dim_] = param->num_split_; + out_tensor->shape_size_ = in_tensor->shape_size_; + return NNACL_OK; +} + +REG_INFER(SplitReduceConcatFusion, PrimType_Inner_SplitReduceConcatFusion, SplitReduceConcatFusionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_reduce_concat_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_reduce_concat_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..80cfc65c68c383ffc7ff4f70f5d3ed570e4f8445 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_reduce_concat_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPLIT_REDUCE_CONCAT_INFER_H +#define MINDSPORE_NNACL_SPLIT_REDUCE_CONCAT_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SplitReduceConcatFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_REDUCE_CONCAT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_with_over_lap_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_with_over_lap_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..0588fb195ac160ed5335284ec28f0db5ca8cd5c8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_with_over_lap_infer.c @@ -0,0 +1,84 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/split_with_over_lap_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/op_base.h" + +int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *input = inputs[0]; + SplitWithOverlapParameter *param = (SplitWithOverlapParameter *)parameter; + + int split_dim = param->split_dim_; + int number_split = param->num_split_; + if (outputs_size != (size_t)number_split) { + return NNACL_ERR; + } + + int ratio[SPLIT_MAX_SLICE_NUM]; + int extend_top[SPLIT_MAX_SLICE_NUM]; + int extend_bottom[SPLIT_MAX_SLICE_NUM]; + for (int i = 0; i < number_split; ++i) { + ratio[i] = param->ratio_[i]; + extend_top[i] = param->extend_top_[i]; + extend_bottom[i] = param->extend_bottom_[i]; + } + + const int *input_shape = input->shape_; + int split_dim_size = input_shape[split_dim]; + int total_block_count = 0; + for (int i = 0; i < number_split; i++) { + total_block_count += ratio[i]; + } + + int borders[MAX_SHAPE_SIZE]; + borders[0] = 0; + int visited_block = 0; + for (int i = 0; i < number_split - 1; i++) { + visited_block += ratio[i]; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(split_dim_size, visited_block) || total_block_count == 0, NNACL_ERR); + int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count); + borders[i + 1] = cur_border; + } + borders[number_split] = split_dim_size; + + for (int i = 0; i < number_split; ++i) { + int output_shape[MAX_SHAPE_SIZE]; + for (int dim = 0; dim < input->shape_size_; dim++) { + if (dim == split_dim) { + int splited_size = borders[i + 1] - borders[i]; + splited_size += (extend_top[i] + extend_bottom[i]); + output_shape[dim] = splited_size; + } else { + output_shape[dim] = input_shape[dim]; + } + } + SetShapeArray(outputs[i], output_shape, input->shape_size_); + SetDataTypeFormat(outputs[i], input); + } + return NNACL_OK; +} + +REG_INFER(SplitWithOverlap, PrimType_SplitWithOverlap, SplitWithOverlapInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_with_over_lap_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_with_over_lap_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..96b9ef40f8d0cdd04359466c14665d87053ea379 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_with_over_lap_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H +#define MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/squeeze_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/squeeze_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..b663436d156508a03b9d3d9f92ec0e0714cb8ded --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/squeeze_infer.c @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/squeeze_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int SqueezeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = + CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 1, kInputSize1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + SqueezeParameter *param = (SqueezeParameter *)parameter; + SetDataTypeFormat(outputs[0], input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + + if (inputs_size == kInputSize1) { + NNACL_CHECK_TRUE_RET(inputs[1]->data_type_ == kNumberTypeInt32 || inputs[1]->data_type_ == kNumberTypeInt, + NNACL_PARAM_INVALID); + int *axis_data = (int *)(inputs[1]->data_); + NNACL_CHECK_TRUE_RET(axis_data != NULL, NNACL_PARAM_INVALID); + param->axis_size_ = NNACLGetElementNum(inputs[1]); + for (size_t i = 0; i < param->axis_size_; i++) { + param->axis_[i] = *(axis_data + i); + } + } + if (param->axis_size_ > MAX_SHAPE_SIZE) { + return NNACL_PARAM_INVALID; + } + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + + for (size_t i = 0; i < param->axis_size_; i++) { + param->axis_[i] = param->axis_[i] >= 0 ? param->axis_[i] : param->axis_[i] + (int)input->shape_size_; + } + + if (param->axis_size_ == 0) { + for (size_t i = 0; i < input->shape_size_; i++) { + if (input->shape_[i] != 1) { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + } + } else { + size_t axisIdx = 0; + for (size_t i = 0; i < input->shape_size_; i++) { + if (axisIdx < param->axis_size_ && param->axis_[axisIdx] == (int)(i)) { + if (input->shape_[i] != 1) return NNACL_PARAM_INVALID; + axisIdx++; + continue; + } else { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + } + } + SetShapeArray(outputs[0], out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(Squeeze, PrimType_Squeeze, SqueezeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/squeeze_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/squeeze_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..ef2a773af50adb475b8d6d5dd3434cf8f5011a19 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/squeeze_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SQUEEZE_INFER_H +#define MINDSPORE_NNACL_SQUEEZE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/squeeze_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SqueezeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SQUEEZE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/stack_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/stack_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..d87725753adc643b1caf215cd0543667555f8d7a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/stack_infer.c @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/stack_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int StackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (outputs_size != 1) { + return NNACL_PARAM_INVALID; + } + if (inputs_size < 1) { + return NNACL_PARAM_INVALID; + } + const TensorC *input = inputs[0]; + SetDataTypeFormat(outputs[0], input); + StackParameter *param = (StackParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + ShapeSet(output_shape, &output_shape_size, input->shape_, input->shape_size_); + int axis = param->axis_ < 0 ? (int)(param->axis_) + (int)(input->shape_size_) + 1 : param->axis_; + if (axis < 0 || axis > (int)(input->shape_size_)) { + return NNACL_PARAM_INVALID; + } + + for (size_t i = 1; i < inputs_size; ++i) { + if (inputs[i]->shape_size_ != input->shape_size_) { + return NNACL_PARAM_INVALID; + } + for (size_t j = 0; j < input->shape_size_; ++j) { + if (inputs[i]->shape_[j] != input->shape_[j]) { + return NNACL_PARAM_INVALID; + } + } + if (inputs[i]->data_type_ != input->data_type_) { + return NNACL_PARAM_INVALID; + } + } + int insert_ret = ShapeInsert(output_shape, &output_shape_size, axis, inputs_size); + if (insert_ret != NNACL_OK) { + return NNACL_ERR; + } + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(Stack, PrimType_Stack, StackInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/stack_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/stack_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..aec77cce12204325092ec14623ff8b720be0629d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/stack_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_STACK_INFER_H +#define MINDSPORE_NNACL_STACK_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/stack_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int StackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_STACK_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..17e286250ad90648ab25ea0b9f783c282f5b85cc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_grad_infer.c @@ -0,0 +1,163 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/strided_slice_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +bool StridedSliceCheckInputs(const TensorC *const *inputs, size_t inputs_size) { + for (size_t i = 1; i < inputs_size; ++i) { + if (inputs[i]->data_ == NULL) { + return false; + } + } + if (NNACLGetElementNum(inputs[2]) > MAX_SHAPE_SIZE) { + return false; + } + if (NNACLGetElementNum(inputs[2]) != NNACLGetElementNum(inputs[3]) && + NNACLGetElementNum(inputs[2]) != NNACLGetElementNum(inputs[4])) { + return false; + } + return true; // note: the original code is ndim_ <= in_shape_size +} + +void ApplyBeginEndEllipsisMask(size_t ndim, int *begins, const uint32_t *const begins_mask, int *ends, + const uint32_t *const ends_mask, const uint32_t *const ellipsis_mask, + const int *const in_shape) { + for (size_t i = 0; i < ndim; i++) { + if (begins_mask[i]) { + begins[i] = 0; + } + if (ends_mask[i]) { + ends[i] = in_shape[i]; + } + } + for (size_t i = 0; i < ndim; i++) { + if (ellipsis_mask[i]) { + begins[i] = 0; + ends[i] = in_shape[i]; + break; + } + } +} + +int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + SetDataTypeFormat(outputs[0], input); + bool inferflag = InferFlag(inputs, inputs_size); + + int in_shape_[MAX_SHAPE_SIZE] = {0}; + size_t in_shape_size = 0; + if (inferflag) { + ShapeSet(in_shape_, &in_shape_size, input->shape_, input->shape_size_); + } + int begins_[MAX_SHAPE_SIZE] = {0}; + size_t begins_size = 0; + int ends_[MAX_SHAPE_SIZE] = {0}; + size_t ends_size = 0; + int strides_[MAX_SHAPE_SIZE] = {0}; + size_t strides_size = 0; + + if (!StridedSliceCheckInputs(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + // input order: dy, shapex, begins, ends, strides. + const TensorC *begin_tensor = inputs[2]; + int *begin_data = (int *)(begin_tensor->data_); + int *end_data = (int *)(inputs[3]->data_); + int *stride_data = (int *)(inputs[4]->data_); + + size_t ndim_ = (size_t)NNACLGetElementNum(begin_tensor); + if (ndim_ > MAX_SHAPE_SIZE) { + return NNACL_INFER_INVALID; + } + for (size_t i = 0; i < ndim_; ++i) { + ShapePush(begins_, &begins_size, begin_data[i]); + ShapePush(ends_, &ends_size, end_data[i]); + ShapePush(strides_, &strides_size, stride_data[i]); + } + + // set all mask to original input shape + uint32_t begins_mask_[MAX_SHAPE_SIZE] = {0}; + uint32_t ends_mask_[MAX_SHAPE_SIZE] = {0}; + uint32_t ellipsis_mask_[MAX_SHAPE_SIZE] = {0}; + uint32_t new_axis_mask_[MAX_SHAPE_SIZE] = {0}; + + StridedSliceParameter *param = (StridedSliceParameter *)parameter; + for (size_t i = 0; i < ndim_; i++) { + begins_mask_[i] = (unsigned)(param->begins_mask_) & (1 << i); + ends_mask_[i] = (unsigned)(param->ends_mask_) & (1 << i); + ellipsis_mask_[i] = (unsigned)(param->ellipsisMask_) & (1 << i); + new_axis_mask_[i] = (unsigned)(param->newAxisMask_) & (1 << i); + } + param->num_axes_ = (int)(in_shape_size); + param->in_shape_length_ = (int)(in_shape_size); + for (size_t i = 0; i < ndim_; ++i) { + param->begins_[i] = begins_[i]; + param->ends_[i] = ends_[i]; + param->strides_[i] = strides_[i]; + } + ShapeSet(param->in_shape_, &in_shape_size, input->shape_, input->shape_size_); + // ApplyNewAxisMask; + for (size_t i = 0; i < ndim_; i++) { + if (new_axis_mask_[i]) { + ndim_ += 1; + int ret = ShapeInsert(in_shape_, &in_shape_size, i, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + begins_[i] = 0; + ends_[i] = 1; + strides_[i] = 1; + + ShapePush(begins_, &begins_size, 0); + ShapePush(ends_, &ends_size, in_shape_[ndim_ - 1]); + ShapePush(strides_, &strides_size, 1); + + begins_mask_[i] = false; + ends_mask_[i] = false; + ellipsis_mask_[i] = false; + } + } + ApplyBeginEndEllipsisMask(ndim_, begins_, begins_mask_, ends_, ends_mask_, ellipsis_mask_, in_shape_); + if (!inferflag) { + return NNACL_OK; + } + int output_size = inputs[1]->shape_[0]; + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + if (inputs[1]->data_ == NULL) { + return NNACL_ERR; + } + + if (output_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + for (int i = 0; i < output_size; i++) { + ShapePush(output_shape, &output_shape_size, ((int *)(inputs[1]->data_))[i]); + } + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(StridedSliceGrad, PrimType_StridedSliceGrad, StridedSliceGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..caba55ba822032d13bf40e712d9e1b7ea5265dd3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_grad_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_STRIDED_SLICE_GRAD_INFER_H +#define MINDSPORE_NNACL_STRIDED_SLICE_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/strided_slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_STRIDED_SLICE_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..89cbf0eb87cf413c281cc1b71e44000941c7fb3e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.c @@ -0,0 +1,483 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/strided_slice_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c_utils.h" + +const size_t kStridedSliceOutputNum = 1; +const size_t kStridedSliceInputNum = 1; +const size_t kStridedSliceMultiInputNumMin = 3; +const size_t kStridedSliceMultiInputNumMax = 5; + +typedef struct StridedSliceTransferBuffer { + int ndim_; + + int begins_[MAX_SHAPE_SIZE]; + int ends_[MAX_SHAPE_SIZE]; + int strides_[MAX_SHAPE_SIZE]; + int begins_mask_[MAX_SHAPE_SIZE]; + int ends_mask_[MAX_SHAPE_SIZE]; + int ellipsis_mask_[MAX_SHAPE_SIZE]; + int new_axis_mask_[MAX_SHAPE_SIZE]; + int shrink_axis_mask_[MAX_SHAPE_SIZE]; + + size_t begins_size_; + size_t ends_size_; + size_t strides_size_; + size_t ellipsis_mask_size_; + size_t new_axis_mask_size_; + size_t shrink_axis_mask_size_; +} StridedSliceTransferBuffer; + +bool CheckInputs(const TensorC *const *inputs, size_t inputs_size) { + for (size_t i = 1; i < inputs_size; ++i) { + if (inputs[i]->data_ == NULL) { + return false; + } + } + return true; +} + +int HandleAxesCheckNull(const TensorC *input_tensor, const TensorC *begin_tensor, int *begin_data, + const TensorC *end_tensor, int *end_data) { + if (input_tensor == NULL || begin_tensor == NULL || end_tensor == NULL || begin_data == NULL || end_data == NULL) { + return NNACL_NULL_PTR; + } + return NNACL_OK; +} + +int HandleAxesInputNotExist(const TensorC *const *inputs, struct StridedSliceTransferBuffer *transfer_buffer) { + const TensorC *begin_tensor = inputs[1]; + const TensorC *end_tensor = inputs[2]; + const TensorC *stride_tensor = inputs[3]; + int ret = GetInt32DataFromTensor(begin_tensor, transfer_buffer->begins_, &transfer_buffer->begins_size_); + if (ret != NNACL_OK) { + return ret; + } + transfer_buffer->ndim_ = NNACLGetElementNum(begin_tensor); + ret = GetInt32DataFromTensor(end_tensor, transfer_buffer->ends_, &transfer_buffer->ends_size_); + if (ret != NNACL_OK) { + return ret; + } + ret = GetInt32DataFromTensor(stride_tensor, transfer_buffer->strides_, &transfer_buffer->strides_size_); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +int GenerateAxes(const TensorC *axes_tensor, int *axes, int num, int ndim) { + int *axes_data = NULL; + if (NNACLGetElementNum(axes_tensor) != 0) { + if (NNACLGetElementNum(axes_tensor) != num) { + return NNACL_ERR; + } + axes_data = (int *)(axes_tensor->data_); + if (axes_data == NULL) { + return NNACL_NULL_PTR; + } + } + if (axes_data == NULL) { + for (int i = 0; i < num; ++i) { + axes[i] = i; + } + } else { + for (int i = 0; i < num; i++) { + axes[i] = axes_data[i]; + } + for (int i = 0; i < num; ++i) { + if (axes[i] < 0) { + axes[i] += ndim; + } + } + } + return NNACL_OK; +} + +int HandleAxesInputExist(const TensorC *const *inputs, int *ndim, int *in_shape, int *begins, int *strides, int *ends) { + const TensorC *input_tensor = inputs[0]; + const TensorC *begin_tensor = inputs[1]; + int begin_data[MAX_SHAPE_SIZE]; + size_t begin_data_size; + int ret = GetInt32DataFromTensor(begin_tensor, begin_data, &begin_data_size); + if (ret != NNACL_OK) { + return ret; + } + + const TensorC *end_tensor = inputs[2]; + int end_data[MAX_SHAPE_SIZE]; + size_t end_data_size; + ret = GetInt32DataFromTensor(end_tensor, end_data, &end_data_size); + if (ret != NNACL_OK) { + return ret; + } + + int handle_check_ret = HandleAxesCheckNull(input_tensor, begin_tensor, begin_data, end_tensor, end_data); + if (handle_check_ret != NNACL_OK) { + return handle_check_ret; + } + + // when input contains axes, begins, ends, strides will be expand to the same length as input rank + *ndim = (int)(input_tensor->shape_size_); + int begin_ndim = NNACLGetElementNum(begin_tensor); + + int *stride_data = NULL; + const TensorC *stride_tensor = inputs[4]; + int stride_data_num = NNACLGetElementNum(stride_tensor); + if (stride_data_num != 0) { + NNACL_CHECK_TRUE_RET(stride_data_num == begin_ndim, NNACL_ERR); + stride_data = (int *)(stride_tensor->data_); + } + + const TensorC *axes_tensor = inputs[3]; + int axes[MAX_SHAPE_SIZE] = {0}; + ret = GenerateAxes(axes_tensor, axes, begin_ndim, *ndim); + if (ret != NNACL_OK) { + return ret; + } + + if (*ndim > MAX_SHAPE_SIZE || *ndim < 0) { + return NNACL_ERR; + } + for (int i = 0; i < *ndim; i++) { + in_shape[i] = 0; + begins[i] = 0; + strides[i] = 0; + } + for (int i = 0; i < *ndim; ++i) { + in_shape[i] = input_tensor->shape_[i]; + } + for (int i = 0; i < *ndim; ++i) { + int axes_it = 0; + if (begin_ndim > MAX_SHAPE_SIZE || begin_ndim < 0) { + return NNACL_ERR; + } + for (int j = 0; j < begin_ndim; j++) { + if (axes[j] == i) { + axes_it = j; + break; + } else { + axes_it++; + } + } + if (axes_it != begin_ndim) { + int axis = axes_it; + if (begin_data[axis] > input_tensor->shape_[i] - 1) { + begins[i] = begin_data[axis]; + } else { + begins[i] = imax(imin(begin_data[axis], input_tensor->shape_[i] - 1), -input_tensor->shape_[i]); + } + // ends exceed limit will be set to limit + ends[i] = imax(imin(end_data[axis], input_tensor->shape_[i]), -input_tensor->shape_[i] - 1); + if (stride_data == NULL) { + return NNACL_ERR; + } + strides[i] = stride_data[axis]; + } else { + begins[i] = 0; + ends[i] = input_tensor->shape_[i]; + strides[i] = 1; + } + } + return NNACL_OK; +} + +int StrideSlicePreCheck(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (outputs_size != kStridedSliceOutputNum) { + return NNACL_PARAM_INVALID; + } + if (inputs_size != kStridedSliceInputNum && + !(inputs_size <= kStridedSliceMultiInputNumMax && inputs_size >= kStridedSliceMultiInputNumMin)) { + return NNACL_PARAM_INVALID; + } + if (parameter == NULL || outputs[0] == NULL || inputs[0] == NULL) { + return NNACL_NULL_PTR; + } + if (inputs_size >= kStridedSliceMultiInputNumMin) { + bool begins_type_ok = + (inputs[C1NUM]->data_type_ == kNumberTypeInt32) || (inputs[C1NUM]->data_type_ == kNumberTypeInt64); + bool ends_type_ok = + (inputs[C2NUM]->data_type_ == kNumberTypeInt32) || (inputs[C2NUM]->data_type_ == kNumberTypeInt64); + if (!(begins_type_ok && ends_type_ok)) { + return NNACL_PARAM_INVALID; + } + } + return NNACL_OK; +} + +void Bit2Vector(StridedSliceTransferBuffer *transfer_buffer, const StridedSliceParameter *param) { + for (unsigned i = 0; i < (unsigned)(size_t)(transfer_buffer->ndim_); i++) { + transfer_buffer->begins_mask_[i] = (unsigned)(param->begins_mask_) & (1 << i); + transfer_buffer->ends_mask_[i] = (unsigned)(param->ends_mask_) & (1 << i); + transfer_buffer->ellipsis_mask_[i] = (unsigned)(param->ellipsisMask_) & (1 << i); + transfer_buffer->new_axis_mask_[i] = (unsigned)(param->newAxisMask_) & (1 << i); + transfer_buffer->shrink_axis_mask_[i] = (unsigned)(param->shrinkAxisMask_) & (1 << i); + } +} + +int ApplyNewAxisMask(StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param, int *in_shape, + size_t *out_shape_size) { + for (size_t i = 0; i < transfer_buffer->new_axis_mask_size_; i++) { + if (transfer_buffer->new_axis_mask_[i]) { + if (*out_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int ret = ShapeInsert(in_shape, out_shape_size, i, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + transfer_buffer->begins_[i] = 0; + transfer_buffer->ends_[i] = 1; + transfer_buffer->strides_[i] = 1; + + ShapePush(transfer_buffer->begins_, &transfer_buffer->begins_size_, 0); + ShapePush(transfer_buffer->ends_, &transfer_buffer->ends_size_, in_shape[(size_t)(transfer_buffer->ndim_) - 1]); + ShapePush(transfer_buffer->strides_, &transfer_buffer->strides_size_, 1); + + transfer_buffer->begins_mask_[i] = false; + transfer_buffer->ends_mask_[i] = false; + transfer_buffer->ellipsis_mask_[i] = false; + transfer_buffer->shrink_axis_mask_[i] = false; + } + } + return NNACL_OK; +} + +void ApplyBeginMask(StridedSliceTransferBuffer *transfer_buffer) { + for (int i = 0; i < transfer_buffer->ndim_; i++) { + if (transfer_buffer->begins_mask_[i]) { + transfer_buffer->begins_[i] = transfer_buffer->strides_[i] > 0 ? 0 : -1; + } + } +} + +int ApplyEndMask(StridedSliceTransferBuffer *transfer_buffer, const int *in_shape, size_t in_shape_size) { + for (int i = 0; i < transfer_buffer->ndim_; i++) { + if (transfer_buffer->ends_mask_[i]) { + if ((size_t)i >= in_shape_size) { + return NNACL_ERR; + } + transfer_buffer->ends_[i] = transfer_buffer->strides_[i] > 0 ? in_shape[i] : -1 - in_shape[i]; + } + } + return NNACL_OK; +} + +int ApplyEllipsisMask(StridedSliceTransferBuffer *transfer_buffer, const int *in_shape, size_t in_shape_size) { + for (size_t i = 0; i < transfer_buffer->ellipsis_mask_size_; i++) { + if (transfer_buffer->ellipsis_mask_[i]) { + if (i >= in_shape_size) { + return NNACL_ERR; + } + transfer_buffer->begins_[i] = 0; + transfer_buffer->ends_[i] = in_shape[i]; + break; + } + } + return NNACL_OK; +} + +int TransIndexToPositive(StridedSliceTransferBuffer *transfer_buffer, const int *in_shape, size_t max_shape_size, + size_t in_shape_size) { + for (size_t i = 0; i < transfer_buffer->begins_size_; i++) { + if (i >= max_shape_size) { + return NNACL_ERR; + } + if (transfer_buffer->begins_[i] < 0) { + transfer_buffer->begins_[i] += in_shape[i]; + } + if (transfer_buffer->ends_[i] < 0) { + transfer_buffer->ends_[i] += in_shape[i]; + } + if (i < in_shape_size) { + if (transfer_buffer->begins_[i] < 0 || transfer_buffer->begins_[i] > in_shape[i]) { + return NNACL_ERR; + } + if ((transfer_buffer->ends_[i] < 0 && transfer_buffer->ends_[i] != -1) || + transfer_buffer->ends_[i] > in_shape[i]) { + return NNACL_ERR; + } + } + } + return NNACL_OK; +} + +void ApplyShrinkMask(StridedSliceTransferBuffer *transfer_buffer, int *output_shape, size_t *output_shape_size) { + int old_out_shape[MAX_SHAPE_SIZE] = {0}; + size_t old_out_shape_size = 0; + ShapeSet(old_out_shape, &old_out_shape_size, output_shape, *output_shape_size); + *output_shape_size = 0; + for (size_t i = 0; i < transfer_buffer->shrink_axis_mask_size_; i++) { + if (transfer_buffer->shrink_axis_mask_[i]) { + transfer_buffer->ends_[i] = transfer_buffer->begins_[i] + 1; + transfer_buffer->strides_[i] = 1; + } else { + ShapePush(output_shape, output_shape_size, old_out_shape[i]); + } + } + for (size_t i = transfer_buffer->shrink_axis_mask_size_; i < old_out_shape_size; i++) { + ShapePush(output_shape, output_shape_size, old_out_shape[i]); + } +} + +int TransferBuffer2Param(const StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param, + const int *in_shape, size_t in_shape_size) { + if (transfer_buffer->ndim_ >= (int)(in_shape_size) || param->in_shape_length_ >= (int)(in_shape_size)) { + return NNACL_ERR; + } + for (int i = 0; i < transfer_buffer->ndim_; i++) { + param->begins_[i] = transfer_buffer->begins_[i]; + param->ends_[i] = transfer_buffer->ends_[i]; + param->in_shape_[i] = in_shape[i]; + param->strides_[i] = transfer_buffer->strides_[i]; + } + + for (int i = transfer_buffer->ndim_; i < param->in_shape_length_; i++) { + param->begins_[i] = 0; + param->ends_[i] = in_shape[i]; + param->in_shape_[i] = in_shape[i]; + param->strides_[i] = 1; + } + return NNACL_OK; +} + +void InitStridedSliceTransferBuffer(StridedSliceTransferBuffer *transfer_buffer) { + transfer_buffer->begins_size_ = 0; + transfer_buffer->ends_size_ = 0; + transfer_buffer->strides_size_ = 0; + transfer_buffer->ellipsis_mask_size_ = 0; + transfer_buffer->new_axis_mask_size_ = 0; + transfer_buffer->shrink_axis_mask_size_ = 0; +} + +void SetMaskSize(StridedSliceTransferBuffer *transfer_buffer) { + transfer_buffer->ellipsis_mask_size_ = (size_t)(transfer_buffer->ndim_); + transfer_buffer->new_axis_mask_size_ = (size_t)(transfer_buffer->ndim_); + transfer_buffer->shrink_axis_mask_size_ = (size_t)(transfer_buffer->ndim_); + transfer_buffer->begins_size_ = (size_t)(transfer_buffer->ndim_); + transfer_buffer->ends_size_ = (size_t)(transfer_buffer->ndim_); + transfer_buffer->strides_size_ = (size_t)(transfer_buffer->ndim_); +} + +// note: begin, end, stride length are equal, but may less than rank of input +int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = StrideSlicePreCheck(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + SetDataTypeFormat(outputs[0], inputs[0]); + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int in_shape[MAX_SHAPE_SIZE] = {0}; + size_t in_shape_size = 0; + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + ShapeSet(in_shape, &in_shape_size, input->shape_, input->shape_size_); + + StridedSliceTransferBuffer transfer_buffer; + InitStridedSliceTransferBuffer(&transfer_buffer); + + StridedSliceParameter *param = (StridedSliceParameter *)parameter; + + transfer_buffer.ndim_ = 0; + if (inputs_size == kStridedSliceInputNum) { + transfer_buffer.ndim_ = (int)(in_shape_size); + if (transfer_buffer.ndim_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + for (int i = 0; i < transfer_buffer.ndim_; i++) { + ShapePush(transfer_buffer.begins_, &transfer_buffer.begins_size_, param->begins_[i]); + ShapePush(transfer_buffer.ends_, &transfer_buffer.ends_size_, param->ends_[i]); + ShapePush(transfer_buffer.strides_, &transfer_buffer.strides_size_, param->strides_[i]); + } + } + if (!CheckInputs(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (inputs_size == 4) { + int ret = HandleAxesInputNotExist(inputs, &transfer_buffer); + if (ret != NNACL_OK) { + return ret; + } + } + + if (inputs_size == 5) { + int ret = HandleAxesInputExist(inputs, &transfer_buffer.ndim_, in_shape, transfer_buffer.begins_, + transfer_buffer.strides_, transfer_buffer.ends_); + if (ret != NNACL_OK) { + return ret; + } + } + + // set all mask to original input shape + SetMaskSize(&transfer_buffer); + Bit2Vector(&transfer_buffer, param); + int ret = ApplyNewAxisMask(&transfer_buffer, param, in_shape, &in_shape_size); + if (ret != NNACL_OK) { + return ret; + } + + // update parameter with new input shape + param->num_axes_ = (int)(in_shape_size); + param->in_shape_length_ = (int)(in_shape_size); + + ApplyBeginMask(&transfer_buffer); + ret = ApplyEndMask(&transfer_buffer, in_shape, MAX_SHAPE_SIZE); + if (ret != NNACL_OK) { + return ret; + } + ret = ApplyEllipsisMask(&transfer_buffer, in_shape, MAX_SHAPE_SIZE); + if (ret != NNACL_OK) { + return ret; + } + + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = 0; + ShapeSet(output_shape, &output_shape_size, in_shape, in_shape_size); + ret = TransIndexToPositive(&transfer_buffer, in_shape, MAX_SHAPE_SIZE, input->shape_size_); + if (ret != NNACL_OK) { + return ret; + } + for (int i = 0; i < transfer_buffer.ndim_; i++) { + if (transfer_buffer.strides_[i] == 0 || in_shape[i] < transfer_buffer.ends_[i]) { + return NNACL_ERR; + } + output_shape[i] = (transfer_buffer.ends_[i] - transfer_buffer.begins_[i] + transfer_buffer.strides_[i] + + (transfer_buffer.strides_[i] < 0 ? 1 : -1)) / + transfer_buffer.strides_[i]; + output_shape[i] = output_shape[i] > 0 ? output_shape[i] : 0; + } + ApplyShrinkMask(&transfer_buffer, output_shape, &output_shape_size); + SetShapeArray(outputs[0], output_shape, output_shape_size); + ret = TransferBuffer2Param(&transfer_buffer, param, in_shape, MAX_SHAPE_SIZE); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +REG_INFER(StridedSlice, PrimType_StridedSlice, StridedSliceInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..492f5b2abd9cf6ed120934a75c53748fbb1c5ca0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_STRIDED_SLICE_INFER_H +#define MINDSPORE_NNACL_STRIDED_SLICE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/strided_slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_STRIDED_SLICE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_extract_features_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_extract_features_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..889d2daf0fc81b3e3ff21d837fa5bed0f5066612 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_extract_features_infer.c @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/string/custom_extract_features_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CustomExtractFeaturesInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output0 = outputs[0]; + TensorC *output1 = outputs[1]; + + output0->data_type_ = kNumberTypeInt32; + output0->format_ = input->format_; + output1->data_type_ = kNumberTypeFloat32; + output1->format_ = input->format_; + + if (input->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int string_num = *((const int32_t *)(input->data_)); + + int res = (string_num == 0 ? 1 : string_num); + output0->shape_size_ = 1; + output0->shape_[0] = res; + output1->shape_size_ = 1; + output1->shape_[0] = res; + return NNACL_OK; +} + +REG_INFER(CustomExtractFeatures, PrimType_CustomExtractFeatures, CustomExtractFeaturesInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_extract_features_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_extract_features_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..861492291a79c5ac4136042cb71ddf7a811b920f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_extract_features_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_CUSTOM_EXTRACT_FEATURES_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_CUSTOM_EXTRACT_FEATURES_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomExtractFeaturesInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_CUSTOM_EXTRACT_FEATURES_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_normalize_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_normalize_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..4cbb57eeab2ea59c5e796eef63d720d34399918b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_normalize_infer.c @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/string/custom_normalize_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int CustomNormalizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + + if (input->data_ == NULL) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(input) < 1) { + return NNACL_ERR; + } + if (input->data_type_ != kNumberTypeUInt32 && input->data_type_ != kObjectTypeString) { + return NNACL_ERR; + } + int string_num = *((const int32_t *)(input->data_)); // also look custom_extract_features + output->shape_size_ = 1; + output->shape_[0] = (string_num == 0 ? 1 : string_num); + return NNACL_OK; +} + +REG_INFER(CustomNormalize, PrimType_CustomNormalize, CustomNormalizeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_normalize_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_normalize_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..cd45bd5b32a2b0218ad2f5243ec1d124b689e34d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_normalize_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_CUSTOM_NORMALIZE_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_CUSTOM_NORMALIZE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/softmax_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomNormalizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_CUSTOM_NORMALIZE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_predict_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_predict_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..f0524c9685734f53a565509e2cbf0b08134a76bc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_predict_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/string/custom_predict_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CustomPredictInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output0 = outputs[0]; + TensorC *output1 = outputs[1]; + + CustomPredictParameter *param = (CustomPredictParameter *)parameter; + output0->shape_size_ = 1; + output0->shape_[0] = param->output_num; + output0->data_type_ = kNumberTypeInt32; + output0->format_ = input->format_; + output1->shape_size_ = 1; + output1->shape_[0] = param->output_num; + output1->data_type_ = kNumberTypeFloat32; + output1->format_ = input->format_; + return NNACL_OK; +} + +REG_INFER(CustomPredict, PrimType_CustomPredict, CustomPredictInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_predict_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_predict_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..1d6410f6169bc69e7dcb2b45119297d1039292f8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_predict_infer.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_CUSTOM_PREDICT_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_CUSTOM_PREDICT_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct CustomPredictParameter { + OpParameter op_parameter_; + int output_num; +} CustomPredictParameter; + +int CustomPredictInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_CUSTOM_PREDICT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/hashtable_lookup_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/hashtable_lookup_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..c26d5c21463d796370349b69a8d16896303e1d58 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/hashtable_lookup_infer.c @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/string/hashtable_lookup_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int HashtableLoopupInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + const TensorC *values = inputs[2]; + if (input == NULL || values == NULL) { + return NNACL_NULL_PTR; + } + + TensorC *output = outputs[0]; + TensorC *hits = outputs[1]; + + output->data_type_ = values->data_type_; + output->format_ = input->format_; + hits->shape_size_ = 1; + hits->shape_[0] = NNACLGetDimensionSize(input, 0); + hits->data_type_ = kNumberTypeUInt8; + hits->format_ = input->format_; + + if (input->data_ == NULL) { + return NNACL_INFER_INVALID; + } + return NNACL_OK; +} + +REG_INFER(HashtableLookup, PrimType_HashtableLookup, HashtableLoopupInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/hashtable_lookup_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/hashtable_lookup_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..9879e0e22c7f9d0bfe29aa50ce6e81f4dbeecb59 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/hashtable_lookup_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_HASHTABLE_LOOKUP_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_HASHTABLE_LOOKUP_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int HashtableLoopupInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_HASHTABLE_LOOKUP_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/lsh_projection_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/lsh_projection_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..0422e8001880adfc45d720b2206f4bb5596b4eba --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/lsh_projection_infer.c @@ -0,0 +1,53 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/string/lsh_projection_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int LshProjectionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in_hash = inputs[0]; + if (in_hash->shape_size_ != 2 || NNACLGetDimensionSize(in_hash, 1) > 32) { + return NNACL_ERR; + } + TensorC *out_tensor = outputs[0]; + out_tensor->data_type_ = kNumberTypeInt32; + out_tensor->format_ = Format_NHWC; + + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + LshProjectionParameter *param = (LshProjectionParameter *)parameter; + switch (param->lsh_type_) { + case LshProjectionType_SPARSE: + ShapePush(out_shape, &out_shape_size, NNACLGetDimensionSize(in_hash, 0)); + break; + case LshProjectionType_DENSE: + ShapePush(out_shape, &out_shape_size, NNACLGetDimensionSize(in_hash, 0) * NNACLGetDimensionSize(in_hash, 1)); + break; + default: + return NNACL_ERR; + } + SetShapeArray(out_tensor, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(LshProjection, PrimType_LshProjection, LshProjectionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/lsh_projection_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/lsh_projection_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..eb98f7e8b48d6924d84d325399cd265ae290c729 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/lsh_projection_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_LSH_PROJECTION_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_LSH_PROJECTION_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/lsh_projection_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LshProjectionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_LSH_PROJECTION_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/skip_gram_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/skip_gram_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..ebc454753742e080c363543bdd1f4ca3a8a7cb09 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/skip_gram_infer.c @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/string/skip_gram_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SkipGramInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (input->data_ == NULL) { + return NNACL_INFER_INVALID; + } + return NNACL_OK; +} + +REG_INFER(SkipGram, PrimType_SkipGram, SkipGramInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/skip_gram_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/skip_gram_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..5fd2f1762bf008496fefacc8da5bfcd57250f8e9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/skip_gram_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_SKIP_GRAM_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_SKIP_GRAM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SkipGramInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_SKIP_GRAM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/tile_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/tile_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..e191b2b476f098a56bd7b167d133eb07131e260a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/tile_infer.c @@ -0,0 +1,111 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/tile_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tile_parameter.h" +#include "nnacl_c/tensor_c_utils.h" + +void TileParamCaffe2Tflite(TileParameter *param, size_t out_shape_size) { + if (param->dims_size_ != 0) { + int multiples_size_tmp[5] = {0}; + NNACL_CHECK_TRUE_RET_VOID(out_shape_size <= 5); + for (size_t i = 0; i < out_shape_size; i++) { + multiples_size_tmp[i] = 1; + } + for (size_t i = 0; i < param->dims_size_; i++) { + if (i >= MAX_SHAPE_SIZE) { + return; + } + multiples_size_tmp[param->dims_[i]] = param->multiples_[i]; + } + for (size_t i = 0; i < 5; i++) { + param->multiples_[i] = multiples_size_tmp[i]; + } + } +} + +int TileInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + TileParameter *param = (TileParameter *)parameter; + + size_t multiples_size = 0; + int input1_shape_size = inputs[1]->shape_size_; + if (input1_shape_size > (int)(input->shape_size_) || input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + NNACL_CHECK_TRUE_RET(input1_shape_size <= MAX_SHAPE_SIZE, NNACL_ERR); + int data_num = NNACLGetElementNum(inputs[1]); + multiples_size = (size_t)(data_num); + if (inputs[1]->data_type_ != kNumberTypeInt && inputs[1]->data_type_ != kNumberTypeInt32) { + return NNACL_INPUT_TENSOR_ERROR; + } + int *input1_data = inputs[1]->data_; + if (input1_data == NULL) { + return NNACL_INFER_INVALID; + } + NNACL_CHECK_TRUE_RET(data_num <= MAX_SHAPE_SIZE, NNACL_ERR); + for (int i = 0; i < data_num; i++) { + param->multiples_[i] = input1_data[i]; + } + + int *dims = param->dims_; + size_t dims_size = param->dims_size_; + if (dims_size == 0) { + int dim_num = NNACLGetElementNum(inputs[1]); + NNACL_CHECK_TRUE_RET(dim_num <= MAX_SHAPE_SIZE, NNACL_ERR); + for (int dim = 0; dim < dim_num; ++dim) { + ShapePush(dims, &dims_size, dim); + } + param->dims_size_ = dims_size; + } + NNACL_CHECK_TRUE_RET(multiples_size == dims_size, NNACL_ERR); + for (size_t i = 0; i < input->shape_size_; ++i) { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + for (size_t i = 0; i < dims_size; ++i) { + if (dims[i] >= MAX_SHAPE_SIZE || input->shape_[dims[i]] == 0) { + return NNACL_ERR; + } + if (input->shape_[dims[i]] != 0 && param->multiples_[i] > INT_MAX / input->shape_[dims[i]]) { + return NNACL_ERR; + } + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(input->shape_[dims[i]], (param->multiples_[i])), NNACL_ERR); + out_shape[dims[i]] = input->shape_[dims[i]] * (param->multiples_[i]); + } + // change caffe param format to tflite + TileParamCaffe2Tflite(param, out_shape_size); + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(Tile, PrimType_TileFusion, TileInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/tile_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/tile_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..ae1aedbc525199165981a232a78e3984d3e5d16b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/tile_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_TILE_INFER_H +#define MINDSPORE_NNACL_TILE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/base/tile_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TileInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_TILE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/topk_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/topk_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..fac77961b5839433d7debf116e2c8edb2d73ccd6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/topk_infer.c @@ -0,0 +1,66 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/topk_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int TopKInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output0 = outputs[0]; + TensorC *output1 = outputs[1]; + SetDataTypeFormat(output0, input); + output1->data_type_ = kNumberTypeInt32; + output1->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *input_k_tensor = inputs[1]; + if (input_k_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + + TopkParameter *param = (TopkParameter *)parameter; + param->k_ = ((int32_t *)input_k_tensor->data_)[0]; + + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, input->shape_, input->shape_size_); + if (out_shape_size < 1) { + return NNACL_ERR; + } + if (param->axis_ < 0) { + param->axis_ += (int)out_shape_size; + } + if (param->axis_ < 0 || (size_t)param->axis_ >= out_shape_size) { + return NNACL_ERR; + } + out_shape[(size_t)param->axis_] = param->k_; + + SetShapeArray(output0, out_shape, out_shape_size); + SetShapeArray(output1, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(TopK, PrimType_TopKFusion, TopKInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/topk_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/topk_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..a08f06b23f810c0d315814883b23b3047e3ffd3d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/topk_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_TOPK_INFER_H +#define MINDSPORE_NNACL_TOPK_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/topk_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TopKInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_TOPK_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/transpose_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/transpose_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..7bfa607778a65762665ee9a27cc677d24576de49 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/transpose_infer.c @@ -0,0 +1,137 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/transpose_infer.h" +#include "nnacl_c/infer/infer_register.h" + +bool CheckPermTransFormat(const int *perm, const int *perm_transformat, const int size) { + for (int i = 0; i < size; ++i) { + if (perm[i] != perm_transformat[i]) { + return false; + } + } + return true; +} + +int SetOutputShape(int perms_num, const TensorC *input, TensorC *output, const int *perm, size_t perm_size, + int *out_shape) { + // set output shape + size_t in_shape_size = input->shape_size_; + output->shape_size_ = in_shape_size; + if (perm_size == 0) { + for (size_t i = 0; i < in_shape_size; ++i) { + out_shape[in_shape_size - i - 1] = input->shape_[i]; + } + } else if (perm_size != in_shape_size) { + for (size_t i = 0; i < in_shape_size; ++i) { + out_shape[i] = input->shape_[i]; + } + } else { + output->shape_size_ = perm_size; + for (size_t i = 0; i < perm_size; ++i) { + if (perm[i] >= input->shape_size_) { + return NNACL_ERR; + } else { + out_shape[i] = input->shape_[perm[i]]; + } + } + } + return NNACL_OK; +} + +int GetAndCheckPerm(const TensorC *perm_tensor, const int perms_num, int *perm, size_t *perm_size) { + if (perms_num >= MAX_TRANSPOSE_DIM_SIZE) { + return NNACL_TRANSPOSE_PERM_DIMS_INVALID; + } + + int ret = GetInt32DataFromTensor(perm_tensor, perm, perm_size); + if (ret != NNACL_OK) { + return ret; + } + for (size_t i = 0; i < *perm_size; i++) { + NNACL_CHECK_TRUE_RET(perm[i] < perms_num, NNACL_ERR); + } + return NNACL_OK; +} + +void Handle4DPerm(const TensorC *input, TensorC *output, int *perm, size_t *perm_size) { + const int nchw2nhwc[4] = {Index0, Index2, Index3, Index1}; + const int nhwc2nchw[4] = {Index0, Index3, Index1, Index2}; + const int trans3d[3] = {Index0, Index2, Index1}; + if (input->format_ == Format_NCHW && CheckPermTransFormat(perm, nchw2nhwc, PERM_NUM_FOUR)) { + output->format_ = Format_NHWC; + } else if ((input->format_ == Format_NHWC || input->format_ == Format_KHWC) && + CheckPermTransFormat(perm, nhwc2nchw, PERM_NUM_FOUR)) { + output->format_ = Format_NCHW; + } + // though the perm is 4d in default, the input can be a 3d tensor. The op implementation must be adapted to this. + if (input->shape_size_ == DIMENSION_3D) { + ShapeSet(perm, perm_size, trans3d, DIMENSION_3D); + } +} + +int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + const TensorC *perm_tensor = inputs[1]; + if (perm_tensor == NULL) { + return NNACL_INFER_INVALID; + } + NNACL_CHECK_TRUE_RET(perm_tensor->shape_size_ == 1, NNACL_INFER_INVALID); + const int perms_num = perm_tensor->shape_[0]; + if (perms_num != 0 && perm_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + TransposeParameter *transpose_param = (TransposeParameter *)parameter; + transpose_param->perm_size_ = perms_num; + int perm[MAX_TRANSPOSE_DIM_SIZE] = {0}; + size_t perm_size = 0; + int ret = GetAndCheckPerm(perm_tensor, perms_num, perm, &perm_size); + if (ret != NNACL_OK) { + return ret; + } + + if (perms_num == PERM_NUM_FOUR) { + Handle4DPerm(input, output, perm, &perm_size); + } + int kPermIndex0 = 0; + int kPermIndex2 = 2; + if (perms_num == PERM_NUM_THREE && perm[0] == kPermIndex0 && perm[1] == kPermIndex2) { + output->format_ = input->format_ == Format_NCHW ? Format_NHWC : Format_NCHW; + } + if (parameter->quant_type_ == Quant_QuantWeight) { + output->data_type_ = kNumberTypeFloat32; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + // set output shape + int out_shape[MAX_TRANSPOSE_DIM_SIZE] = {0}; + SetOutputShape(perms_num, input, output, perm, perm_size, out_shape); + SetShapeArray(output, out_shape, output->shape_size_); + return NNACL_OK; +} + +REG_INFER(Transpose, PrimType_Transpose, TransposeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/transpose_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/transpose_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..2557fcbde82fdd4a5741fa6c6ac0da1833cca429 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/transpose_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_TRANSPOSE_INFER_H +#define MINDSPORE_NNACL_TRANSPOSE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/transpose_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_TRANSPOSE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/triu_tril_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/triu_tril_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..faa45c8cc5d5f2b235e4684cee066a230b7e4db9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/triu_tril_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/triu_tril_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int TriuTrilInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + const size_t triul_input_min_size = 1; + const size_t triul_output_size = 1; + if (inputs_size < triul_input_min_size || outputs_size != triul_output_size) { + return NNACL_ERR; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(Triu, PrimType_Triu, TriuTrilInferShape) +REG_INFER(Tril, PrimType_Tril, TriuTrilInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/triu_tril_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/triu_tril_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..5dd85d45f63fc97cadebda65fc3b6d570dc158b0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/triu_tril_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_TRIU_TRIL_INFER_H +#define MINDSPORE_NNACL_TRIU_TRIL_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/arg_min_max_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TriuTrilInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_TRIU_TRIL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/uniform_real_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/uniform_real_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..4614c7acd76e3e8a05d4ca4c5e6accd328bdbefd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/uniform_real_infer.c @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/uniform_real_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int UniformRealInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (ret != NNACL_OK) { + return ret; + } + outputs[0]->data_type_ = kNumberTypeFloat32; + outputs[0]->format_ = inputs[0]->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int32_t *input_data = (int32_t *)(inputs[0]->data_); + if (input_data == NULL) { + return NNACL_INFER_INVALID; + } + int input_num = NNACLGetElementNum(inputs[0]); + if (input_num > MAX_SHAPE_SIZE || input_num < 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = (size_t)(input_num); + for (int i = 0; i < input_num; i++) { + output_shape[i] = input_data[i]; + } + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(UniformReal, PrimType_UniformReal, UniformRealInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/uniform_real_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/uniform_real_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d3aad6a79aec14d9e54b53fcbf83489804e80039 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/uniform_real_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_UNIFORM_REAL_INFER_H +#define MINDSPORE_NNACL_UNIFORM_REAL_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int UniformRealInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_UNIFORM_REAL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unique_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unique_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..ab24aebe5c9680774905a8f586c733ec9e295cac --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unique_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/unique_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int UniqueInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 2); + if (ret != NNACL_OK) { + return ret; + } + + const TensorC *input0 = inputs[0]; + TensorC *output0 = outputs[0]; + TensorC *output1 = outputs[1]; + + SetDataTypeFormat(output0, input0); + output1->data_type_ = kNumberTypeInt32; + output1->format_ = input0->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output0, input0); + SetShapeTensor(output1, input0); + return NNACL_OK; +} + +REG_INFER(Unique, PrimType_Unique, UniqueInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unique_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unique_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..b97e37decf31badd5248d6d55b89562fe6707142 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unique_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_UNIQUE_INFER_H +#define MINDSPORE_NNACL_UNIQUE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int UniqueInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_UNIQUE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsorted_segment_sum_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsorted_segment_sum_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..4fdfe03f6a8fb342dc6d59901d53e3f4e74c7be3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsorted_segment_sum_infer.c @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/unsorted_segment_sum_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int UnsortedSegmentSumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorC *out = outputs[0]; + const TensorC *x = inputs[0]; + const TensorC *segment_id = inputs[1]; + if (inputs[2]->data_ == NULL || + (inputs[2]->data_type_ != kNumberTypeInt && inputs[2]->data_type_ != kNumberTypeInt32)) { + return NNACL_INPUT_TENSOR_ERROR; + } + int num_segments = *(int *)(inputs[2]->data_); + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + ShapePush(output_shape, &output_shape_size, num_segments); + for (int index = (int)(segment_id->shape_size_); index < (int)(x->shape_size_); index++) { + if (output_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + ShapePush(output_shape, &output_shape_size, x->shape_[index]); + } + SetShapeArray(out, output_shape, output_shape_size); + SetDataTypeFormat(out, x); + return NNACL_OK; +} + +REG_INFER(UnsortedSegmentSum, PrimType_UnsortedSegmentSum, UnsortedSegmentSumInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsorted_segment_sum_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsorted_segment_sum_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..f6332f5a2e9c6f273e773cbaf2dab32a90142ea8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsorted_segment_sum_infer.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_UNSORTED_SEGMENT_SUM_INFER_H +#define MINDSPORE_NNACL_UNSORTED_SEGMENT_SUM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct UnsortedSegmentSumParameter { + OpParameter op_parameter_; + int segments_num_; +} UnsortedSegmentSumParameter; + +int UnsortedSegmentSumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_UNSORTED_SEGMENT_SUM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsqueeze_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsqueeze_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..3bbf1d281adfb2b8e53c14a8cffcf241f58f9843 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsqueeze_infer.c @@ -0,0 +1,79 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/unsqueeze_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int UnsqueezeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + UnSqueezeParameter *param = (UnSqueezeParameter *)parameter; + int in_rank = (int)(input->shape_size_); + int dim_rank = param->num_dim_; + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + if (dim_rank == 0) { + for (size_t i = 0; i < input->shape_size_; i++) { + if (input->shape_[i] != 1) { + if (out_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + } + } else { + int sz = in_rank + dim_rank; + size_t in_itr = 0; + size_t ax_itr = 0; + if (sz < 0) { + return NNACL_ERR; + } + for (int i = 0; i < sz; i++) { + if (out_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + if (ax_itr < (size_t)(dim_rank) && param->dims_[ax_itr] == (int)(i)) { + ShapePush(out_shape, &out_shape_size, 1); + ax_itr++; + } else if (ax_itr < (size_t)(dim_rank) && param->dims_[ax_itr] + sz == i) { + ShapePush(out_shape, &out_shape_size, 1); + ax_itr++; + } else { + if (in_itr >= input->shape_size_) { + return NNACL_ERR; + } + ShapePush(out_shape, &out_shape_size, input->shape_[in_itr]); + in_itr++; + } + } + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(Unsqueeze, PrimType_Unsqueeze, UnsqueezeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsqueeze_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsqueeze_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..fa7b96e4f264b5f1f2743baacb620c3c253f49d4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsqueeze_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_UNSQUEEZE_INFER_H +#define MINDSPORE_NNACL_UNSQUEEZE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/unsqueeze_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int UnsqueezeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_UNSQUEEZE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unstack_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unstack_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..b604c1b341b4ce58cb6543346ee1468ffdc55585 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unstack_infer.c @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/unstack_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int UnstackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + UnstackParameter *param = (UnstackParameter *)parameter; + int axis = param->axis_ < 0 ? param->axis_ + (int)(input->shape_size_) : param->axis_; + if (axis < 0 || axis >= (int)(input->shape_size_)) { + return NNACL_PARAM_INVALID; + } + for (size_t i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], input); + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + for (size_t i = 0; i < input->shape_size_; ++i) { + if (i != (size_t)(axis)) { + if (output_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + ShapePush(output_shape, &output_shape_size, input->shape_[i]); + } + } + for (size_t i = 0; i < outputs_size; i++) { + if (outputs[i] == NULL) { + return NNACL_NULL_PTR; + } + SetShapeArray(outputs[i], output_shape, output_shape_size); + } + return NNACL_OK; +} + +REG_INFER(Unstack, PrimType_Unstack, UnstackInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unstack_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unstack_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..386447e83e1a2727b0f365c3e48e7cc6d531883d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unstack_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_UNSTACK_INFER_H +#define MINDSPORE_NNACL_UNSTACK_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/unstack_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int UnstackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_UNSTACK_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/where_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/where_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..06a10cbe8e124e58af55f057c0dfef7f41603750 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/where_infer.c @@ -0,0 +1,91 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/where_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/infer/broadcast_to_infer.h" + +int WhereBroadCastInferShape(const int input_shape0_size, const int input_shape1_size, const int *input_shape0, + const int *input_shape1, int *ndim, int *in_shape0, int *in_shape1, int *out_shape, + bool *has_broad_cast) { + if (input_shape0_size > MAX_SHAPE_SIZE || input_shape1_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + MakeUpInputShapes(input_shape0_size, input_shape1_size, input_shape0, input_shape1, ndim, in_shape0, in_shape1); + if (*ndim >= MAX_SHAPE_SIZE) { + return NNACL_INFER_INVALID; + } + return BroadCastOutputShape(in_shape0, in_shape1, *ndim, out_shape, has_broad_cast); +} + +int WhereInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input0 = inputs[0]; + TensorC *output = outputs[0]; + NNACL_CHECK_NULL_RETURN_ERR(input0); + NNACL_CHECK_NULL_RETURN_ERR(output); + + // Need to dynamically allocate at runtime. + if (inputs_size == 1) { + output->data_type_ = kNumberTypeInt32; + output->format_ = input0->format_; + return NNACL_INFER_INVALID; + } + + if (inputs_size < 3 || outputs_size != 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input1 = inputs[1]; + const TensorC *input2 = inputs[2]; + NNACL_CHECK_NULL_RETURN_ERR(input1); + NNACL_CHECK_NULL_RETURN_ERR(input2); + SetDataTypeFormat(output, input1); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int in_shape0[MAX_SHAPE_SIZE] = {0}; + int in_shape1[MAX_SHAPE_SIZE] = {0}; + int in_shape2[MAX_SHAPE_SIZE] = {0}; + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape0_size = input0->shape_size_; + size_t input_shape1_size = input1->shape_size_; + size_t input_shape2_size = input2->shape_size_; + const int *input_shape0 = input0->shape_; + const int *input_shape1 = input1->shape_; + const int *input_shape2 = input2->shape_; + int ndim = (int)input_shape0_size; + bool has_broad_cast_1 = false; + bool has_broad_cast_2 = false; + if (WhereBroadCastInferShape(input_shape0_size, input_shape1_size, input_shape0, input_shape1, &ndim, in_shape0, + in_shape1, output_shape, &has_broad_cast_1) != NNACL_OK) { + return NNACL_ERR; + } + if (WhereBroadCastInferShape(ndim, input_shape2_size, output_shape, input_shape2, &ndim, in_shape0, in_shape2, + output_shape, &has_broad_cast_2) != NNACL_OK) { + return NNACL_ERR; + } + ShapeSet(output->shape_, &output->shape_size_, output_shape, ndim); + return NNACL_OK; +} + +REG_INFER(Where, PrimType_Where, WhereInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/where_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/where_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..6dadfc79afda6b2fbec6c84a212472599f6adaed --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/where_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_WHERE_INFER_H +#define MINDSPORE_NNACL_WHERE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int WhereInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_WHERE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/instance_norm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/instance_norm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..daa82e28312fed57dcb489f9249af94e68240498 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/instance_norm_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INSTANCE_NORM_PARAMETER_H_ +#define NNACL_INSTANCE_NORM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct InstanceNormParameter { + // Primitive parameter + OpParameter op_parameter_; + float epsilon_; + // shape correlative + int batch_; + int channel_; + int inner_size_; +} InstanceNormParameter; + +#endif // NNACL_INSTANCE_NORM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/add_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/add_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..a1f54b43c1c58b6f793ce2fb3b6251190c843238 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/add_int8.c @@ -0,0 +1,531 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/add_int8.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#ifdef ENABLE_AVX +#include "nnacl_c/intrinsics/avx/common_utils.h" +#endif +#include "nnacl_c/int8/fixed_point.h" + +#ifdef ENABLE_ARM +void AddInt8InputRounding(int32x4_t *in1, int32x4_t *in2, int32x4_t *in3, int32x4_t *in4, const int32x4_t left_vec, + const int32x4_t right_vec, const int32_t multiplier) { + // Apply left shift + *in1 = vmulq_s32(*in1, left_vec); + *in2 = vmulq_s32(*in2, left_vec); + *in3 = vmulq_s32(*in3, left_vec); + *in4 = vmulq_s32(*in4, left_vec); + + // Apply the fixed-point part of the multiplier. + *in1 = vqrdmulhq_n_s32(*in1, multiplier); + *in2 = vqrdmulhq_n_s32(*in2, multiplier); + *in3 = vqrdmulhq_n_s32(*in3, multiplier); + *in4 = vqrdmulhq_n_s32(*in4, multiplier); + + // Apply right shift + *in1 = vqaddq_s32(*in1, vshrq_n_s32(vandq_s32(*in1, right_vec), 31)); + *in2 = vqaddq_s32(*in2, vshrq_n_s32(vandq_s32(*in2, right_vec), 31)); + *in3 = vqaddq_s32(*in3, vshrq_n_s32(vandq_s32(*in3, right_vec), 31)); + *in4 = vqaddq_s32(*in4, vshrq_n_s32(vandq_s32(*in4, right_vec), 31)); + + *in1 = vrshlq_s32(*in1, right_vec); + *in2 = vrshlq_s32(*in2, right_vec); + *in3 = vrshlq_s32(*in3, right_vec); + *in4 = vrshlq_s32(*in4, right_vec); +} + +void AddInt8OutputRounding(int32x4_t *out1, int32x4_t *out2, int32x4_t *out3, int32x4_t *out4, const int32x4_t left_vec, + const int32x4_t right_vec, const int32_t multiplier) { + // Apply left shift + *out1 = vshlq_s32(*out1, left_vec); + *out2 = vshlq_s32(*out2, left_vec); + *out3 = vshlq_s32(*out3, left_vec); + *out4 = vshlq_s32(*out4, left_vec); + + // Apply the fixed-point part of the multiplier. + *out1 = vqrdmulhq_n_s32(*out1, multiplier); + *out2 = vqrdmulhq_n_s32(*out2, multiplier); + *out3 = vqrdmulhq_n_s32(*out3, multiplier); + *out4 = vqrdmulhq_n_s32(*out4, multiplier); + + // Apply right shift + *out1 = vqaddq_s32(*out1, vshrq_n_s32(vandq_s32(*out1, right_vec), 31)); + *out2 = vqaddq_s32(*out2, vshrq_n_s32(vandq_s32(*out2, right_vec), 31)); + *out3 = vqaddq_s32(*out3, vshrq_n_s32(vandq_s32(*out3, right_vec), 31)); + *out4 = vqaddq_s32(*out4, vshrq_n_s32(vandq_s32(*out4, right_vec), 31)); + + *out1 = vrshlq_s32(*out1, right_vec); + *out2 = vrshlq_s32(*out2, right_vec); + *out3 = vrshlq_s32(*out3, right_vec); + *out4 = vrshlq_s32(*out4, right_vec); +} +#endif + +void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, const AddQuantParameter *params) { + int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_); + int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_); + int index = 0; +#ifdef ENABLE_ARM + const int8x16_t min_vec = vdupq_n_s8(params->min_); + const int8x16_t max_vec = vdupq_n_s8(params->max_); + + const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_args_.zp_); + const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_args_.zp_); + const int16x8_t out_zp_vec = vdupq_n_s16(params->out_zp_); + + const int32x4_t in0_left_vec = vdupq_n_s32(in0_left_shift); + const int32x4_t in1_left_vec = vdupq_n_s32(in1_left_shift); + + const int32x4_t in0_right_vec = vdupq_n_s32(-params->in0_args_.right_shift_); + const int32x4_t in1_right_vec = vdupq_n_s32(-params->in1_args_.right_shift_); + + const int32x4_t out_left_vec = vdupq_n_s32(params->out_left_shift_); + const int32x4_t out_right_vec = vdupq_n_s32(-params->out_right_shift_); + + for (; index <= size - 16; index += 16) { + const int8x16_t in0_src = vld1q_s8(input0 + index); + const int8x16_t in1_src = vld1q_s8(input1 + index); + + const int16x8_t in0_s16_low = vmovl_s8(vget_low_s8(in0_src)); + const int16x8_t in0_s16_high = vmovl_s8(vget_high_s8(in0_src)); + const int16x8_t in1_s16_low = vmovl_s8(vget_low_s8(in1_src)); + const int16x8_t in1_s16_high = vmovl_s8(vget_high_s8(in1_src)); + + const int16x8_t in0_zp_low = vaddq_s16(in0_s16_low, in0_zp_vec); + const int16x8_t in0_zp_high = vaddq_s16(in0_s16_high, in0_zp_vec); + const int16x8_t in1_zp_low = vaddq_s16(in1_s16_low, in1_zp_vec); + const int16x8_t in1_zp_high = vaddq_s16(in1_s16_high, in1_zp_vec); + + int32x4_t in0_1 = vmovl_s16(vget_low_s16(in0_zp_low)); + int32x4_t in0_2 = vmovl_s16(vget_high_s16(in0_zp_low)); + int32x4_t in0_3 = vmovl_s16(vget_low_s16(in0_zp_high)); + int32x4_t in0_4 = vmovl_s16(vget_high_s16(in0_zp_high)); + int32x4_t in1_1 = vmovl_s16(vget_low_s16(in1_zp_low)); + int32x4_t in1_2 = vmovl_s16(vget_high_s16(in1_zp_low)); + int32x4_t in1_3 = vmovl_s16(vget_low_s16(in1_zp_high)); + int32x4_t in1_4 = vmovl_s16(vget_high_s16(in1_zp_high)); + + AddInt8InputRounding(&in0_1, &in0_2, &in0_3, &in0_4, in0_left_vec, in0_right_vec, params->in0_args_.multiplier_); + AddInt8InputRounding(&in1_1, &in1_2, &in1_3, &in1_4, in1_left_vec, in1_right_vec, params->in1_args_.multiplier_); + + /* calculate output */ + int32x4_t out1 = vaddq_s32(in0_1, in1_1); + int32x4_t out2 = vaddq_s32(in0_2, in1_2); + int32x4_t out3 = vaddq_s32(in0_3, in1_3); + int32x4_t out4 = vaddq_s32(in0_4, in1_4); + + AddInt8OutputRounding(&out1, &out2, &out3, &out4, out_left_vec, out_right_vec, params->out_multiplier_); + + const int16x4_t out1_s16 = vmovn_s32(out1); + const int16x4_t out2_s16 = vmovn_s32(out2); + const int16x4_t out3_s16 = vmovn_s32(out3); + const int16x4_t out4_s16 = vmovn_s32(out4); + + const int16x8_t out_s16_1 = vaddq_s16(vcombine_s16(out1_s16, out2_s16), out_zp_vec); + const int16x8_t out_s16_2 = vaddq_s16(vcombine_s16(out3_s16, out4_s16), out_zp_vec); + + const int8x16_t out = vcombine_s8(vqmovn_s16(out_s16_1), vqmovn_s16(out_s16_2)); + const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vec, out)); + + vst1q_s8(output + index, int8_out); + } +#endif + for (; index < size; index++) { + const int32_t in0_left = (input0[index] + params->in0_args_.zp_) * in0_left_shift; + const int32_t in1_left = (input1[index] + params->in1_args_.zp_) * in1_left_shift; + const int32_t in0 = + MultiplyByMultiplierAndRightShift(in0_left, params->in0_args_.multiplier_, params->in0_args_.right_shift_); + const int32_t in1 = + MultiplyByMultiplierAndRightShift(in1_left, params->in1_args_.multiplier_, params->in1_args_.right_shift_); + + int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_, + -params->out_right_shift_); + out += params->out_zp_; + output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); + } + return; +} + +void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, + const AddQuantParameter *params, const AddQuantQrgs *ptr_args, const AddQuantQrgs *ele_args) { + int ptr_left_shift = (1 << params->left_shift_) * (1 << ptr_args->left_shift_); + int ele_left_shift = (1 << params->left_shift_) * (1 << ele_args->left_shift_); + int index = 0; + +#ifdef ENABLE_ARM + /* const value init */ + const int8x16_t min_vec = vdupq_n_s8(params->min_); + const int8x16_t max_vec = vdupq_n_s8(params->max_); + + const int16x8_t ptr_zp_vec = vdupq_n_s16(ptr_args->zp_); + const int16x8_t ele_zp_vec = vdupq_n_s16(ele_args->zp_); + const int16x8_t out_zp_vec = vdupq_n_s16(params->out_zp_); + + const int32x4_t ptr_left_vec = vdupq_n_s32(ptr_left_shift); + const int32x4_t ele_left_vec = vdupq_n_s32(ele_left_shift); + + const int32x4_t ptr_right_vec = vdupq_n_s32(-ptr_args->right_shift_); + const int32x4_t ele_right_vec = vdupq_n_s32(-ele_args->right_shift_); + + const int32x4_t out_left_vec = vdupq_n_s32(params->out_left_shift_); + const int32x4_t out_right_vec = vdupq_n_s32(-params->out_right_shift_); + + /* deal with const node */ + const int8x16_t ele_src = vdupq_n_s8(element_in); + const int16x8_t ele_s16_low = vmovl_s8(vget_low_s8(ele_src)); + const int16x8_t ele_s16_high = vmovl_s8(vget_high_s8(ele_src)); + const int16x8_t ele_zp_low = vaddq_s16(ele_s16_low, ele_zp_vec); + const int16x8_t ele_zp_high = vaddq_s16(ele_s16_high, ele_zp_vec); + int32x4_t ele1 = vmovl_s16(vget_low_s16(ele_zp_low)); + int32x4_t ele2 = vmovl_s16(vget_high_s16(ele_zp_low)); + int32x4_t ele3 = vmovl_s16(vget_low_s16(ele_zp_high)); + int32x4_t ele4 = vmovl_s16(vget_high_s16(ele_zp_high)); + + AddInt8InputRounding(&ele1, &ele2, &ele3, &ele4, ele_left_vec, ele_right_vec, ele_args->multiplier_); + + for (; index <= size - 16; index += 16) { + const int8x16_t ptr_src = vld1q_s8(ptr_in + index); + + const int16x8_t ptr_s16_low = vmovl_s8(vget_low_s8(ptr_src)); + const int16x8_t ptr_s16_high = vmovl_s8(vget_high_s8(ptr_src)); + + const int16x8_t ptr_zp_low = vaddq_s16(ptr_s16_low, ptr_zp_vec); + const int16x8_t ptr_zp_high = vaddq_s16(ptr_s16_high, ptr_zp_vec); + + int32x4_t ptr1 = vmovl_s16(vget_low_s16(ptr_zp_low)); + int32x4_t ptr2 = vmovl_s16(vget_high_s16(ptr_zp_low)); + int32x4_t ptr3 = vmovl_s16(vget_low_s16(ptr_zp_high)); + int32x4_t ptr4 = vmovl_s16(vget_high_s16(ptr_zp_high)); + + AddInt8InputRounding(&ptr1, &ptr2, &ptr3, &ptr4, ptr_left_vec, ptr_right_vec, ptr_args->multiplier_); + + /* calculate output */ + int32x4_t out1 = vaddq_s32(ptr1, ele1); + int32x4_t out2 = vaddq_s32(ptr2, ele2); + int32x4_t out3 = vaddq_s32(ptr3, ele3); + int32x4_t out4 = vaddq_s32(ptr4, ele4); + + AddInt8OutputRounding(&out1, &out2, &out3, &out4, out_left_vec, out_right_vec, params->out_multiplier_); + + const int16x4_t out1_s16 = vmovn_s32(out1); + const int16x4_t out2_s16 = vmovn_s32(out2); + const int16x4_t out3_s16 = vmovn_s32(out3); + const int16x4_t out4_s16 = vmovn_s32(out4); + + const int16x8_t out_s16_1 = vaddq_s16(vcombine_s16(out1_s16, out2_s16), out_zp_vec); + const int16x8_t out_s16_2 = vaddq_s16(vcombine_s16(out3_s16, out4_s16), out_zp_vec); + + const int8x16_t out = vcombine_s8(vqmovn_s16(out_s16_1), vqmovn_s16(out_s16_2)); + const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vec, out)); + + vst1q_s8(output + index, int8_out); + } +#endif + for (; index < size; index++) { + const int32_t ptr_left = (ptr_in[index] + ptr_args->zp_) * ptr_left_shift; + const int32_t ele_left = (element_in + ele_args->zp_) * ele_left_shift; + const int32_t ptr = MultiplyByMultiplierAndRightShift(ptr_left, ptr_args->multiplier_, ptr_args->right_shift_); + const int32_t ele = MultiplyByMultiplierAndRightShift(ele_left, ele_args->multiplier_, ele_args->right_shift_); + + int32_t out = MultiplyByQuantizedMultiplier(ptr + ele, params->out_multiplier_, params->out_left_shift_, + -params->out_right_shift_); + out += params->out_zp_; + output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); + } + return; +} + +int ElementAddInt8(const int8_t *in0, const int8_t *in1, int8_t *out, int size) { + for (int i = 0; i < size; i++) { + out[i] = in0[i] + in1[i]; + } + return NNACL_OK; +} + +int BroadcastAddInt8(const int8_t *in0, const int8_t *in1, int8_t *tile_in0, int8_t *tile_in1, int8_t *out, int size, + ArithmeticParameter *param) { + TileDimensionsInt8(in0, in1, tile_in0, tile_in1, param); + return ElementAddInt8(tile_in0, tile_in1, out, size); +} + +#ifdef ENABLE_AVX +void AddInt8Rounding(__m128i *in1, __m128i *in2, __m128i *in3, __m128i *in4, const __m128i left_vec, + const int32_t right_shift, const __m128i multiplier) { + // Apply left shift + *in1 = _mm_mullo_epi32(*in1, left_vec); + *in2 = _mm_mullo_epi32(*in2, left_vec); + *in3 = _mm_mullo_epi32(*in3, left_vec); + *in4 = _mm_mullo_epi32(*in4, left_vec); + + // Apply the fixed-point part of the multiplier. + *in1 = _mm_qrdmulh_epi32(*in1, multiplier); + *in2 = _mm_qrdmulh_epi32(*in2, multiplier); + *in3 = _mm_qrdmulh_epi32(*in3, multiplier); + *in4 = _mm_qrdmulh_epi32(*in4, multiplier); + + // Apply right shift + int32_t in1_remainder_mask = (1ll << (right_shift)) - 1; + int32_t in1_remainder_threshold = in1_remainder_mask >> 1; + const __m128i vin1_remainder_mask = _mm_set1_epi32(in1_remainder_mask); + const __m128i vin1_remainder_threshold = _mm_set1_epi32(in1_remainder_threshold); + + const __m128i in1_remainder = + _mm_add_epi32(_mm_and_si128(*in1, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in1)); + *in1 = _mm_sub_epi32(_mm_rshr_epi32(*in1, right_shift), _mm_cmpgt_epi32(in1_remainder, vin1_remainder_threshold)); + + const __m128i in2_remainder = + _mm_add_epi32(_mm_and_si128(*in2, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in2)); + *in2 = _mm_sub_epi32(_mm_rshr_epi32(*in2, right_shift), _mm_cmpgt_epi32(in2_remainder, vin1_remainder_threshold)); + + const __m128i in3_remainder = + _mm_add_epi32(_mm_and_si128(*in3, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in3)); + *in3 = _mm_sub_epi32(_mm_rshr_epi32(*in3, right_shift), _mm_cmpgt_epi32(in3_remainder, vin1_remainder_threshold)); + + const __m128i in4_remainder = + _mm_add_epi32(_mm_and_si128(*in4, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in4)); + *in4 = _mm_sub_epi32(_mm_rshr_epi32(*in4, right_shift), _mm_cmpgt_epi32(in4_remainder, vin1_remainder_threshold)); +} + +void AddInt8_AVX2(const int8_t *input0, const int8_t *input1, int8_t *output, int size, + const AddQuantParameter *params) { + const int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_); + const int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_); + const __m128i min_vec = _mm_set1_epi8(params->min_); + const __m128i max_vec = _mm_set1_epi8(params->max_); + const __m128i in0_zp_vec = _mm_set1_epi16(params->in0_args_.zp_); + const __m128i in1_zp_vec = _mm_set1_epi16(params->in1_args_.zp_); + const __m128i out_zp_vec = _mm_set1_epi16(params->out_zp_); + const __m128i in0_left_vec = _mm_set1_epi32(in0_left_shift); + const __m128i in1_left_vec = _mm_set1_epi32(in1_left_shift); + const __m128i in0_multiplier = _mm_set1_epi32(params->in0_args_.multiplier_); + const __m128i in1_multiplier = _mm_set1_epi32(params->in1_args_.multiplier_); + const __m128i out_multiplier = _mm_set1_epi32(params->out_multiplier_); + int index = 0; + for (; index <= size - 16; index += 16) { + const __m128i in0_src = _mm_loadu_si128((__m128i *)(input0 + index)); + const __m128i in1_src = _mm_loadu_si128((__m128i *)(input1 + index)); + + const __m256i in0_s16 = _mm256_cvtepi8_epi16(in0_src); + const __m128i in0_s16_low = _mm256_extractf128_si256(in0_s16, 0); + const __m128i in0_s16_high = _mm256_extractf128_si256(in0_s16, 1); + const __m256i in1_s16 = _mm256_cvtepi8_epi16(in1_src); + const __m128i in1_s16_low = _mm256_extractf128_si256(in1_s16, 0); + const __m128i in1_s16_high = _mm256_extractf128_si256(in1_s16, 1); + + const __m128i in0_zp_low = _mm_add_epi16(in0_s16_low, in0_zp_vec); + const __m128i in0_zp_high = _mm_add_epi16(in0_s16_high, in0_zp_vec); + const __m128i in1_zp_low = _mm_add_epi16(in1_s16_low, in1_zp_vec); + const __m128i in1_zp_high = _mm_add_epi16(in1_s16_high, in1_zp_vec); + + __m256i tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_low); + __m128i in0_1 = _mm256_extractf128_si256(tmp_in0, 0); + __m128i in0_2 = _mm256_extractf128_si256(tmp_in0, 1); + tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_high); + __m128i in0_3 = _mm256_extractf128_si256(tmp_in0, 0); + __m128i in0_4 = _mm256_extractf128_si256(tmp_in0, 1); + __m256i tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_low); + __m128i in1_1 = _mm256_extractf128_si256(tmp_in1, 0); + __m128i in1_2 = _mm256_extractf128_si256(tmp_in1, 1); + tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_high); + __m128i in1_3 = _mm256_extractf128_si256(tmp_in1, 0); + __m128i in1_4 = _mm256_extractf128_si256(tmp_in1, 1); + + AddInt8Rounding(&in0_1, &in0_2, &in0_3, &in0_4, in0_left_vec, params->in0_args_.right_shift_, in0_multiplier); + AddInt8Rounding(&in1_1, &in1_2, &in1_3, &in1_4, in1_left_vec, params->in1_args_.right_shift_, in1_multiplier); + + /* calculate output */ + __m128i out1 = _mm_add_epi32(in0_1, in1_1); + __m128i out2 = _mm_add_epi32(in0_2, in1_2); + __m128i out3 = _mm_add_epi32(in0_3, in1_3); + __m128i out4 = _mm_add_epi32(in0_4, in1_4); + + // Apply left shift + out1 = _mm_slli_epi32(out1, params->out_left_shift_); + out2 = _mm_slli_epi32(out2, params->out_left_shift_); + out3 = _mm_slli_epi32(out3, params->out_left_shift_); + out4 = _mm_slli_epi32(out4, params->out_left_shift_); + + // Apply the fixed-point part of the multiplier. + out1 = _mm_qrdmulh_epi32(out1, out_multiplier); + out2 = _mm_qrdmulh_epi32(out2, out_multiplier); + out3 = _mm_qrdmulh_epi32(out3, out_multiplier); + out4 = _mm_qrdmulh_epi32(out4, out_multiplier); + + // Apply right shift + int32_t out_remainder_mask = (1ll << (params->out_right_shift_)) - 1; + int32_t out_remainder_threshold = out_remainder_mask >> 1; + const __m128i vout_remainder_mask = _mm_set1_epi32(out_remainder_mask); + const __m128i vout_remainder_threshold = _mm_set1_epi32(out_remainder_threshold); + const __m128i out1_remainder = + _mm_add_epi32(_mm_and_si128(out1, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out1)); + out1 = _mm_sub_epi32(_mm_rshr_epi32(out1, params->out_right_shift_), + _mm_cmpgt_epi32(out1_remainder, vout_remainder_threshold)); + const __m128i out2_remainder = + _mm_add_epi32(_mm_and_si128(out2, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out2)); + out2 = _mm_sub_epi32(_mm_rshr_epi32(out2, params->out_right_shift_), + _mm_cmpgt_epi32(out2_remainder, vout_remainder_threshold)); + const __m128i out3_remainder = + _mm_add_epi32(_mm_and_si128(out3, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out3)); + out3 = _mm_sub_epi32(_mm_rshr_epi32(out3, params->out_right_shift_), + _mm_cmpgt_epi32(out3_remainder, vout_remainder_threshold)); + const __m128i out4_remainder = + _mm_add_epi32(_mm_and_si128(out4, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out4)); + out4 = _mm_sub_epi32(_mm_rshr_epi32(out4, params->out_right_shift_), + _mm_cmpgt_epi32(out4_remainder, vout_remainder_threshold)); + + __m128i out1_s16 = _mm_packs_epi32(out1, out2); + __m128i out2_s16 = _mm_packs_epi32(out3, out4); + + __m128i out_s16_1 = _mm_add_epi16(out1_s16, out_zp_vec); + __m128i out_s16_2 = _mm_add_epi16(out2_s16, out_zp_vec); + __m128i out = _mm_packs_epi16(out_s16_1, out_s16_2); + __m128i int8_out = _mm_max_epi8(min_vec, _mm_min_epi8(max_vec, out)); + + _mm_storeu_si128((__m128i *)(output + index), int8_out); + } + for (; index < size; index++) { + const int32_t in0_left = (input0[index] + params->in0_args_.zp_) * in0_left_shift; + const int32_t in1_left = (input1[index] + params->in1_args_.zp_) * in1_left_shift; + const int32_t in0 = + MultiplyByMultiplierAndRightShift(in0_left, params->in0_args_.multiplier_, params->in0_args_.right_shift_); + const int32_t in1 = + MultiplyByMultiplierAndRightShift(in1_left, params->in1_args_.multiplier_, params->in1_args_.right_shift_); + + int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_, + -params->out_right_shift_); + out += params->out_zp_; + output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); + } + return; +} + +void AddOptInt8_AVX2(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, + const AddQuantParameter *params, const AddQuantQrgs *ptr_args, const AddQuantQrgs *ele_args) { + // input0: ptr_in + // input1: element_in + // load quant parameters of input0 and input1 + const int in0_left_shift = (1 << params->left_shift_) * (1 << ptr_args->left_shift_); + const int in1_left_shift = (1 << params->left_shift_) * (1 << ele_args->left_shift_); + const __m128i min_vec = _mm_set1_epi8(params->min_); + const __m128i max_vec = _mm_set1_epi8(params->max_); + const __m128i in0_zp_vec = _mm_set1_epi16(ptr_args->zp_); + const __m128i in1_zp_vec = _mm_set1_epi16(ele_args->zp_); + const __m128i out_zp_vec = _mm_set1_epi16(params->out_zp_); + const __m128i in0_left_vec = _mm_set1_epi32(in0_left_shift); + const __m128i in1_left_vec = _mm_set1_epi32(in1_left_shift); + const __m128i in0_multiplier = _mm_set1_epi32(params->in0_args_.multiplier_); + const __m128i in1_multiplier = _mm_set1_epi32(params->in1_args_.multiplier_); + const __m128i out_multiplier = _mm_set1_epi32(params->out_multiplier_); + + // input1 can be processed once because it is const + const __m128i in1_src = _mm_set1_epi8(element_in); + const __m256i in1_s16 = _mm256_cvtepi8_epi16(in1_src); + const __m128i in1_s16_low = _mm256_extractf128_si256(in1_s16, 0); + const __m128i in1_s16_high = _mm256_extractf128_si256(in1_s16, 1); + const __m128i in1_zp_low = _mm_add_epi16(in1_s16_low, in1_zp_vec); + const __m128i in1_zp_high = _mm_add_epi16(in1_s16_high, in1_zp_vec); + __m256i tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_low); + __m128i in1_1 = _mm256_extractf128_si256(tmp_in1, 0); + __m128i in1_2 = _mm256_extractf128_si256(tmp_in1, 1); + tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_high); + __m128i in1_3 = _mm256_extractf128_si256(tmp_in1, 0); + __m128i in1_4 = _mm256_extractf128_si256(tmp_in1, 1); + + AddInt8Rounding(&in1_1, &in1_2, &in1_3, &in1_4, in1_left_vec, params->in1_args_.right_shift_, in1_multiplier); + + int index = 0; + for (; index <= size - 16; index += 16) { + const __m128i in0_src = _mm_loadu_si128((__m128i *)(ptr_in + index)); + const __m256i in0_s16 = _mm256_cvtepi8_epi16(in0_src); + const __m128i in0_s16_low = _mm256_extractf128_si256(in0_s16, 0); + const __m128i in0_s16_high = _mm256_extractf128_si256(in0_s16, 1); + const __m128i in0_zp_low = _mm_add_epi16(in0_s16_low, in0_zp_vec); + const __m128i in0_zp_high = _mm_add_epi16(in0_s16_high, in0_zp_vec); + + __m256i tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_low); + __m128i in0_1 = _mm256_extractf128_si256(tmp_in0, 0); + __m128i in0_2 = _mm256_extractf128_si256(tmp_in0, 1); + tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_high); + __m128i in0_3 = _mm256_extractf128_si256(tmp_in0, 0); + __m128i in0_4 = _mm256_extractf128_si256(tmp_in0, 1); + + AddInt8Rounding(&in0_1, &in0_2, &in0_3, &in0_4, in0_left_vec, params->in0_args_.right_shift_, in0_multiplier); + + /* calculate output */ + __m128i out1 = _mm_add_epi32(in0_1, in1_1); + __m128i out2 = _mm_add_epi32(in0_2, in1_2); + __m128i out3 = _mm_add_epi32(in0_3, in1_3); + __m128i out4 = _mm_add_epi32(in0_4, in1_4); + + // Apply left shift + out1 = _mm_slli_epi32(out1, params->out_left_shift_); + out2 = _mm_slli_epi32(out2, params->out_left_shift_); + out3 = _mm_slli_epi32(out3, params->out_left_shift_); + out4 = _mm_slli_epi32(out4, params->out_left_shift_); + + // Apply the fixed-point part of the multiplier. + out1 = _mm_qrdmulh_epi32(out1, out_multiplier); + out2 = _mm_qrdmulh_epi32(out2, out_multiplier); + out3 = _mm_qrdmulh_epi32(out3, out_multiplier); + out4 = _mm_qrdmulh_epi32(out4, out_multiplier); + + // Apply right shift + int32_t out_remainder_mask = (1ll << (params->out_right_shift_)) - 1; + int32_t out_remainder_threshold = out_remainder_mask >> 1; + const __m128i vout_remainder_mask = _mm_set1_epi32(out_remainder_mask); + const __m128i vout_remainder_threshold = _mm_set1_epi32(out_remainder_threshold); + const __m128i out1_remainder = + _mm_add_epi32(_mm_and_si128(out1, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out1)); + out1 = _mm_sub_epi32(_mm_rshr_epi32(out1, params->out_right_shift_), + _mm_cmpgt_epi32(out1_remainder, vout_remainder_threshold)); + const __m128i out2_remainder = + _mm_add_epi32(_mm_and_si128(out2, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out2)); + out2 = _mm_sub_epi32(_mm_rshr_epi32(out2, params->out_right_shift_), + _mm_cmpgt_epi32(out2_remainder, vout_remainder_threshold)); + const __m128i out3_remainder = + _mm_add_epi32(_mm_and_si128(out3, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out3)); + out3 = _mm_sub_epi32(_mm_rshr_epi32(out3, params->out_right_shift_), + _mm_cmpgt_epi32(out3_remainder, vout_remainder_threshold)); + const __m128i out4_remainder = + _mm_add_epi32(_mm_and_si128(out4, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out4)); + out4 = _mm_sub_epi32(_mm_rshr_epi32(out4, params->out_right_shift_), + _mm_cmpgt_epi32(out4_remainder, vout_remainder_threshold)); + + __m128i out1_s16 = _mm_packs_epi32(out1, out2); + __m128i out2_s16 = _mm_packs_epi32(out3, out4); + + __m128i out_s16_1 = _mm_add_epi16(out1_s16, out_zp_vec); + __m128i out_s16_2 = _mm_add_epi16(out2_s16, out_zp_vec); + __m128i out = _mm_packs_epi16(out_s16_1, out_s16_2); + __m128i int8_out = _mm_max_epi8(min_vec, _mm_min_epi8(max_vec, out)); + + _mm_storeu_si128((__m128i *)(output + index), int8_out); + } + for (; index < size; index++) { + const int32_t in0_left = (ptr_in[index] + ptr_args->zp_) * in0_left_shift; + const int32_t in1_left = (element_in + ele_args->zp_) * in1_left_shift; + const int32_t in0 = MultiplyByMultiplierAndRightShift(in0_left, ptr_args->multiplier_, ptr_args->right_shift_); + const int32_t in1 = MultiplyByMultiplierAndRightShift(in1_left, ele_args->multiplier_, ele_args->right_shift_); + + int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_, + -params->out_right_shift_); + out += params->out_zp_; + output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); + } + return; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/add_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/add_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..0ecd815498c8a65661bc5e77aa51b16d4573c56f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/add_int8.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_ADD_INT8_H_ +#define NNACL_ADD_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/int8/arithmetic_int8.h" + +typedef struct AddQuantQrgs { + int32_t zp_; + int32_t left_shift_; + int32_t right_shift_; + int32_t multiplier_; +} AddQuantQrgs; + +typedef struct AddQuantParameter { + int left_shift_; + int32_t min_; + int32_t max_; + + AddQuantQrgs in0_args_; + AddQuantQrgs in1_args_; + + int32_t out_zp_; + int32_t out_left_shift_; + int32_t out_right_shift_; + int32_t out_multiplier_; +} AddQuantParameter; + +#ifdef __cplusplus +extern "C" { +#endif + +void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, const AddQuantParameter *params); + +void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, + const AddQuantParameter *params, const AddQuantQrgs *ptr_args, const AddQuantQrgs *ele_args); + +int ElementAddInt8(const int8_t *in0, const int8_t *in1, int8_t *out, int size); + +int BroadcastAddInt8(const int8_t *in0, const int8_t *in1, int8_t *tile_in0, int8_t *tile_in1, int8_t *out, int size, + ArithmeticParameter *param); + +#ifdef ENABLE_AVX +void AddInt8_AVX2(const int8_t *input0, const int8_t *input1, int8_t *output, int size, + const AddQuantParameter *params); + +void AddOptInt8_AVX2(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, + const AddQuantParameter *params, const AddQuantQrgs *ptr_args, const AddQuantQrgs *ele_args); +#endif +#ifdef __cplusplus +} +#endif + +#endif // NNACL_ADD_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arg_min_max_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arg_min_max_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..0d3318806cc57fa351656985237e797adef9f8a8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arg_min_max_int8.c @@ -0,0 +1,237 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/int8/arg_min_max_int8.h" +#include + +void CalcParameter(const int32_t *shape, int dims_number, int axis, int32_t *pre_axis_count, int32_t *axis_count, + int32_t *after_axis_count) { + *pre_axis_count = 1; + for (int i = 0; i < axis; ++i) { + *pre_axis_count = (*pre_axis_count) * shape[i]; + } + + *axis_count = shape[axis]; + + *after_axis_count = 1; + for (int i = axis + 1; i < dims_number; ++i) { + *after_axis_count = (*after_axis_count) * shape[i]; + } +} + +void SetOutputValue(float value, int32_t index, int8_t *output1, int8_t *output2, int offset, + float output_inverse_scale, float output_zp, bool out_value) { + if (output2 != NULL) { + int32_t *output1_index = (int32_t *)output1; + output1_index[offset] = index; + output2[offset] = value * output_inverse_scale + output_zp; + } else { + if (out_value) { + output1[offset] = value * output_inverse_scale + output_zp; + } else { + int32_t *output1_index = (int32_t *)output1; + output1_index[offset] = index; + } + } +} + +void DoArgMinMaxQuant(const int8_t *input, int8_t *output1, int8_t *output2, const ArgMinMaxComputeParam *param, + int pre_axis_count, int axis_count, int after_axis_count, const QuantArg *in_quant_arg, + const QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + for (int i = 0; i < pre_axis_count; ++i) { + int output_offset = i * after_axis_count; + int input_offset = output_offset * axis_count; + for (int j = 0; j < after_axis_count; ++j) { + float value = -FLT_MAX; + if (!param->get_max_) { + value = FLT_MAX; + } + int32_t index = 0; + for (int k = 0; k < axis_count; ++k) { + float value_tmp = input[input_offset + k * after_axis_count + j] * in_quant_arg->scale_ + bias; + if (param->get_max_) { + if (value_tmp > value) { + value = value_tmp; + index = k; + } + } else { + if (value_tmp < value) { + value = value_tmp; + index = k; + } + } + } + SetOutputValue(value, index, output1, output2, output_offset + j, output_inverse_scale, output_zp, out_value); + } + } +} + +void Int8ArgMinMaxQuant(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + const ArgMinMaxComputeParam *param, const QuantArg *in_quant_arg, + const QuantArg *out_quant_arg) { + int pre_axis_count = 1; + int axis_count = 1; + int after_axis_count = 1; + CalcParameter(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count); + DoArgMinMaxQuant(input, output1, output2, param, pre_axis_count, axis_count, after_axis_count, in_quant_arg, + out_quant_arg); + return; +} + +int ArgCompareAscInt8(const void *a, const void *b) { + return ((ArgElement *)a)->data_.f_data_ - ((ArgElement *)b)->data_.f_data_; +} + +int ArgCompareDescInt8(const void *a, const void *b) { + return ((ArgElement *)b)->data_.f_data_ - ((ArgElement *)a)->data_.f_data_; +} + +int8_t GetInt8Output(float real_out, float output_inverse_scale, int32_t output_zp) { + return real_out * output_inverse_scale + output_zp; +} + +void Int8ArgMinMaxDim0(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + for (int32_t i = 0; i < param->in_strides_[0]; ++i) { + for (int j = 0; j < in_shape[0]; ++j) { + int offset = param->in_strides_[0] * j + i; + param->arg_elements_[j].index_ = (uint32_t)j; + param->arg_elements_[j].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareAscInt8); + } + + for (int j = 0; j < param->topk_; ++j) { + int out_offset = j * param->out_strides_[0] + i; + SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2, + out_offset, output_inverse_scale, output_zp, out_value); + } + } +} + +void Int8ArgMinMaxDim1(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + int in_shape1 = in_shape[1]; + for (int i = 0; i < in_shape[0]; ++i) { + int in_dim0_offset = i * param->in_strides_[0]; + int out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < param->in_strides_[1]; ++j) { + for (int k = 0; k < in_shape1; ++k) { + int offset = param->in_strides_[1] * k + in_dim0_offset + j; + param->arg_elements_[k].index_ = (size_t)k; + param->arg_elements_[k].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareAscInt8); + } + + for (int k = 0; k < param->topk_; ++k) { + int out_offset = out_dim0_offset + j + k * param->out_strides_[1]; + SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2, + out_offset, output_inverse_scale, output_zp, out_value); + } + } + } +} + +void Int8ArgMinMaxDim2(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + for (int i = 0; i < in_shape[0]; ++i) { + int in_dim0_offset = i * param->in_strides_[0]; + int out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + int in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + int out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < param->in_strides_[2]; ++k) { + for (int l = 0; l < in_shape2; ++l) { + int offset = param->in_strides_[2] * l + k + in_dim1_offset; + param->arg_elements_[l].index_ = (uint32_t)l; + param->arg_elements_[l].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareAscInt8); + } + for (int l = 0; l < param->topk_; ++l) { + int out_offset = out_dim1_offset + k + l * param->out_strides_[2]; + SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2, + out_offset, output_inverse_scale, output_zp, out_value); + } + } + } + } +} + +void Int8ArgMinMaxDim3(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + int in_shape3 = in_shape[3]; + for (int i = 0; i < in_shape[0]; ++i) { + int in_dim0_offset = i * param->in_strides_[0]; + int out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + int in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + int out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < in_shape2; ++k) { + int in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset; + int out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset; + for (int l = 0; l < in_shape3; ++l) { + int offset = l + in_dim2_offset; + param->arg_elements_[l].index_ = (uint32_t)l; + param->arg_elements_[l].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareAscInt8); + } + for (int l = 0; l < param->topk_; ++l) { + int out_offset = out_dim2_offset + l; + SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2, + out_offset, output_inverse_scale, output_zp, out_value); + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arg_min_max_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arg_min_max_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..d4ccfc5a228c7cbd039af6f83b20c52258da2ea4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arg_min_max_int8.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_ARG_MIN_MAX_INT8_H_ +#define NNACL_INT8_ARG_MIN_MAX_INT8_H_ + +#include "nnacl_c/arg_min_max_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/kernel/arg_min_max.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Int8ArgMinMaxQuant(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + const ArgMinMaxComputeParam *param, const QuantArg *in_quant, const QuantArg *out_quant); +void Int8ArgMinMaxDim0(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant, const QuantArg *out_quant); +void Int8ArgMinMaxDim1(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant, const QuantArg *out_quant); +void Int8ArgMinMaxDim2(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant, const QuantArg *out_quant); +void Int8ArgMinMaxDim3(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant, const QuantArg *out_quant); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_ARG_MIN_MAX_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..f08592a08d0b78c5a32dd9cf15e013f44890f6b8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_int8.c @@ -0,0 +1,137 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/arithmetic_int8.h" +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/errorcode.h" + +void TileOneDimensionInt8(const int8_t *inData, int8_t *outData, int dim, size_t ndim, const int32_t *inShape, + const int32_t *inStrides, const int32_t *outStrides, const int32_t *multiple) { + int srcDimSize = inShape[dim]; + if (dim == ndim - 1) { + for (int i = 0; i < multiple[dim]; i++) { + memcpy(outData, inData, srcDimSize * sizeof(int8_t)); + outData += srcDimSize; + } + return; + } + for (size_t i = 0; i < srcDimSize; i++) { + for (size_t j = 0; j < multiple[dim]; j++) { + TileOneDimensionInt8(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim, + inShape, inStrides, outStrides, multiple); + } + } +} + +void TileDimensionsInt8(const int8_t *data0, const int8_t *data1, int8_t *tile_data0, int8_t *tile_data1, + ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimensionInt8(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, + param->multiples0_); + TileOneDimensionInt8(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_, + param->multiples1_); +} + +#define ACCURACY_DATA 0.00000001 + +int ElementNotEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + float minus_inputs = in0_real - in1_real; + bool out_real = true; + if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) { + out_real = false; + } + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +int ElementEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + float minus_inputs = in0_real - in1_real; + bool out_real = false; + if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) { + out_real = true; + } + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +int ElementLessInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + bool out_real = in0_real < in1_real; + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +int ElementLessEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + bool out_real = in0_real <= in1_real; + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +int ElementGreaterInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + bool out_real = in0_real > in1_real; + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + bool out_real = in0_real >= in1_real; + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +#undef ACCURACY_DATA diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..8d1d2be04ee681b5e4efbca0b7fc1971fefdfe32 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_int8.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_ARITHMETIC_INT8_H_ +#define NNACL_INT8_ARITHMETIC_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/base/arithmetic_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void TileOneDimensionInt8(const int8_t *inData, int8_t *outData, int dim, size_t ndim, const int32_t *inShape, + const int32_t *inStrides, const int32_t *outStrides, const int32_t *multiple); +void TileDimensionsInt8(const int8_t *data0, const int8_t *data1, int8_t *tile_data0, int8_t *tile_data1, + ArithmeticParameter *param); + +int ElementNotEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); + +int ElementEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg); + +int ElementLessInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg); + +int ElementLessEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); + +int ElementGreaterInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); + +int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_ARITHMETIC_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_self_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_self_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..5e3fc901394fc007a85fcd0edd76b321b5cb224b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_self_int8.c @@ -0,0 +1,305 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/arithmetic_self_int8.h" +#include +#include +#ifdef ENABLE_NEON +#include +#include "nnacl_c/int8/common_func_int8.h" +#endif +#include "nnacl_c/int8/fixed_point.h" + +int Int8ElementFloor(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(floorf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementRound(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(round(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementCeil(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(ceil(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementAbs(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(fabsf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementSin(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(sinf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementCos(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(cosf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementLog(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(logf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementSqrt(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + float input_f32 = input[i] * in_scale + bias; + if (input_f32 < 0) { + return NNACL_ERRCODE_SQRT_NEGATIVE; + } + int32_t output_tmp = round(sqrtf(input_f32) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementRsqrt(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + float input_f32 = input[i] * in_scale + bias; + if (input_f32 <= 0) { + return NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO; + } + int32_t output_tmp = round(1.f / (sqrtf(input_f32) * out_scale)) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +#ifdef ENABLE_NEON + +int16x4_t ClacSumHalfWord(int32x4_t scaled_input, int32x4_t left_shift_out_vec, int32x4_t output_multiplier_vec, + ArithSelfQuantArg para) { + int32x4_t input_scale = vmulq_s32(scaled_input, scaled_input); + int32x4_t raw_sum = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), + para.shift_right_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(para.out_args_.zp_)); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(para.output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(para.output_activation_max_)); + return vqmovn_s32(raw_sum); +} + +void SquareInt8NEON(const int8_t *input_data, int8_t *output_data, int64_t element_size, ArithSelfQuantArg para, + int32_t *index) { + int32x4_t output_multiplier_vec = vdupq_n_s32(para.output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << (size_t)para.shift_left_); + + for (; (*index) <= element_size - 8; (*index) += 8) { + int16x8_t input_val = LoadAndAddOffset(input_data, *index, para.in_args_.zp_); + int32x4_t input_low = vmovl_s16(vget_low_s16(input_val)); + int32x4_t input_high = vmovl_s16(vget_high_s16(input_val)); + + int16x4_t sum_low = ClacSumHalfWord(input_low, left_shift_out_vec, output_multiplier_vec, para); + int16x4_t sum_high = ClacSumHalfWord(input_high, left_shift_out_vec, output_multiplier_vec, para); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(output_data, res_u8_n0); + output_data += 8; + } +} +#endif + +int Int8ElementSquare(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + int32_t in_zp = para.in_args_.zp_; + int32_t out_zp = para.out_args_.zp_; + + int index = 0; +#ifdef ENABLE_NEON + SquareInt8NEON(input, output, element_size, para, &index); +#endif + for (; index < element_size; index++) { + const int32_t input_val = input[index] + in_zp; + int32_t output_tmp = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input_val * input_val * (1 << para.shift_left_), para.output_multiplier_), + para.shift_right_); + output_tmp += out_zp; + if (output_tmp > para.output_activation_max_) { + output[index] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[index] = para.output_activation_min_; + } else { + output[index] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementLogicalNot(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(((float)(!(bool)(input[i] * in_scale + bias))) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (output_tmp); + } + } + return NNACL_OK; +} + +int Int8ElementReciprocal(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + float input_f32 = input[i] * in_scale + bias; + if (fabs(input_f32) <= FLT_EPSILON) { + return NNACL_ERR; + } + int32_t output_tmp = round(1.f / (input_f32 * out_scale)) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_self_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_self_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..21e7d75b4a507d3f07fd7a030f59dd9353c26025 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_self_int8.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_ARITHMETIC_SELF_INT8_H_ +#define NNACL_INT8_ARITHMETIC_SELF_INT8_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Int8ElementRound(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementFloor(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementCeil(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementAbs(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementSin(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementCos(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementLog(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementSqrt(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementRsqrt(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementSquare(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementLogicalNot(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementReciprocal(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_ARITHMETIC_SELF_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batch_to_space_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batch_to_space_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..e963505ad35c5d30c1d4fab98695665a4448bfcf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batch_to_space_int8.c @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/batch_to_space_int8.h" + +void BatchToSpaceNoCropForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, int out_n, + const int32_t *block, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) { + int block_h = block[0]; + int block_w = block[1]; + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int64_t stride_h = block_w * out_n; + int64_t output_offset = 0; + int64_t in_stride_h = in_w * in_c; + int64_t in_stride_n = in_stride_h * in_h; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float scale = in_quant_arg->scale_ * output_inverse_scale; + float bias = -in_quant_arg->zp_ * scale; + int32_t output_zp = out_quant_arg->zp_; + + for (int n = 0; n < out_n; ++n) { + for (int h = 0; h < in_h; ++h) { + int64_t h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + for (int w = 0; w < in_w; ++w) { + int64_t w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + int64_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + for (int c = 0; c < in_c; ++c) { + int32_t output_tmp = round(input[in_offset + c] * scale + bias) + output_zp; + output_tmp = output_tmp > 127 ? 127 : output_tmp; + output_tmp = output_tmp < -128 ? -128 : output_tmp; + output[output_offset++] = output_tmp; + } + } + } + } + } + } +} + +void BatchToSpaceForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, int out_n, + const int32_t *block, const int32_t *crops, const QuantArg *in_quant_arg, + const QuantArg *out_quant_arg) { + int block_h = block[0]; + int block_w = block[1]; + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int h_start = crops[0] / block_h; + int h_valid_begin = crops[0]; + int h_end = MSMIN((in_h * block_h - crops[1]) / block_h + 1, in_h); + int h_valid_end = in_h * block_h - crops[1] - 1; + int w_start = crops[2] / block_w; + int w_valid_begin = crops[2]; + int w_end = MSMIN((in_w * block_w - crops[3]) / block_w + 1, in_w); + int w_valid_end = in_w * block_w - crops[3] - 1; + + int64_t stride_h = block_w * out_n; + int64_t output_offset = 0; + int64_t in_stride_h = in_w * in_c; + int64_t in_stride_n = in_stride_h * in_h; + + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float scale = in_quant_arg->scale_ * output_inverse_scale; + float bias = -in_quant_arg->zp_ * scale; + int32_t output_zp = out_quant_arg->zp_; + + for (int n = 0; n < out_n; ++n) { + for (int h = h_start; h < h_end; ++h) { + int64_t h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + int64_t h_index = h * block_h + bh; + if (h_index < h_valid_begin || h_index > h_valid_end) { + continue; + } + for (int w = w_start; w < w_end; ++w) { + int64_t w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + int64_t w_index = w * block_w + bw; + if (w_index < w_valid_begin || w_index > w_valid_end) { + continue; + } + int64_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + for (int c = 0; c < in_c; ++c) { + int32_t output_tmp = round(input[in_offset + c] * scale + bias) + output_zp; + output_tmp = output_tmp > 127 ? 127 : output_tmp; + output_tmp = output_tmp < -128 ? -128 : output_tmp; + output[output_offset++] = output_tmp; + } + } + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batch_to_space_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batch_to_space_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..7fe6bdd2be460e13d0d790fda616d8d382458fdb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batch_to_space_int8.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_BATCH_TO_SPACE_INT8_H_ +#define NNACL_INT8_BATCH_TO_SPACE_INT8_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +void BatchToSpaceNoCropForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, int out_n, + const int32_t *block, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg); +void BatchToSpaceForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, int out_n, + const int32_t *block, const int32_t *crops, const QuantArg *in_quant_arg, + const QuantArg *out_quant_arg); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_BATCH_TO_SPACE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batchnorm_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batchnorm_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..ea89a205ced721be807f62bd22bc2fee842479f3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batchnorm_int8.c @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/batchnorm_int8.h" +#include +#include "nnacl_c/batchnorm_parameter.h" + +void BatchNormInt8(int8_t *output_ptr, const int8_t *input_ptr, const float *alpha_ptr, const float *beta_ptr, + int task_id, int unit, int units, int channel) { + int unit_st = task_id * unit; + int unit_end = MSMIN((task_id + 1) * unit, units); + for (int u = unit_st; u < unit_end; u++) { + for (int c = 0; c < channel; c++) { + int32_t output_tmp = round(input_ptr[u * channel + c] * alpha_ptr[c] + beta_ptr[c]); + output_tmp = output_tmp > 127 ? 127 : output_tmp; + output_tmp = output_tmp < -128 ? -128 : output_tmp; + output_ptr[u * channel + c] = (int8_t)output_tmp; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batchnorm_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batchnorm_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..fc3b1f2559fb8584701ee6619876ba9ed2b86bec --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batchnorm_int8.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_BATCHNORM_H_ +#define NNACL_INT8_BATCHNORM_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/batchnorm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void BatchNormInt8(int8_t *output_ptr, const int8_t *input_ptr, const float *alpha_ptr, const float *beta_ptr, + int task_id, int unit, int units, int channel); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_BATCHNORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/common_func_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/common_func_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..b9424784cc9134cfa1324810e2d35841f6b5db46 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/common_func_int8.c @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/common_func_int8.h" +#include "nnacl_c/int8/fixed_point.h" + +void PostConvFuncCommInt8(const int32_t *in, int8_t *out, const int32_t *bias, size_t oc, size_t plane, + size_t out_oc_stride, size_t in_plane_stride, int32_t multiplier, int32_t mini, int32_t maxi, + int32_t left_shift, int32_t right_shift, int32_t zp, int size) { + if (size == 0) { + return; + } + for (int r = 0; r < plane; r++) { + for (int c = 0; c < oc; c++) { + int c4div = c / size, c4mod = c % size; + int src_index = c4div * in_plane_stride + r * size + c4mod; + int dst_index = r * out_oc_stride + c; + int32_t value = in[src_index]; + if (bias != NULL) { + value = in[src_index] + bias[c]; + } + value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + out[dst_index] = (int8_t)value; + } + } + return; +} + +void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride, + int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, + int32_t maxi) { +/* ((int32_t)row4x4-major + bias) * multiplier + output_zp => (int8)relu => (int8_t)row-major */ +#ifndef ENABLE_ARM64 + PostConvFuncCommInt8(in, out, bias, oc, plane, stride, UP_ROUND(plane, C4NUM) * C4NUM, multiplier, mini, maxi, + left_shift, right_shift, zp, C4NUM); +#else + size_t oc4div = oc / C4NUM * C4NUM; + size_t oc4res = oc % C4NUM; + PostFuncInt8C4Neon64(in, bias, out, oc4div, oc4res, plane, stride * sizeof(int8_t), multiplier, left_shift, + right_shift, zp, mini, maxi); +#endif + return; +} + +#ifdef ENABLE_ARM +int16x8_t LoadAndAddOffset(const int8_t *data, int index, int offset) { + int8x8_t input_s8 = vld1_s8(data + index); + int16x8_t input_s16 = vmovl_s8(input_s8); + return vaddq_s16(input_s16, vdupq_n_s16(offset)); +} + +int32x4_t ClacScaledInput(int32x4_t input, int32x4_t left_shift_result_vec, int32x4_t input_multiplier_vec, + int32x4_t right_shift_vec) { + int32x4_t shifted_input = vmulq_s32(input, left_shift_result_vec); + shifted_input = vqrdmulhq_s32(shifted_input, input_multiplier_vec); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(shifted_input, right_shift_vec), 31); + return vrshlq_s32(vqaddq_s32(shifted_input, fixup), right_shift_vec); +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/common_func_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/common_func_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..fba3af4c2cfadf7cfcc97271e874f3d6bd8320b2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/common_func_int8.h @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_COMMON_FUNC_H_ +#define NNACL_INT8_COMMON_FUNC_H_ + +#include +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride, + int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, + int32_t maxi); +#ifdef ENABLE_ARM +void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, + int output_channel, int input_step, int8_t input_zp); +void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, + const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, int32_t acc_min, int32_t acc_max); +void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, + int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); +void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *weight, size_t ksize, size_t ic8, + size_t oc4, size_t offset); +void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, + size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, + size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, const int8_t *in_zp, + const int32_t *out_zp, const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *acc_min, const int32_t *acc_max); +void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +void DeconvDwInt8Post(int8_t *dst, int32_t *output_buffer, const int32_t *bias, int block_channel, int pixel_nums, + int out_multiplier, int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, + int32_t acc_max); +int16x8_t LoadAndAddOffset(const int8_t *data, int index, int offset); +int32x4_t ClacScaledInput(int32x4_t input, int32x4_t left_shift_result_vec, int32x4_t input_multiplier_vec, + int32x4_t right_shift_vec); +#endif + +#ifdef ENABLE_ARM32 +void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, + int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + int32_t acc_min, int32_t acc_max, size_t per_channel); +#endif + +#ifdef ENABLE_ARM64 +void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res, + size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift, + int32_t zp, int32_t mini, int32_t maxi); +void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, + int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, + int32_t out_zp, const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, int32_t acc_min, int32_t acc_max, size_t per_channel); +void ConvDw3x3Int8Stride2(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, + int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, + int32_t out_zp, const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, int32_t acc_min, int32_t acc_max, size_t per_channel); +void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, + size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, const int32_t *out_multiplier, + const int32_t *left_shift, const int32_t *right_shift, size_t acc_min, size_t acc_max, + size_t per_channel); +void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, + size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + size_t acc_min, size_t acc_max, size_t per_channel); +void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, + size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + size_t acc_min, size_t acc_max, size_t per_channel); +#endif +#ifdef __cplusplus +} +#endif + +#endif /* NNACL_FP32_COMMON_FUNC_H_ */ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/concat_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/concat_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..6b1d8bf00c668b07d1b1fff69a14224eadcfa396 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/concat_int8.c @@ -0,0 +1,57 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/concat_int8.h" +#include +#include +#include +#include "nnacl_c/concat_parameter.h" + +void Int8Concat(int8_t **inputs, int8_t *output, const ConcatParameter *para, int axis, int64_t real_dst_count, + int task_id, int input_num, int64_t count_unit, int64_t after_axis_size, int **input_shapes, + const int32_t *output_shape) { + float output_scale = para->quant_arg_.out_args_.scale_; + const float output_inverse_scale = 1.f / output_scale; + int out_copy_size = output_shape[axis] * after_axis_size; + QuantArg *input_quant = para->quant_arg_.in_args_; + int output_zp = para->quant_arg_.out_args_.zp_; + int8_t max_int8 = para->quant_arg_.output_activation_max_; + int8_t min_int8 = para->quant_arg_.output_activation_min_; + int64_t start = task_id * count_unit; + int64_t end = start + real_dst_count; + output += start * out_copy_size; + + for (int k = start; k < end; k++) { + for (int i = 0; i < input_num; i++) { + const int32_t *input_shape = input_shapes[i]; + int64_t in_copy_size = input_shape[axis] * after_axis_size; + const int8_t *input_ptr = inputs[i] + k * in_copy_size; + if (fabs(input_quant[i].scale_ - output_scale) <= FLT_EPSILON && input_quant[i].zp_ == output_zp) { + memcpy(output, input_ptr, in_copy_size); + } else { + float scale = input_quant[i].scale_ * output_inverse_scale; + float bias = -input_quant[i].zp_ * scale; + for (int j = 0; j < in_copy_size; j++) { + int32_t output_tmp = round(input_ptr[j] * scale + bias) + output_zp; + output_tmp = output_tmp > min_int8 ? output_tmp : min_int8; + output_tmp = output_tmp < max_int8 ? output_tmp : max_int8; + output[j] = (int8_t)output_tmp; + } + } + output += in_copy_size; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/concat_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/concat_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..e80e88316ce9b33026cf5d7e9c995a4e28addead --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/concat_int8.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_CONCAT_INT8_H_ +#define NNACL_INT8_CONCAT_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/concat_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Int8Concat(int8_t **inputs, int8_t *output, const ConcatParameter *para, int axis, int64_t real_dst_count, + int task_id, int input_num, int64_t count_unit, int64_t after_axis_size, int **input_shapes, + const int32_t *output_shape); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CONCAT_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv1x1_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv1x1_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..9a51d27f058743214fd38f987855059149e80729 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv1x1_int8.c @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/conv1x1_int8.h" + +void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, + const int32_t *filter_zp) { + int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; + matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, + left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc, + filter_zp); + return; +} + +void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, const int32_t *filter_zp) { + int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; + MatmulInt8Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, + conv_param->output_channel_, is_per_oc, filter_zp); + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv1x1_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv1x1_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..fd8868e2c68c4c64fcbc288710b536176a609f00 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv1x1_int8.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_CONV1X1_INT8_H_ +#define NNACL_INT8_CONV1X1_INT8_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/int8/matmul_int8.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, const int32_t *filter_zp); +void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, + const int32_t *filter_zp); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CONV1X1_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..6ecb53e1faed83e81b026095ea44d3b6f0b84911 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.c @@ -0,0 +1,902 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/conv3x3_int8.h" + +void Conv3x3Int8InputUnit(const int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) { +#ifdef ENABLE_ARM + int16x8_t zp = vdupq_n_s16(input_zp); + + int16x8_t d00 = vsubq_s16(vld1q_s16(tmp_data), zp); + int16x8_t d01 = vsubq_s16(vld1q_s16(tmp_data + 8), zp); + int16x8_t d02 = vsubq_s16(vld1q_s16(tmp_data + 2 * 8), zp); + int16x8_t d03 = vsubq_s16(vld1q_s16(tmp_data + 3 * 8), zp); + + int16x8_t d10 = vsubq_s16(vld1q_s16(tmp_data + 4 * 8), zp); + int16x8_t d11 = vsubq_s16(vld1q_s16(tmp_data + 5 * 8), zp); + int16x8_t d12 = vsubq_s16(vld1q_s16(tmp_data + 6 * 8), zp); + int16x8_t d13 = vsubq_s16(vld1q_s16(tmp_data + 7 * 8), zp); + + int16x8_t d20 = vsubq_s16(vld1q_s16(tmp_data + 8 * 8), zp); + int16x8_t d21 = vsubq_s16(vld1q_s16(tmp_data + 9 * 8), zp); + int16x8_t d22 = vsubq_s16(vld1q_s16(tmp_data + 10 * 8), zp); + int16x8_t d23 = vsubq_s16(vld1q_s16(tmp_data + 11 * 8), zp); + + int16x8_t d30 = vsubq_s16(vld1q_s16(tmp_data + 12 * 8), zp); + int16x8_t d31 = vsubq_s16(vld1q_s16(tmp_data + 13 * 8), zp); + int16x8_t d32 = vsubq_s16(vld1q_s16(tmp_data + 14 * 8), zp); + int16x8_t d33 = vsubq_s16(vld1q_s16(tmp_data + 15 * 8), zp); + + int16x8_t t00 = vsubq_s16(d00, d20); + int16x8_t t01 = vsubq_s16(d01, d21); + int16x8_t t02 = vsubq_s16(d02, d22); + int16x8_t t03 = vsubq_s16(d03, d23); + + int16x8_t t10 = vaddq_s16(d10, d20); + int16x8_t t11 = vaddq_s16(d11, d21); + int16x8_t t12 = vaddq_s16(d12, d22); + int16x8_t t13 = vaddq_s16(d13, d23); + + int16x8_t t20 = vsubq_s16(d20, d10); + int16x8_t t21 = vsubq_s16(d21, d11); + int16x8_t t22 = vsubq_s16(d22, d12); + int16x8_t t23 = vsubq_s16(d23, d13); + + int16x8_t t30 = vsubq_s16(d10, d30); + int16x8_t t31 = vsubq_s16(d11, d31); + int16x8_t t32 = vsubq_s16(d12, d32); + int16x8_t t33 = vsubq_s16(d13, d33); + + int16x8_t m00 = vsubq_s16(t00, t02); + int16x8_t m01 = vaddq_s16(t01, t02); + int16x8_t m02 = vsubq_s16(t02, t01); + int16x8_t m03 = vsubq_s16(t01, t03); + + int16x8_t m10 = vsubq_s16(t10, t12); + int16x8_t m11 = vaddq_s16(t11, t12); + int16x8_t m12 = vsubq_s16(t12, t11); + int16x8_t m13 = vsubq_s16(t11, t13); + + int16x8_t m20 = vsubq_s16(t20, t22); + int16x8_t m21 = vaddq_s16(t21, t22); + int16x8_t m22 = vsubq_s16(t22, t21); + int16x8_t m23 = vsubq_s16(t21, t23); + + int16x8_t m30 = vsubq_s16(t30, t32); + int16x8_t m31 = vaddq_s16(t31, t32); + int16x8_t m32 = vsubq_s16(t32, t31); + int16x8_t m33 = vsubq_s16(t31, t33); + + vst1q_s16(trans_input_data, m00); + vst1q_s16(trans_input_data + step, m01); + vst1q_s16(trans_input_data + 2 * step, m02); + vst1q_s16(trans_input_data + 3 * step, m03); + + vst1q_s16(trans_input_data + 4 * step, m10); + vst1q_s16(trans_input_data + 5 * step, m11); + vst1q_s16(trans_input_data + 6 * step, m12); + vst1q_s16(trans_input_data + 7 * step, m13); + + vst1q_s16(trans_input_data + 8 * step, m20); + vst1q_s16(trans_input_data + 9 * step, m21); + vst1q_s16(trans_input_data + 10 * step, m22); + vst1q_s16(trans_input_data + 11 * step, m23); + + vst1q_s16(trans_input_data + 12 * step, m30); + vst1q_s16(trans_input_data + 13 * step, m31); + vst1q_s16(trans_input_data + 14 * step, m32); + vst1q_s16(trans_input_data + 15 * step, m33); +#else + for (int i = 0; i < C8NUM; i++) { + const int16_t *local_ptr = tmp_data + i; + int16_t d00 = local_ptr[0] - input_zp; + int16_t d01 = (local_ptr + C8NUM)[0] - input_zp; + int16_t d02 = (local_ptr + 2 * C8NUM)[0] - input_zp; + int16_t d03 = (local_ptr + 3 * C8NUM)[0] - input_zp; + + int16_t d10 = (local_ptr + 4 * C8NUM)[0] - input_zp; + int16_t d11 = (local_ptr + 5 * C8NUM)[0] - input_zp; + int16_t d12 = (local_ptr + 6 * C8NUM)[0] - input_zp; + int16_t d13 = (local_ptr + 7 * C8NUM)[0] - input_zp; + + int16_t d20 = (local_ptr + 8 * C8NUM)[0] - input_zp; + int16_t d21 = (local_ptr + 9 * C8NUM)[0] - input_zp; + int16_t d22 = (local_ptr + 10 * C8NUM)[0] - input_zp; + int16_t d23 = (local_ptr + 11 * C8NUM)[0] - input_zp; + + int16_t d30 = (local_ptr + 12 * C8NUM)[0] - input_zp; + int16_t d31 = (local_ptr + 13 * C8NUM)[0] - input_zp; + int16_t d32 = (local_ptr + 14 * C8NUM)[0] - input_zp; + int16_t d33 = (local_ptr + 15 * C8NUM)[0] - input_zp; + + int16_t t00 = d00 - d20; + int16_t t01 = d01 - d21; + int16_t t02 = d02 - d22; + int16_t t03 = d03 - d23; + + int16_t t10 = d10 + d20; + int16_t t11 = d11 + d21; + int16_t t12 = d12 + d22; + int16_t t13 = d13 + d23; + + int16_t t20 = d20 - d10; + int16_t t21 = d21 - d11; + int16_t t22 = d22 - d12; + int16_t t23 = d23 - d13; + + int16_t t30 = d10 - d30; + int16_t t31 = d11 - d31; + int16_t t32 = d12 - d32; + int16_t t33 = d13 - d33; + + int16_t m00 = t00 - t02; + int16_t m01 = t01 + t02; + int16_t m02 = t02 - t01; + int16_t m03 = t01 - t03; + + int16_t m10 = t10 - t12; + int16_t m11 = t11 + t12; + int16_t m12 = t12 - t11; + int16_t m13 = t11 - t13; + + int16_t m20 = t20 - t22; + int16_t m21 = t21 + t22; + int16_t m22 = t22 - t21; + int16_t m23 = t21 - t23; + + int16_t m30 = t30 - t32; + int16_t m31 = t31 + t32; + int16_t m32 = t32 - t31; + int16_t m33 = t31 - t33; + + (trans_input_data + i)[0] = m00; + (trans_input_data + i + step)[0] = m01; + (trans_input_data + i + 2 * step)[0] = m02; + (trans_input_data + i + 3 * step)[0] = m03; + + (trans_input_data + i + 4 * step)[0] = m10; + (trans_input_data + i + 5 * step)[0] = m11; + (trans_input_data + i + 6 * step)[0] = m12; + (trans_input_data + i + 7 * step)[0] = m13; + + (trans_input_data + i + 8 * step)[0] = m20; + (trans_input_data + i + 9 * step)[0] = m21; + (trans_input_data + i + 10 * step)[0] = m22; + (trans_input_data + i + 11 * step)[0] = m23; + + (trans_input_data + i + 12 * step)[0] = m30; + (trans_input_data + i + 13 * step)[0] = m31; + (trans_input_data + i + 14 * step)[0] = m32; + (trans_input_data + i + 15 * step)[0] = m33; + } +#endif +} + +void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, + int kernel_plane) { + const int input_unit = 4; + int dst_step = iC8 * C8NUM * C4NUM; + for (int o = 0; o < output_channel; o++) { + int oc4_block_num = o / C4NUM; + int oc4_block_rem = o % C4NUM; + int src_oc_offset = o * iC8 * C8NUM * kernel_plane; + int dst_oc_offset = oc4_block_num * C4NUM * iC8 * C8NUM * input_unit * input_unit + oc4_block_rem; + for (int i = 0; i < iC8; i++) { + const int16_t *src_ic8_ptr = weight_data + src_oc_offset + i * kernel_plane * C8NUM; + int16_t *dst_ic8_ptr = trans_weight + dst_oc_offset + i * C4NUM * C8NUM; +#ifdef ENABLE_ARM + int16x8_t g00 = vld1q_s16(src_ic8_ptr); + int16x8_t g01 = vld1q_s16(src_ic8_ptr + 8); + int16x8_t g02 = vld1q_s16(src_ic8_ptr + 2 * 8); + int16x8_t g10 = vld1q_s16(src_ic8_ptr + 3 * 8); + int16x8_t g11 = vld1q_s16(src_ic8_ptr + 4 * 8); + int16x8_t g12 = vld1q_s16(src_ic8_ptr + 5 * 8); + int16x8_t g20 = vld1q_s16(src_ic8_ptr + 6 * 8); + int16x8_t g21 = vld1q_s16(src_ic8_ptr + 7 * 8); + int16x8_t g22 = vld1q_s16(src_ic8_ptr + 8 * 8); + + int16x8_t dst00 = vmulq_n_s16(g00, 2); + int16x8_t dst01 = vmulq_n_s16(g01, 2); + int16x8_t dst02 = vmulq_n_s16(g02, 2); + + int16x8_t dst10 = vaddq_s16(vaddq_s16(g00, g10), g20); + int16x8_t dst11 = vaddq_s16(vaddq_s16(g01, g11), g21); + int16x8_t dst12 = vaddq_s16(vaddq_s16(g02, g12), g22); + + int16x8_t dst20 = vaddq_s16(vsubq_s16(g00, g10), g20); + int16x8_t dst21 = vaddq_s16(vsubq_s16(g01, g11), g21); + int16x8_t dst22 = vaddq_s16(vsubq_s16(g02, g12), g22); + + int16x8_t dst30 = vmulq_n_s16(g20, 2); + int16x8_t dst31 = vmulq_n_s16(g21, 2); + int16x8_t dst32 = vmulq_n_s16(g22, 2); + + int16x8_t m00 = vmulq_n_s16(dst00, 2); + int16x8_t m01 = vaddq_s16(vaddq_s16(dst00, dst01), dst02); + int16x8_t m02 = vaddq_s16(vsubq_s16(dst00, dst01), dst02); + int16x8_t m03 = vmulq_n_s16(dst02, 2); + + int16x8_t m10 = vmulq_n_s16(dst10, 2); + int16x8_t m11 = vaddq_s16(vaddq_s16(dst10, dst11), dst12); + int16x8_t m12 = vaddq_s16(vsubq_s16(dst10, dst11), dst12); + int16x8_t m13 = vmulq_n_s16(dst12, 2); + + int16x8_t m20 = vmulq_n_s16(dst20, 2); + int16x8_t m21 = vaddq_s16(vaddq_s16(dst20, dst21), dst22); + int16x8_t m22 = vaddq_s16(vsubq_s16(dst20, dst21), dst22); + int16x8_t m23 = vmulq_n_s16(dst22, 2); + + int16x8_t m30 = vmulq_n_s16(dst30, 2); + int16x8_t m31 = vaddq_s16(vaddq_s16(dst30, dst31), dst32); + int16x8_t m32 = vaddq_s16(vsubq_s16(dst30, dst31), dst32); + int16x8_t m33 = vmulq_n_s16(dst32, 2); + + dst_ic8_ptr[0] = m00[0]; + dst_ic8_ptr[4] = m00[1]; + dst_ic8_ptr[8] = m00[2]; + dst_ic8_ptr[12] = m00[3]; + dst_ic8_ptr[16] = m00[4]; + dst_ic8_ptr[20] = m00[5]; + dst_ic8_ptr[24] = m00[6]; + dst_ic8_ptr[28] = m00[7]; + + dst_ic8_ptr[0 + dst_step] = m01[0]; + dst_ic8_ptr[4 + dst_step] = m01[1]; + dst_ic8_ptr[8 + dst_step] = m01[2]; + dst_ic8_ptr[12 + dst_step] = m01[3]; + dst_ic8_ptr[16 + dst_step] = m01[4]; + dst_ic8_ptr[20 + dst_step] = m01[5]; + dst_ic8_ptr[24 + dst_step] = m01[6]; + dst_ic8_ptr[28 + dst_step] = m01[7]; + + dst_ic8_ptr[0 + 2 * dst_step] = m02[0]; + dst_ic8_ptr[4 + 2 * dst_step] = m02[1]; + dst_ic8_ptr[8 + 2 * dst_step] = m02[2]; + dst_ic8_ptr[12 + 2 * dst_step] = m02[3]; + dst_ic8_ptr[16 + 2 * dst_step] = m02[4]; + dst_ic8_ptr[20 + 2 * dst_step] = m02[5]; + dst_ic8_ptr[24 + 2 * dst_step] = m02[6]; + dst_ic8_ptr[28 + 2 * dst_step] = m02[7]; + + dst_ic8_ptr[0 + 3 * dst_step] = m03[0]; + dst_ic8_ptr[4 + 3 * dst_step] = m03[1]; + dst_ic8_ptr[8 + 3 * dst_step] = m03[2]; + dst_ic8_ptr[12 + 3 * dst_step] = m03[3]; + dst_ic8_ptr[16 + 3 * dst_step] = m03[4]; + dst_ic8_ptr[20 + 3 * dst_step] = m03[5]; + dst_ic8_ptr[24 + 3 * dst_step] = m03[6]; + dst_ic8_ptr[28 + 3 * dst_step] = m03[7]; + + dst_ic8_ptr[0 + 4 * dst_step] = m10[0]; + dst_ic8_ptr[4 + 4 * dst_step] = m10[1]; + dst_ic8_ptr[8 + 4 * dst_step] = m10[2]; + dst_ic8_ptr[12 + 4 * dst_step] = m10[3]; + dst_ic8_ptr[16 + 4 * dst_step] = m10[4]; + dst_ic8_ptr[20 + 4 * dst_step] = m10[5]; + dst_ic8_ptr[24 + 4 * dst_step] = m10[6]; + dst_ic8_ptr[28 + 4 * dst_step] = m10[7]; + + dst_ic8_ptr[0 + 5 * dst_step] = m11[0]; + dst_ic8_ptr[4 + 5 * dst_step] = m11[1]; + dst_ic8_ptr[8 + 5 * dst_step] = m11[2]; + dst_ic8_ptr[12 + 5 * dst_step] = m11[3]; + dst_ic8_ptr[16 + 5 * dst_step] = m11[4]; + dst_ic8_ptr[20 + 5 * dst_step] = m11[5]; + dst_ic8_ptr[24 + 5 * dst_step] = m11[6]; + dst_ic8_ptr[28 + 5 * dst_step] = m11[7]; + + dst_ic8_ptr[0 + 6 * dst_step] = m12[0]; + dst_ic8_ptr[4 + 6 * dst_step] = m12[1]; + dst_ic8_ptr[8 + 6 * dst_step] = m12[2]; + dst_ic8_ptr[12 + 6 * dst_step] = m12[3]; + dst_ic8_ptr[16 + 6 * dst_step] = m12[4]; + dst_ic8_ptr[20 + 6 * dst_step] = m12[5]; + dst_ic8_ptr[24 + 6 * dst_step] = m12[6]; + dst_ic8_ptr[28 + 6 * dst_step] = m12[7]; + + dst_ic8_ptr[0 + 7 * dst_step] = m13[0]; + dst_ic8_ptr[4 + 7 * dst_step] = m13[1]; + dst_ic8_ptr[8 + 7 * dst_step] = m13[2]; + dst_ic8_ptr[12 + 7 * dst_step] = m13[3]; + dst_ic8_ptr[16 + 7 * dst_step] = m13[4]; + dst_ic8_ptr[20 + 7 * dst_step] = m13[5]; + dst_ic8_ptr[24 + 7 * dst_step] = m13[6]; + dst_ic8_ptr[28 + 7 * dst_step] = m13[7]; + + dst_ic8_ptr[0 + 8 * dst_step] = m20[0]; + dst_ic8_ptr[4 + 8 * dst_step] = m20[1]; + dst_ic8_ptr[8 + 8 * dst_step] = m20[2]; + dst_ic8_ptr[12 + 8 * dst_step] = m20[3]; + dst_ic8_ptr[16 + 8 * dst_step] = m20[4]; + dst_ic8_ptr[20 + 8 * dst_step] = m20[5]; + dst_ic8_ptr[24 + 8 * dst_step] = m20[6]; + dst_ic8_ptr[28 + 8 * dst_step] = m20[7]; + + dst_ic8_ptr[0 + 9 * dst_step] = m21[0]; + dst_ic8_ptr[4 + 9 * dst_step] = m21[1]; + dst_ic8_ptr[8 + 9 * dst_step] = m21[2]; + dst_ic8_ptr[12 + 9 * dst_step] = m21[3]; + dst_ic8_ptr[16 + 9 * dst_step] = m21[4]; + dst_ic8_ptr[20 + 9 * dst_step] = m21[5]; + dst_ic8_ptr[24 + 9 * dst_step] = m21[6]; + dst_ic8_ptr[28 + 9 * dst_step] = m21[7]; + + dst_ic8_ptr[0 + 10 * dst_step] = m22[0]; + dst_ic8_ptr[4 + 10 * dst_step] = m22[1]; + dst_ic8_ptr[8 + 10 * dst_step] = m22[2]; + dst_ic8_ptr[12 + 10 * dst_step] = m22[3]; + dst_ic8_ptr[16 + 10 * dst_step] = m22[4]; + dst_ic8_ptr[20 + 10 * dst_step] = m22[5]; + dst_ic8_ptr[24 + 10 * dst_step] = m22[6]; + dst_ic8_ptr[28 + 10 * dst_step] = m22[7]; + + dst_ic8_ptr[0 + 11 * dst_step] = m23[0]; + dst_ic8_ptr[4 + 11 * dst_step] = m23[1]; + dst_ic8_ptr[8 + 11 * dst_step] = m23[2]; + dst_ic8_ptr[12 + 11 * dst_step] = m23[3]; + dst_ic8_ptr[16 + 11 * dst_step] = m23[4]; + dst_ic8_ptr[20 + 11 * dst_step] = m23[5]; + dst_ic8_ptr[24 + 11 * dst_step] = m23[6]; + dst_ic8_ptr[28 + 11 * dst_step] = m23[7]; + + dst_ic8_ptr[0 + 12 * dst_step] = m30[0]; + dst_ic8_ptr[4 + 12 * dst_step] = m30[1]; + dst_ic8_ptr[8 + 12 * dst_step] = m30[2]; + dst_ic8_ptr[12 + 12 * dst_step] = m30[3]; + dst_ic8_ptr[16 + 12 * dst_step] = m30[4]; + dst_ic8_ptr[20 + 12 * dst_step] = m30[5]; + dst_ic8_ptr[24 + 12 * dst_step] = m30[6]; + dst_ic8_ptr[28 + 12 * dst_step] = m30[7]; + + dst_ic8_ptr[0 + 13 * dst_step] = m31[0]; + dst_ic8_ptr[4 + 13 * dst_step] = m31[1]; + dst_ic8_ptr[8 + 13 * dst_step] = m31[2]; + dst_ic8_ptr[12 + 13 * dst_step] = m31[3]; + dst_ic8_ptr[16 + 13 * dst_step] = m31[4]; + dst_ic8_ptr[20 + 13 * dst_step] = m31[5]; + dst_ic8_ptr[24 + 13 * dst_step] = m31[6]; + dst_ic8_ptr[28 + 13 * dst_step] = m31[7]; + + dst_ic8_ptr[0 + 14 * dst_step] = m32[0]; + dst_ic8_ptr[4 + 14 * dst_step] = m32[1]; + dst_ic8_ptr[8 + 14 * dst_step] = m32[2]; + dst_ic8_ptr[12 + 14 * dst_step] = m32[3]; + dst_ic8_ptr[16 + 14 * dst_step] = m32[4]; + dst_ic8_ptr[20 + 14 * dst_step] = m32[5]; + dst_ic8_ptr[24 + 14 * dst_step] = m32[6]; + dst_ic8_ptr[28 + 14 * dst_step] = m32[7]; + + dst_ic8_ptr[0 + 15 * dst_step] = m33[0]; + dst_ic8_ptr[4 + 15 * dst_step] = m33[1]; + dst_ic8_ptr[8 + 15 * dst_step] = m33[2]; + dst_ic8_ptr[12 + 15 * dst_step] = m33[3]; + dst_ic8_ptr[16 + 15 * dst_step] = m33[4]; + dst_ic8_ptr[20 + 15 * dst_step] = m33[5]; + dst_ic8_ptr[24 + 15 * dst_step] = m33[6]; + dst_ic8_ptr[28 + 15 * dst_step] = m33[7]; +#else + for (int j = 0; j < C8NUM; j++) { + const int16_t *local_ptr = src_ic8_ptr + j; + int16_t dst00 = local_ptr[0] * 2; + int16_t dst01 = (local_ptr + 8)[0] * 2; + int16_t dst02 = (local_ptr + 16)[0] * 2; + + int16_t dst10 = local_ptr[0] + (local_ptr + 24)[0] + (local_ptr + 48)[0]; + int16_t dst11 = (local_ptr + 8)[0] + (local_ptr + 32)[0] + (local_ptr + 56)[0]; + int16_t dst12 = (local_ptr + 16)[0] + (local_ptr + 40)[0] + (local_ptr + 64)[0]; + + int16_t dst20 = local_ptr[0] - (local_ptr + 24)[0] + (local_ptr + 48)[0]; + int16_t dst21 = (local_ptr + 8)[0] - (local_ptr + 32)[0] + (local_ptr + 56)[0]; + int16_t dst22 = (local_ptr + 16)[0] - (local_ptr + 40)[0] + (local_ptr + 64)[0]; + + int16_t dst30 = (local_ptr + 48)[0] * 2; + int16_t dst31 = (local_ptr + 56)[0] * 2; + int16_t dst32 = (local_ptr + 64)[0] * 2; + + int16_t m00 = dst00 * 2; + int16_t m01 = dst00 + dst01 + dst02; + int16_t m02 = dst00 - dst01 + dst02; + int16_t m03 = dst02 * 2; + + int16_t m10 = dst10 * 2; + int16_t m11 = dst10 + dst11 + dst12; + int16_t m12 = dst10 - dst11 + dst12; + int16_t m13 = dst12 * 2; + + int16_t m20 = dst20 * 2; + int16_t m21 = dst20 + dst21 + dst22; + int16_t m22 = dst20 - dst21 + dst22; + int16_t m23 = dst22 * 2; + + int16_t m30 = dst30 * 2; + int16_t m31 = dst30 + dst31 + dst32; + int16_t m32 = dst30 - dst31 + dst32; + int16_t m33 = dst32 * 2; + + *(dst_ic8_ptr + j * 4) = m00; + *(dst_ic8_ptr + j * 4 + dst_step) = m01; + *(dst_ic8_ptr + j * 4 + 2 * dst_step) = m02; + *(dst_ic8_ptr + j * 4 + 3 * dst_step) = m03; + + *(dst_ic8_ptr + j * 4 + 4 * dst_step) = m10; + *(dst_ic8_ptr + j * 4 + 5 * dst_step) = m11; + *(dst_ic8_ptr + j * 4 + 6 * dst_step) = m12; + *(dst_ic8_ptr + j * 4 + 7 * dst_step) = m13; + + *(dst_ic8_ptr + j * 4 + 8 * dst_step) = m20; + *(dst_ic8_ptr + j * 4 + 9 * dst_step) = m21; + *(dst_ic8_ptr + j * 4 + 10 * dst_step) = m22; + *(dst_ic8_ptr + j * 4 + 11 * dst_step) = m23; + + *(dst_ic8_ptr + j * 4 + 12 * dst_step) = m30; + *(dst_ic8_ptr + j * 4 + 13 * dst_step) = m31; + *(dst_ic8_ptr + j * 4 + 14 * dst_step) = m32; + *(dst_ic8_ptr + j * 4 + 15 * dst_step) = m33; + } +#endif + } + } +} + +void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, + bool w_not_bound, int output_w, int real_num, int oc_start, + const ConvParameter *conv_param) { + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; + int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; + int out_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int out_max = conv_param->conv_quant_arg_.out_act_max_[0]; + +#ifdef ENABLE_ARM + int32x4_t bias_ptr = vld1q_s32(bias_data); + + int32x4_t s00 = vld1q_s32(gemm_out); + int32x4_t s01 = vld1q_s32(gemm_out + 4); + int32x4_t s02 = vld1q_s32(gemm_out + 8); + int32x4_t s03 = vld1q_s32(gemm_out + 12); + + int32x4_t s10 = vld1q_s32(gemm_out + 16); + int32x4_t s11 = vld1q_s32(gemm_out + 20); + int32x4_t s12 = vld1q_s32(gemm_out + 24); + int32x4_t s13 = vld1q_s32(gemm_out + 28); + + int32x4_t s20 = vld1q_s32(gemm_out + 32); + int32x4_t s21 = vld1q_s32(gemm_out + 36); + int32x4_t s22 = vld1q_s32(gemm_out + 40); + int32x4_t s23 = vld1q_s32(gemm_out + 44); + + int32x4_t s30 = vld1q_s32(gemm_out + 48); + int32x4_t s31 = vld1q_s32(gemm_out + 52); + int32x4_t s32 = vld1q_s32(gemm_out + 56); + int32x4_t s33 = vld1q_s32(gemm_out + 60); + + int32x4_t t00 = vshrq_n_s32(vaddq_s32(vaddq_s32(s00, s10), s20), 1); + int32x4_t t01 = vshrq_n_s32(vaddq_s32(vaddq_s32(s01, s11), s21), 1); + int32x4_t t02 = vshrq_n_s32(vaddq_s32(vaddq_s32(s02, s12), s22), 1); + int32x4_t t03 = vshrq_n_s32(vaddq_s32(vaddq_s32(s03, s13), s23), 1); + + int32x4_t t10 = vshrq_n_s32(vsubq_s32(vsubq_s32(s10, s20), s30), 1); + int32x4_t t11 = vshrq_n_s32(vsubq_s32(vsubq_s32(s11, s21), s31), 1); + int32x4_t t12 = vshrq_n_s32(vsubq_s32(vsubq_s32(s12, s22), s32), 1); + int32x4_t t13 = vshrq_n_s32(vsubq_s32(vsubq_s32(s13, s23), s33), 1); + + int32x4_t d00 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t00, t01), t02), 1), bias_ptr); + int32x4_t d01 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t01, t02), t03), 1), bias_ptr); + + int32x4_t d10 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t10, t11), t12), 1), bias_ptr); + int32x4_t d11 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t11, t12), t13), 1), bias_ptr); + + int32x4_t out_multiplier; + int32x4_t ls; + int32x4_t rs; + if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + out_multiplier = vld1q_s32(quant_multiplier + oc_start); + ls = vld1q_s32(left_shift + oc_start); + rs = vld1q_s32(right_shift + oc_start); + } else { + out_multiplier = vdupq_n_s32(quant_multiplier[0]); + ls = vdupq_n_s32(left_shift[0]); + rs = vdupq_n_s32(right_shift[0]); + } + int32x4_t out_zp = vdupq_n_s32(output_zp); + int32x4_t output_min = vdupq_n_s32(out_min); + int32x4_t output_max = vdupq_n_s32(out_max); + + d00 = vqshlq_s32(d00, ls); + d00 = vqrdmulhq_s32(d00, out_multiplier); + int32x4_t carry = vandq_s32(d00, rs); + carry = vshrq_n_s32(carry, 31); + d00 = vqaddq_s32(d00, carry); + d00 = vqrshlq_s32(d00, rs); + d00 = vaddq_s32(d00, out_zp); + d00 = vmaxq_s32(d00, output_min); + d00 = vminq_s32(d00, output_max); + + d01 = vqshlq_s32(d01, ls); + d01 = vqrdmulhq_s32(d01, out_multiplier); + carry = vandq_s32(d01, rs); + carry = vshrq_n_s32(carry, 31); + d01 = vqaddq_s32(d01, carry); + d01 = vqrshlq_s32(d01, rs); + d01 = vaddq_s32(d01, out_zp); + d01 = vmaxq_s32(d01, output_min); + d01 = vminq_s32(d01, output_max); + + d10 = vqshlq_s32(d10, ls); + d10 = vqrdmulhq_s32(d10, out_multiplier); + carry = vandq_s32(d10, rs); + carry = vshrq_n_s32(carry, 31); + d10 = vqaddq_s32(d10, carry); + d10 = vqrshlq_s32(d10, rs); + d10 = vaddq_s32(d10, out_zp); + d10 = vmaxq_s32(d10, output_min); + d10 = vminq_s32(d10, output_max); + + d11 = vqshlq_s32(d11, ls); + d11 = vqrdmulhq_s32(d11, out_multiplier); + carry = vandq_s32(d11, rs); + carry = vshrq_n_s32(carry, 31); + d11 = vqaddq_s32(d11, carry); + d11 = vqrshlq_s32(d11, rs); + d11 = vaddq_s32(d11, out_zp); + d11 = vmaxq_s32(d11, output_min); + d11 = vminq_s32(d11, output_max); + + (output_data)[0] = (int8_t)d00[0]; + (output_data + 1)[0] = (int8_t)d00[1]; + (output_data + 2)[0] = (int8_t)d00[2]; + (output_data + 3)[0] = (int8_t)d00[3]; + + if (w_not_bound) { + *(output_data + 4) = (int8_t)d01[0]; + *(output_data + 5) = (int8_t)d01[1]; + *(output_data + 6) = (int8_t)d01[2]; + *(output_data + 7) = (int8_t)d01[3]; + } + if (h_not_bound) { + *(output_data + output_w * 4) = (int8_t)d10[0]; + *(output_data + output_w * 4 + 1) = (int8_t)d10[1]; + *(output_data + output_w * 4 + 2) = (int8_t)d10[2]; + *(output_data + output_w * 4 + 3) = (int8_t)d10[3]; + if (w_not_bound) { + *(output_data + output_w * 4 + 4) = (int8_t)d11[0]; + *(output_data + output_w * 4 + 5) = (int8_t)d11[1]; + *(output_data + output_w * 4 + 6) = (int8_t)d11[2]; + *(output_data + output_w * 4 + 7) = (int8_t)d11[3]; + } + } +#else + if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + for (int i = 0; i < C4NUM; i++) { + const int32_t *local_ptr = gemm_out + i; + const int32_t *bias_ptr = bias_data + i; + + int32_t s00 = local_ptr[0]; + int32_t s01 = (local_ptr + 4)[0]; + int32_t s02 = (local_ptr + 8)[0]; + int32_t s03 = (local_ptr + 12)[0]; + + int32_t s10 = (local_ptr + 16)[0]; + int32_t s11 = (local_ptr + 20)[0]; + int32_t s12 = (local_ptr + 24)[0]; + int32_t s13 = (local_ptr + 28)[0]; + + int32_t s20 = (local_ptr + 32)[0]; + int32_t s21 = (local_ptr + 36)[0]; + int32_t s22 = (local_ptr + 40)[0]; + int32_t s23 = (local_ptr + 44)[0]; + + int32_t s30 = (local_ptr + 48)[0]; + int32_t s31 = (local_ptr + 52)[0]; + int32_t s32 = (local_ptr + 56)[0]; + int32_t s33 = (local_ptr + 60)[0]; + + int32_t t00 = (s00 + s10 + s20) / 2; + int32_t t01 = (s01 + s11 + s21) / 2; + int32_t t02 = (s02 + s12 + s22) / 2; + int32_t t03 = (s03 + s13 + s23) / 2; + + int32_t t10 = (s10 - s20 - s30) / 2; + int32_t t11 = (s11 - s21 - s31) / 2; + int32_t t12 = (s12 - s22 - s32) / 2; + int32_t t13 = (s13 - s23 - s33) / 2; + + int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; + int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; + + int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; + int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; + + int oc_index = oc_start + i; + d00 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d00 += output_zp; + d00 = d00 > out_min ? d00 : out_min; + d00 = d00 < out_max ? d00 : out_max; + + d01 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d01 += output_zp; + d01 = d01 > out_min ? d01 : out_min; + d01 = d01 < out_max ? d01 : out_max; + + d10 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d10 += output_zp; + d10 = d10 > out_min ? d10 : out_min; + d10 = d10 < out_max ? d10 : out_max; + + d11 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d11 += output_zp; + d11 = d11 > out_min ? d11 : out_min; + d11 = d11 < out_max ? d11 : out_max; + + (output_data + i)[0] = (int8_t)d00; + if (w_not_bound) { + (output_data + i + C4NUM)[0] = (int8_t)d01; + } + if (h_not_bound) { + (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; + if (w_not_bound) { + (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; + } + } + } + } else { + for (int i = 0; i < C4NUM; i++) { + const int32_t *local_ptr = gemm_out + i; + const int32_t *bias_ptr = bias_data + i; + + int32_t s00 = local_ptr[0]; + int32_t s01 = (local_ptr + 4)[0]; + int32_t s02 = (local_ptr + 8)[0]; + int32_t s03 = (local_ptr + 12)[0]; + + int32_t s10 = (local_ptr + 16)[0]; + int32_t s11 = (local_ptr + 20)[0]; + int32_t s12 = (local_ptr + 24)[0]; + int32_t s13 = (local_ptr + 28)[0]; + + int32_t s20 = (local_ptr + 32)[0]; + int32_t s21 = (local_ptr + 36)[0]; + int32_t s22 = (local_ptr + 40)[0]; + int32_t s23 = (local_ptr + 44)[0]; + + int32_t s30 = (local_ptr + 48)[0]; + int32_t s31 = (local_ptr + 52)[0]; + int32_t s32 = (local_ptr + 56)[0]; + int32_t s33 = (local_ptr + 60)[0]; + + int32_t t00 = (s00 + s10 + s20) / 2; + int32_t t01 = (s01 + s11 + s21) / 2; + int32_t t02 = (s02 + s12 + s22) / 2; + int32_t t03 = (s03 + s13 + s23) / 2; + + int32_t t10 = (s10 - s20 - s30) / 2; + int32_t t11 = (s11 - s21 - s31) / 2; + int32_t t12 = (s12 - s22 - s32) / 2; + int32_t t13 = (s13 - s23 - s33) / 2; + + int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; + int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; + + int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; + int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; + + d00 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d00 += output_zp; + d00 = d00 > out_min ? d00 : out_min; + d00 = d00 < out_max ? d00 : out_max; + + d01 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d01 += output_zp; + d01 = d01 > out_min ? d01 : out_min; + d01 = d01 < out_max ? d01 : out_max; + + d10 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d10 += output_zp; + d10 = d10 > out_min ? d10 : out_min; + d10 = d10 < out_max ? d10 : out_max; + + d11 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d11 += output_zp; + d11 = d11 > out_min ? d11 : out_min; + d11 = d11 < out_max ? d11 : out_max; + + (output_data + i)[0] = (int8_t)d00; + if (w_not_bound) { + (output_data + i + C4NUM)[0] = (int8_t)d01; + } + if (h_not_bound) { + (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; + if (w_not_bound) { + (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; + } + } + } + } +#endif +} + +void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, + int real_cal_num, int out_w_block, const ConvParameter *conv_param) { + int output_channel = conv_param->output_channel_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + const int oc4 = UP_DIV(output_channel, C4NUM); + const int input_unit = 4; + if (out_w_block == 0) { + return; + } + for (int i = 0; i < real_cal_num; i++) { + int out_w_index = (start_index + i) % out_w_block; + int out_h_index = (start_index + i) / out_w_block; + int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit; + int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w); + + for (int j = 0; j < oc4; j++) { + int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM; + int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w; + const int32_t *src_ptr = gemm_out + src_oc4_offset; + const int32_t *bias_ptr = bias_data + j * C4NUM; + int8_t *dst_ptr = out_data + dst_oc4_offset; + + // output transform + int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM; + bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w; + bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h; + Conv3x3Int8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM, + conv_param); + } + } +} + +void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, + int real_cal_num, int out_w_block, const ConvParameter *conv_param) { + // input data format : nhwc + int input_channel = conv_param->input_channel_; + int input_width = conv_param->input_w_; + int input_height = conv_param->input_h_; + int pad_w = conv_param->pad_l_; + int pad_h = conv_param->pad_u_; + ConvQuantArg quant_arg = conv_param->conv_quant_arg_; + int input_zp = quant_arg.input_quant_args_[0].zp_; + const int ic8 = UP_DIV(input_channel, C8NUM); + const int input_unit = 4; + if (out_w_block == 0) { + return; + } + for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { + int x_id = start_index + cal_id; + int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w; + int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h; + int real_x_start = origin_x > 0 ? 0 : -origin_x; + int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x); + int real_y_start = origin_y > 0 ? 0 : -origin_y; + int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y); + + int src_plane_offset = C8NUM * (origin_y * input_width + origin_x); + int dst_plane_offset = cal_id * C8NUM; + for (int ic = 0; ic < ic8; ic++) { + // copy data from origin input to tmp buffer + for (int i = 0; i < input_unit * input_unit * TILE_NUM; i++) tmp_data[i] = input_zp; + + int src_c8_offset = src_plane_offset + ic * C8NUM * input_height * input_width; + for (int j = real_y_start; j < real_y_end; j++) { + const int16_t *src = input_data + src_c8_offset + C8NUM * (j * input_width + real_x_start); + int16_t *dst = tmp_data + C8NUM * (C4NUM * j + real_x_start); + memcpy(dst, src, (size_t)(real_x_end - real_x_start) * C8NUM * sizeof(int16_t)); + } + // input transform + int dst_ic8_offset = dst_plane_offset + ic * TILE_NUM * C8NUM; + size_t dst_step = (size_t)ic8 * C8NUM * TILE_NUM; + int16_t *trans_input_ptr = trans_input + dst_ic8_offset; + Conv3x3Int8InputUnit(tmp_data, trans_input_ptr, dst_step, input_zp); + } + } +} + +void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) { + int oc4 = UP_DIV(oc, C4NUM); +#ifdef ENABLE_ARM + IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, (size_t)oc4 * 4 * 16 * sizeof(int32_t)); +#else + const int input_unit_square = 16; + for (int c = 0; c < oc4; c++) { + int filter_oc_offset = c * input_unit_square * ic8 * C8NUM * C4NUM; + int dst_oc_offset = c * input_unit_square * C4NUM; + for (int n = 0; n < real_cal_num; n++) { + int src_tile_offset = n * C8NUM; + int dst_tile_offset = dst_oc_offset + n * oc4 * C4NUM * input_unit_square; + for (int i = 0; i < 4; i++) { + int filter_h_offset = filter_oc_offset + i * 4 * ic8 * C8NUM * C4NUM; + int src_h_offset = src_tile_offset + i * C8NUM * ic8 * C8NUM * C4NUM; + int dst_h_offset = dst_tile_offset + i * 4 * 4; + for (int m = 0; m < 4; m++) { + int filter_w_offset = filter_h_offset + m * 4 * C8NUM * ic8; + int src_w_offset = src_h_offset + m * 8 * ic8 * C8NUM; + int dst_w_offset = dst_h_offset + m * C4NUM; + + int32_t acc[4] = {0}; + for (int z = 0; z < 4; z++) { + int filter_offset = filter_w_offset + z; + for (int j = 0; j < ic8; j++) { + int filter_c8_offset = filter_offset + j * 4 * 8; + int src_c8_offset = src_w_offset + j * 8 * 8; + + for (int k = 0; k < 8; k++) { + const int16_t *w_ptr = weight + filter_c8_offset + k * 4; + const int16_t *input_ptr = src + src_c8_offset + k; + acc[z] += w_ptr[0] * input_ptr[0]; + } + } + (dst + dst_w_offset + z)[0] = acc[z]; + } + } + } + } + } +#endif +} + +// int8 convolution 3x3 +void Conv3x3Int8(const int16_t *input_data, const int16_t *transed_weight, const int32_t *bias_data, + int8_t *output_data, int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, + int8_t *tmp_out, int task_id, const ConvParameter *conv_param) { + int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); + int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); + int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); + int output_count = out_w_block * out_h_block; + NNACL_CHECK_ZERO_RETURN(TILE_NUM); + int output_tile_count = UP_DIV(output_count, TILE_NUM); + int oc4 = UP_DIV(conv_param->output_channel_, C4NUM); + int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM; + const int block_unit_buffer_offset = 16 * C8NUM; + int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM; + + for (int batch = 0; batch < conv_param->input_batch_; batch++) { + int in_batch_offset = batch * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_; + int tmp_out_batch_offset = batch * oc4 * C4NUM * conv_param->output_w_ * conv_param->output_h_; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { + int start_index = thread_id * TILE_NUM; + int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; + + Conv3x3Int8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset, + block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, + out_w_block, conv_param); + + Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, + transed_weight, conv_param->output_channel_, ic8, real_cal_num); + + Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset, + bias_data, start_index, real_cal_num, out_w_block, conv_param); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..794c4ea613ee54d7c58b4e33917ef83b174549ad --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_CONV_INT8_H_ +#define NNACL_INT8_CONV_INT8_H_ + +#include +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/int8/common_func_int8.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, + int kernel_plane); + +void Conv3x3Int8(const int16_t *input_data, const int16_t *transed_weight, const int32_t *bias_data, + int8_t *output_data, int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, + int8_t *tmp_out, int task_id, const ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CONV_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_depthwise_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_depthwise_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..7d83429c7f4ab69b40d6159e59c1f90ffb9dee59 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_depthwise_int8.c @@ -0,0 +1,825 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/conv_depthwise_int8.h" +#include +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/int8/common_func_int8.h" + +/*conv depthwise int8 begin*/ +#ifndef ENABLE_ARM +void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, + int output_channel, int input_step, int8_t input_zp) { + for (int i = 0; i < num_pixels; i++) { + for (int c = 0; c < output_channel; c++) { + const int16_t input = input_ptr[c] - input_zp; + *output_ptr++ += input * weight_ptr[c]; + } + input_ptr += input_step; + } +} +#endif + +void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int output_w, int channel, int32_t output_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + int32_t acc_min, int32_t acc_max, bool per_channel) { + if (per_channel) { + // support perchannel + for (int w = 0; w < output_w; w++) { + int channel4 = 0; +#ifdef ENABLE_ARM + channel4 = channel / 4 * 4; + ConvDwInt8PostAlign4PerChannel(dst, buffer, channel4, output_zp, out_multiplier, left_shift, right_shift, acc_min, + acc_max); +#endif + for (int c = channel4; c < channel; c++) { + buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), + -right_shift[c]); + buffer[c] += output_zp; + buffer[c] = MSMAX(buffer[c], acc_min); + buffer[c] = MSMIN(buffer[c], acc_max); + dst[c] = (buffer[c]); + } + buffer += channel; + dst += channel; + } + } else { + int num_pixels = output_w * channel; + int align_num = 0; +#ifdef ENABLE_ARM + align_num = num_pixels / 4 * 4; + ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier[0], left_shift[0], right_shift[0], acc_min, + acc_max); +#endif + for (int i = align_num; i < num_pixels; i++) { + buffer[i] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(buffer[i] * (1 << (unsigned int)left_shift[0]), out_multiplier[0]), + -right_shift[0]); + buffer[i] += output_zp; + buffer[i] = MSMAX(buffer[i], acc_min); + buffer[i] = MSMIN(buffer[i], acc_max); + dst[i] = (buffer[i]); + } + } +} + +void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, int task_id) { + int step_h = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int start_h = step_h * task_id; + int end_h = MSMIN(start_h + step_h, conv_param->output_h_); + + bool filter_per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; + int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; + + int intput_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; + int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; + int acc_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int acc_max = conv_param->conv_quant_arg_.out_act_max_[0]; + + for (int b = 0; b < conv_param->output_batch_; b++) { + const int8_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + int8_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = start_h; oh < end_h; oh++) { + int8_t *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_; + + int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_)); + + // init acc + for (int ow = 0; ow < conv_param->output_w_; ow++) { + memcpy(row_buffer + ow * conv_param->output_channel_, bias_data, conv_param->output_channel_ * sizeof(int32_t)); + } + for (int kh = start_kh; kh < end_kh; kh++) { + int ih = ih_origin + conv_param->dilation_h_ * kh; + + const int8_t *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_; + const int16_t *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_; + + int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_; + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + int out_w_start = MSMAX( + 0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_); + int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ - + conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / + conv_param->stride_w_); + + int32_t *acc_w = row_buffer + out_w_start * conv_param->output_channel_; + int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw; + + const int8_t *src_kw = src_kh + iw_origin * conv_param->input_channel_; + int num_pixels = out_w_end - out_w_start; + + ConvDwInt8Row(acc_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step, intput_zp); + weight_kh += conv_param->output_channel_; + } + } + // post func, acc int32 -> dst int8 + ConvDwInt8Post(dst_data, row_buffer, conv_param->output_w_, conv_param->output_channel_, output_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + } + } +} +/*conv depthwise int8 end*/ + +/*conv depthwise 3x3 int8 begin*/ +void ConvDw3x3Int8InitBuffer(int8_t *buffer, const int8_t *input, const ConvParameter *conv_param, int block_input_h, + int block_input_w) { + for (int h = 0; h < block_input_h; h++) { + const int8_t *src = input; + for (int w = 0; w < block_input_w; w++) { + memcpy(buffer, src, 64); + src += conv_param->input_channel_; + buffer += 64; + } + input += conv_param->input_w_ * conv_param->input_channel_; + } +} + +void ConvDw3x3Int8Window(int8_t *output, const int8_t *buffer, const int16_t *weight, const int32_t *bias, int col_size, + int row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + int32_t acc_min, int32_t acc_max, int stride, bool per_channel) { + for (int w = 0; w < output_w; w++) { + int tmp_buffer[C8NUM]; + for (int i = 0; i < C8NUM; i++) { + tmp_buffer[i] = 0; + } + int8_t *output_tmp = output; + const int8_t *src_kh = buffer; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < 3; kh++) { + const int8_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < 3; kw++) { + for (int c = 0; c < 8; c++) { + tmp_buffer[c] += (src_kw[c] - in_zp) * weight_kw[c]; + } + src_kw += col_size; + weight_kw += channel; + } + src_kh += row_size; + weight_kh += 3 * channel; + } + if (per_channel) { + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), + -right_shift[c]); + tmp_buffer[c] += out_zp; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); + *output_tmp++ = (tmp_buffer[c]); + } + } else { + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[0]), out_multiplier[0]), + -right_shift[0]); + tmp_buffer[c] += out_zp; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); + *output_tmp++ = (tmp_buffer[c]); + } + } + output += channel; + buffer += col_size * stride; + } +} + +void ConvDw3x3Int8Block(int8_t *output, const int8_t *buffer, const int16_t *weight, const int32_t *bias, int start_c, + int end_c, int col_size, int row_size, int channel, int output_h, int output_w, int8_t in_zp, + int32_t out_zp, const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, int32_t acc_min, int32_t acc_max, int stride, bool per_channel) { + for (; start_c <= end_c - 8; start_c += 8) { +#ifdef ENABLE_ARM64 + if (stride == 1) { + ConvDw3x3Int8Neon64(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max, per_channel); + } else { + ConvDw3x3Int8Stride2(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max, per_channel); + } + +#else + ConvDw3x3Int8Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max, stride, per_channel); +#endif + output += 8; + buffer += 8; + weight += 8; + bias += 8; + if (per_channel) { + out_multiplier += 8; + left_shift += 8; + right_shift += 8; + } + } +} + +void ConvDw3x3Int8Row(int8_t *output, int8_t *buffer, const int8_t *input, const int16_t *weight, const int32_t *bias, + const ConvParameter *conv_param, int start_w, int end_w, int block_output_h, int block_output_w, + int block_input_h, int block_input_w) { + bool filter_per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; + int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; + int in_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; + int out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; + int acc_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int acc_max = conv_param->conv_quant_arg_.out_act_max_[0]; + + const int ih_offset = 64 * block_input_w; + int w = start_w; + if (conv_param->output_channel_ > 64 || (conv_param->output_channel_ < 64 && conv_param->input_w_ > 150)) { + for (; w <= end_w - block_output_w; w += block_output_w) { + int8_t *output_ptr = output; + const int8_t *input_ptr = input; + const int16_t *weight_ptr = weight; + const int32_t *bias_ptr = bias; + int32_t *out_multiplier_ptr = out_multiplier; + int32_t *left_shift_ptr = left_shift; + int32_t *right_shift_ptr = right_shift; + int c = 0; + for (; c <= conv_param->output_channel_ - 64; c += 64) { + ConvDw3x3Int8InitBuffer(buffer, input_ptr, conv_param, block_input_h, block_input_w); + ConvDw3x3Int8Block(output_ptr, buffer, weight_ptr, bias_ptr, 0, 64, 64, ih_offset, conv_param->input_channel_, + block_output_h, block_output_w, in_zp, out_zp, out_multiplier_ptr, left_shift_ptr, + right_shift_ptr, acc_min, acc_max, conv_param->stride_h_, filter_per_channel); + output_ptr += 64; + input_ptr += 64; + weight_ptr += 64; + bias_ptr += 64; + if (filter_per_channel) { + out_multiplier_ptr += 64; + left_shift_ptr += 64; + right_shift_ptr += 64; + } + } + // left channel + ConvDw3x3Int8Block(output_ptr, input_ptr, weight_ptr, bias_ptr, c, conv_param->input_channel_, + conv_param->input_channel_, conv_param->input_w_ * conv_param->input_channel_, + conv_param->input_channel_, block_output_h, block_output_w, in_zp, out_zp, out_multiplier_ptr, + left_shift_ptr, right_shift_ptr, acc_min, acc_max, conv_param->stride_h_, filter_per_channel); + output += block_output_w * conv_param->input_channel_; + input += conv_param->stride_w_ * block_output_w * conv_param->input_channel_; + } + } + // left width + int left_width = end_w - w; + if (left_width > 0) { + ConvDw3x3Int8Block(output, input, weight, bias, 0, conv_param->input_channel_, conv_param->input_channel_, + conv_param->input_w_ * conv_param->input_channel_, conv_param->input_channel_, block_output_h, + left_width, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, + conv_param->stride_h_, filter_per_channel); + } +} + +void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + int output_h = sliding->bottom_ - sliding->top_; + int step_oh = UP_DIV(output_h, conv_param->thread_num_); + int start_oh = step_oh * task_id + sliding->top_; + int end_oh = MSMIN(start_oh + step_oh, sliding->bottom_); + int start_ow = sliding->left_; + int end_ow = sliding->right_; + + const int block_output_h = 1; + int block_output_w = conv_param->stride_w_ == 1 ? 30 : 14; + const int block_input_h = 3; + int block_input_w = conv_param->stride_w_ * (block_output_w - 1) + 3; + + for (int b = 0; b < conv_param->output_batch_; b++) { + int start_ih = start_oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_iw = start_ow * conv_param->stride_w_ - conv_param->pad_l_; + const int8_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_ + + start_ih * conv_param->input_w_ * conv_param->input_channel_ + + start_iw * conv_param->input_channel_; + int8_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_ + + start_oh * conv_param->output_w_ * conv_param->output_channel_ + + start_ow * conv_param->output_channel_; + + for (int oh = start_oh; oh < end_oh; oh++) { + ConvDw3x3Int8Row(dst, buffer, src, weight_data, bias_data, conv_param, start_ow, end_ow, block_output_h, + block_output_w, block_input_h, block_input_w); + src += conv_param->stride_h_ * conv_param->input_w_ * conv_param->input_channel_; + dst += conv_param->output_w_ * conv_param->output_channel_; + } + } +} + +#ifndef ENABLE_ARM32 +void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, + int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + const int32_t acc_min, const int32_t acc_max, bool per_channel) { + for (int c = 0; c < channel; c += 8) { + int tmp_buffer[8]; + for (int i = 0; i < 8; i++) { + tmp_buffer[i] = 0; + } + const int8_t *src_kh = src; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + const int8_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int i = 0; i < 8; i++) { + tmp_buffer[i] += (src_kw[c + i] - in_zp) * weight_kw[c + i]; + } + src_kw += in_kw_step; + weight_kw += channel; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += 3 * channel; + } // kernel_h loop + if (per_channel) { + for (int i = 0; i < 8; i++) { + tmp_buffer[i] += bias[c + i]; + tmp_buffer[i] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[i] * (1 << (unsigned int)left_shift[i]), out_multiplier[i]), + -right_shift[i]); + tmp_buffer[i] += out_zp; + tmp_buffer[i] = MSMAX(tmp_buffer[i], acc_min); + tmp_buffer[i] = MSMIN(tmp_buffer[i], acc_max); + dst[i] = (tmp_buffer[i]); + } + left_shift += 8; + right_shift += 8; + out_multiplier += 8; + } else { + for (int i = 0; i < 8; i++) { + tmp_buffer[i] += bias[c + i]; + tmp_buffer[i] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[i] * (1 << (unsigned int)left_shift[0]), out_multiplier[0]), + -right_shift[0]); + tmp_buffer[i] += out_zp; + tmp_buffer[i] = MSMAX(tmp_buffer[i], acc_min); + tmp_buffer[i] = MSMIN(tmp_buffer[i], acc_max); + dst[i] = (tmp_buffer[i]); + } + } + dst += 8; + } +} +#endif + +#ifndef ENABLE_ARM64 +void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step, + int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, const int32_t *out_multiplier, + const int32_t *left_shift, const int32_t *right_shift, int32_t acc_min, int32_t acc_max, + bool per_channel) { + ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 2, 2, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier, + left_shift, right_shift, acc_min, acc_max, per_channel); +} + +void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step, + int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, const int32_t *out_multiplier, + const int32_t *left_shift, const int32_t *right_shift, int32_t acc_min, int32_t acc_max, + bool per_channel) { + ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 2, 3, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier, + left_shift, right_shift, acc_min, acc_max, per_channel); +} + +void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step, + int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, const int32_t *out_multiplier, + const int32_t *left_shift, const int32_t *right_shift, int32_t acc_min, int32_t acc_max, + bool per_channel) { + ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 3, 2, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier, + left_shift, right_shift, acc_min, acc_max, per_channel); +} +#endif + +void ConvDw3x3Int8Pad(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + bool filter_per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; + int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; + int in_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; + int out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; + int acc_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int acc_max = conv_param->conv_quant_arg_.out_act_max_[0]; + int input_row_size = conv_param->input_w_ * conv_param->input_channel_; + int weight_row_size = conv_param->kernel_w_ * conv_param->input_channel_; + int output_row_size = conv_param->output_w_ * conv_param->output_channel_; + int in_kh_step = sliding->in_kh_step_; + int in_kw_step = sliding->in_kw_step_; + + // top + for (int b = 0; b < conv_param->output_batch_; b++) { + const int8_t *input_batch = + input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + int8_t *output_batch = + output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + + const int8_t *input = input_batch; + const int16_t *weight = weight_data + weight_row_size + conv_param->input_channel_; + int8_t *output = output_batch; + ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + input += (conv_param->stride_w_ - 1) * conv_param->input_channel_; + weight = weight_data + weight_row_size; + output += conv_param->output_channel_; + for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { + ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + input += conv_param->stride_w_ * conv_param->input_channel_; + output += conv_param->output_channel_; + } + ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + + // left + input = input_batch + (conv_param->stride_h_ - 1) * input_row_size; + weight = weight_data + conv_param->input_channel_; + output = output_batch + output_row_size; + for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { + ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, + in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, + filter_per_channel); + input += conv_param->stride_h_ * input_row_size; + output += output_row_size; + } + + // right + input = input_batch + (conv_param->input_w_ - 2) * conv_param->input_channel_ + + (conv_param->stride_h_ - 1) * input_row_size; + weight = weight_data; + output = output_batch + output_row_size + (conv_param->output_w_ - 1) * conv_param->output_channel_; + for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { + ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, + in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, + filter_per_channel); + input += conv_param->stride_h_ * input_row_size; + output += output_row_size; + } + + // bottom + input = input_batch + (conv_param->input_h_ - 2) * input_row_size; + weight = weight_data + conv_param->input_channel_; + output = output_batch + (conv_param->output_h_ - 1) * output_row_size; + ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + input += conv_param->stride_w_ == 1 ? 0 : conv_param->input_channel_; + weight = weight_data; + output += conv_param->output_channel_; + for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { + ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + input += conv_param->stride_w_ * conv_param->input_channel_; + output += conv_param->output_channel_; + } + ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + } +} +/*conv depthwise 3x3 int8 end*/ + +/*conv depthwise sliding window perchannel int8 begin*/ +void ConvDwInt8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, + int width, int in_kh_step, int in_kw_step, int kernel_w, const int8_t *input_zp, + const int32_t *out_zp, const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *acc_min, const int32_t *acc_max) { + int tmp_buffer[C8NUM]; + for (int i = 0; i < C8NUM; i++) { + tmp_buffer[i] = 0; + } + const int8_t *src_kh = src; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + const int8_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += (src_kw[c] - input_zp[c]) * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop + + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), + -right_shift[c]); + tmp_buffer[c] += out_zp[c]; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]); + dst[c] = (tmp_buffer[c]); + } +} + +void ConvDwInt8Border(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int top, int bottom, + int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + const int8_t *in_zp, const int32_t *out_zp, const int32_t *out_multiplier, + const int32_t *left_shift, const int32_t *right_shift, const int32_t *acc_min, + const int32_t *acc_max) { + int8_t *dst_h = dst + top * sliding->out_h_step_; + for (int oh = top; oh < bottom; oh++) { + int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const int8_t *src_h = src + ih * sliding->in_h_step_; + + int8_t *dst_kernel = dst_h + left * sliding->block_channel_; + for (int ow = left; ow < right; ow++) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const int8_t *src_w = src_h + iw * sliding->block_channel_; + + const int8_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM; + + ConvDwInt8BorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max); + + dst_kernel += sliding->block_channel_; + } // width loop + dst_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM +void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, + int in_kh_step, int in_kw_step, const int8_t *in_zp, const int32_t *out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + const int32_t *acc_min, const int32_t *acc_max) { + int tmp_buffer[C8NUM]; + int8_t *dst_h = dst; + const int8_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + int8_t *dst_w = dst_h; + const int8_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + const int8_t *src_kh = src_w; + const int16_t *weight_kh = weight; + + for (int i = 0; i < C8NUM; i++) { + tmp_buffer[i] = 0; + } + for (int kh = 0; kh < kernel_h; kh++) { + const int8_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += (src_kw[c] - in_zp[c]) * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop + // add bias relu + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), + -right_shift[c]); + tmp_buffer[c] += out_zp[c]; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]); + dst_w[c] = (tmp_buffer[c]); + } + dst_w += block_channel; + src_w += in_sw_step; + } // dst_width loop + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} +#endif + +void ConvDwInt8SW(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data, + const int8_t *input_zp, const int32_t *output_zp, const ConvParameter *conv_param, + const SlidingWindowParam *sliding, int task_id) { + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_w_); + const int8_t *src = input_data; + int8_t *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const int8_t *src_data = src + oc * C8NUM; + int8_t *dst_data = dst + oc * C8NUM; + const int16_t *weight = weight_data + oc * sliding->kernel_step_; + const int32_t *bias = bias_data + oc * C8NUM; + + int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc * C8NUM; + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_ + oc * C8NUM; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_ + oc * C8NUM; + int32_t *acc_min = conv_param->conv_quant_arg_.out_act_min_ + oc * C8NUM; + int32_t *acc_max = conv_param->conv_quant_arg_.out_act_max_ + oc * C8NUM; + const int8_t *in_zp = input_zp + oc * C8NUM; + const int32_t *out_zp = output_zp + oc * C8NUM; + + ConvDwInt8Border(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, + sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); + ConvDwInt8Border(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, + conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, + right_shift, acc_min, acc_max); + ConvDwInt8Border(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param, + sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); + ConvDwInt8Border(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, + right_shift, acc_min, acc_max); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + const int8_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; + int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; + ConvDwInt8Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); + } + } // output C8 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nhwc8 +} +/*conv depthwise sliding window perchannel int8 end*/ + +/*deconv depthwise int8 begin*/ +void DeconvDwInt8BorderPixel(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width, + int in_kh_step, int in_kw_step, int kernel_w) { + int32_t *dst_kh = dst; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + int32_t *dst_kw = dst_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src[c] * weight_kw[c]; + } + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop +} + +void DeconvDwInt8Border(int32_t *dst, const int16_t *src, const int16_t *weight, int top, int bottom, int left, + int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + const int16_t *src_h = src + top * sliding->out_h_step_; + for (int ih = top; ih < bottom; ih++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int32_t *dst_h = dst + oh * sliding->in_h_step_; + + const int16_t *src_kernel = src_h + left * sliding->block_channel_; + for (int iw = left; iw < right; iw++) { + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + int32_t *dst_w = dst_h + ow * C4NUM; + + const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; + int32_t *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + + DeconvDwInt8BorderPixel(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_); + src_kernel += sliding->block_channel_; + } // width loop + src_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM +void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width, int kernel_h, + int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, int in_kh_step, + int in_kw_step) { + int32_t *dst_h = dst; + const int16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + int32_t *dst_w = dst_h; + const int16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + int32_t *dst_kh = dst_w; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < kernel_h; kh++) { + int32_t *dst_kw = dst_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src_w[c] * weight_kw[c]; + } + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} +#endif + +#ifndef ENABLE_ARM +void DeconvDwInt8Post(int8_t *dst, int32_t *output_buffer, const int32_t *bias, int block_channel, int pixel_nums, + int out_multiplier, int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, + int32_t acc_max) { + int8_t *dst_k = dst; + int32_t *buffer_k = output_buffer; + for (int k = 0; k < pixel_nums; k++) { + for (int c = 0; c < C4NUM; c++) { + buffer_k[c] += bias[c]; + buffer_k[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(buffer_k[c] * (1 << (unsigned int)left_shift), out_multiplier), -right_shift); + buffer_k[c] += out_zp; + buffer_k[c] = MSMAX(buffer_k[c], acc_min); + buffer_k[c] = MSMIN(buffer_k[c], acc_max); + dst_k[c] = (buffer_k[c]); + } + dst_k += block_channel; + buffer_k += C4NUM; + } +} +#endif + +void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + const int16_t *src = input_data; + int8_t *dst = output_data; + int buffer_size = conv_param->output_h_ * conv_param->output_w_ * C4NUM; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + memset(output_buffer, 0, buffer_size * sizeof(int32_t)); + const int16_t *src_data = src + oc * C4NUM; + const int16_t *weight = weight_data + oc * sliding->kernel_step_; + const int32_t *bias = bias_data + oc * C4NUM; + int8_t *dst_data = dst + oc * C4NUM; + DeconvDwInt8Border(output_buffer, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, + sliding); + DeconvDwInt8Border(output_buffer, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, + conv_param->input_w_, conv_param, sliding); + DeconvDwInt8Border(output_buffer, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, + conv_param, sliding); + DeconvDwInt8Border(output_buffer, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->input_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + int32_t *out_t = output_buffer + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; + const int16_t *in_t = + src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#ifdef ENABLE_ARM + DeconvDwInt8Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(int16_t), + sliding->block_channel_ * sizeof(int16_t), sliding->in_sh_step_ * sizeof(int32_t), + sliding->in_sw_step_ * sizeof(int32_t), sliding->in_kh_step_ * sizeof(int32_t), + sliding->in_kw_step_ * sizeof(int32_t)); +#else + DeconvDwInt8Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_); +#endif + } + DeconvDwInt8Post(dst_data, output_buffer, bias, sliding->block_channel_, + conv_param->output_h_ * conv_param->output_w_, conv_param->conv_quant_arg_.quant_multiplier_[0], + conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); + } // output C4 loop + src += sliding->out_step_; + dst += sliding->in_step_; + } // batch loop + // output nhwc4 +} +/*deconv depthwise int8 end*/ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_depthwise_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_depthwise_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..60150821a09adeb2c93401f5955d61e5f92f59d5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_depthwise_int8.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_CONV_DEPTHWISE_H_ +#define NNACL_INT8_CONV_DEPTHWISE_H_ + +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, int task_id); + +void ConvDw3x3Int8Pad(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding); + +void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +void ConvDwInt8SW(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data, + const int8_t *input_zp, const int32_t *output_zp, const ConvParameter *conv_param, + const SlidingWindowParam *sliding, int task_id); + +void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CONV_DEPTHWISE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..9904230cf3e167ecff168586af5b83dafccd9b90 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.c @@ -0,0 +1,913 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/conv_int8.h" + +#ifdef ENABLE_ARM32 +void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, const int32_t *filter_zp_ptr, + size_t plane_size, size_t input_channel, size_t output_channel) { + size_t hw4 = UP_ROUND(plane_size, C4NUM); + size_t ic16 = UP_ROUND(input_channel, C16NUM); + +#ifdef ENABLE_ARM32 + size_t oc_div2 = output_channel / C2NUM * C2NUM; + size_t oc_res2 = output_channel - oc_div2; + size_t inputsun_stride = hw4 * C2NUM * 4 - C4NUM * C2NUM * 4; + PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div2, oc_res2, inputsun_stride); +#else + for (int ri = 0; ri < plane_size; ri++) { + int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; + for (int ci = 0; ci < output_channel; ci++) { + int32_t tmp_sum_value = 0; + int ci2div = ci / C2NUM, ci2mod = ci % C2NUM; + int32_t filter_zp = filter_zp_ptr[ci]; + for (int di = 0; di < input_channel; di++) { + size_t di16div = di / C16NUM, di16mod = di % C16NUM; + int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod; + tmp_sum_value += input_value[src_index]; + } + int dst_index = ci2div * C2NUM * hw4 + ri * C2NUM + ci2mod; + input_sum[dst_index] = tmp_sum_value * filter_zp; + } + } +#endif + return; +} +#endif + +void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, const int32_t *filter_zp_ptr, + size_t plane_size, size_t input_channel, size_t output_channel) { + size_t hw4 = UP_ROUND(plane_size, C4NUM); + size_t ic16 = UP_ROUND(input_channel, C16NUM); +#ifdef ENABLE_ARM64 + size_t oc_div4 = output_channel / C4NUM * C4NUM; + size_t oc_res4 = output_channel - oc_div4; + size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4; + PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div4, oc_res4, inputsun_stride); +#else + + for (int ri = 0; ri < plane_size; ri++) { + int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; + for (int ci = 0; ci < output_channel; ci++) { + int32_t tmp_sum_value = 0; + int ci4div = ci / C4NUM, ci4mod = ci % C4NUM; + int32_t filter_zp = filter_zp_ptr[ci]; + for (int di = 0; di < input_channel; di++) { + size_t di16div = di / C16NUM, di16mod = di % C16NUM; + int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod; + tmp_sum_value += input_value[src_index]; + } + int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod; + input_sum[dst_index] = tmp_sum_value * filter_zp; + } + } +#endif + return; +} + +void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, + size_t output_channel, size_t plane_size, const int32_t *filter_zp, size_t inputsum_stride) { + int ic4 = UP_ROUND(input_channel, C4NUM); + int oc8 = UP_ROUND(output_channel, C8NUM); + int hw8 = UP_ROUND(plane_size, C8NUM); + size_t hw_8div = plane_size / C8NUM * C8NUM; + size_t oc_8div = output_channel / C8NUM * C8NUM; + size_t oc_8res = output_channel - oc_8div; + size_t ic_4div = input_channel / C4NUM * C4NUM; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + int32_t *input_sum_r = input_sum; + + for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + int32_t *input_sum_oc = input_sum_r; +#ifdef ENABLE_ARM64 + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4; + asm volatile( + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x0, #0 \n" + "1: \n" + "cmp x0, %[ic_4div] \n" + "add x0, x0, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + "ld1 {v1.s}[0], [x12], %[src_stride]\n" + "ld1 {v1.s}[1], [x12], %[src_stride]\n" + "ld1 {v1.s}[2], [x12], %[src_stride]\n" + "ld1 {v1.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 1b \n" + + "3: \n" /* col res 1 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + "ld1 {v1.b}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[8], [x12], %[src_stride]\n" + "ld1 {v1.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "4: \n" /* col res 2 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "5: \n" /* col res 3 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[2], [x13], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.b}[6], [x13], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[10], [x13], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "6: \n" + "dup v0.4s, v16.s[0] \n" + "dup v1.4s, v16.s[1] \n" + "dup v2.4s, v16.s[2] \n" + "dup v3.4s, v16.s[3] \n" + "dup v4.4s, v17.s[0] \n" + "dup v5.4s, v17.s[1] \n" + "dup v6.4s, v17.s[2] \n" + "dup v7.4s, v17.s[3] \n" + "mov x4, #0 \n" + "mov x10, %[filter_zp] \n" + "mov x11, %[input_sum_oc] \n" + + "7: \n" + "cmp x4, %[oc_8div] \n" + "beq 8f \n" + "add x4, x4, #8\n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.4s}, [x10], #16\n" + + "mul v18.4s, v16.4s, v0.4s \n" + "mul v19.4s, v17.4s, v0.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + + "mul v20.4s, v16.4s, v1.4s \n" + "mul v21.4s, v17.4s, v1.4s \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + + "mul v22.4s, v16.4s, v2.4s \n" + "mul v23.4s, v17.4s, v2.4s \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + + "mul v24.4s, v16.4s, v3.4s \n" + "mul v25.4s, v17.4s, v3.4s \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "mul v18.4s, v16.4s, v4.4s \n" + "mul v19.4s, v17.4s, v4.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + + "mul v20.4s, v16.4s, v5.4s \n" + "mul v21.4s, v17.4s, v5.4s \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + + "mul v22.4s, v16.4s, v6.4s \n" + "mul v23.4s, v17.4s, v6.4s \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + + "mul v24.4s, v16.4s, v7.4s \n" + "mul v25.4s, v17.4s, v7.4s \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "add x11, x11, %[input_sum_stride] \n" + "b 7b \n" + + "8: \n" + "cmp %[oc_8res], #0\n" + "beq 17f \n" + + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + "cmp %[oc_8res], #1\n" + "beq 9f \n" + "cmp %[oc_8res], #2\n" + "beq 10f \n" + "cmp %[oc_8res], #3\n" + "beq 11f \n" + "cmp %[oc_8res], #4\n" + "beq 12f \n" + "cmp %[oc_8res], #5\n" + "beq 13f \n" + "cmp %[oc_8res], #6\n" + "beq 14f \n" + "cmp %[oc_8res], #7\n" + "beq 15f \n" + + "9: \n" + "ld1 {v16.s}[0], [x10] \n" + "b 16f \n" + + "10: \n" + "ld1 {v16.d}[0], [x10] \n" + "b 16f \n" + + "11: \n" + "ld1 {v16.d}[0], [x10] \n" + "add x10, x10, #8 \n" + "ld1 {v16.s}[2], [x10] \n" + "b 16f \n" + + "12: \n" + "ld1 {v16.4s}, [x10] \n" + "b 16f \n" + + "13: \n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.s}[0], [x10] \n" + "b 16f \n" + + "14: \n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.d}[0], [x10] \n" + "b 16f \n" + + "15: \n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.d}[0], [x10] \n" + "add x10, x10, #8 \n" + "ld1 {v17.s}[2], [x10] \n" + "b 16f \n" + + "16: \n" + "mul v18.4s, v16.4s, v0.4s \n" + "mul v19.4s, v17.4s, v0.4s \n" + "mul v20.4s, v16.4s, v1.4s \n" + "mul v21.4s, v17.4s, v1.4s \n" + "mul v22.4s, v16.4s, v2.4s \n" + "mul v23.4s, v17.4s, v2.4s \n" + "mul v24.4s, v16.4s, v3.4s \n" + "mul v25.4s, v17.4s, v3.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "mul v18.4s, v16.4s, v4.4s \n" + "mul v19.4s, v17.4s, v4.4s \n" + "mul v20.4s, v16.4s, v5.4s \n" + "mul v21.4s, v17.4s, v5.4s \n" + "mul v22.4s, v16.4s, v6.4s \n" + "mul v23.4s, v17.4s, v6.4s \n" + "mul v24.4s, v16.4s, v7.4s \n" + "mul v25.4s, v17.4s, v7.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "17: \n" + + : + : [src_ic] "r"(src_ic), [pack_ic] "r"(pack_ic), [filter_zp] "r"(filter_zp), [input_sum_oc] "r"(input_sum_oc), + [input_sum_stride] "r"(input_sum_stride), [src_stride] "r"(src_stride), [ic_4div] "r"(ic_4div), + [ic_4res] "r"(ic_4res), [oc_8div] "r"(oc_8div), [oc_8res] "r"(oc_8res) + : "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25"); +#else + int32_t tmp_sum_value[8] = {0}; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[0 + i * input_channel]; + tmp_sum_value[i] += src_ic[1 + i * input_channel]; + tmp_sum_value[i] += src_ic[2 + i * input_channel]; + tmp_sum_value[i] += src_ic[3 + i * input_channel]; + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[i * input_channel]; + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (int ici = input_channel; ici < ic4; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + pack_ic[i * C4NUM] = 0; + } + pack_ic += 1; + } + + for (int oci = 0; oci < oc_8div; oci += C8NUM) { + for (int ri = 0; ri < C8NUM; ri++) { + input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0]; + input_sum_oc[ri * C8NUM + 1] = tmp_sum_value[ri] * filter_zp[oci + 1]; + input_sum_oc[ri * C8NUM + 2] = tmp_sum_value[ri] * filter_zp[oci + 2]; + input_sum_oc[ri * C8NUM + 3] = tmp_sum_value[ri] * filter_zp[oci + 3]; + input_sum_oc[ri * C8NUM + 4] = tmp_sum_value[ri] * filter_zp[oci + 4]; + input_sum_oc[ri * C8NUM + 5] = tmp_sum_value[ri] * filter_zp[oci + 5]; + input_sum_oc[ri * C8NUM + 6] = tmp_sum_value[ri] * filter_zp[oci + 6]; + input_sum_oc[ri * C8NUM + 7] = tmp_sum_value[ri] * filter_zp[oci + 7]; + } + input_sum_oc += inputsum_stride; + } + if (oc_8div != output_channel) { + for (int oci = 0; oci < oc_8res; oci += 1) { + for (int ri = 0; ri < C8NUM; ri++) { + input_sum_oc[ri * C8NUM + oci] = tmp_sum_value[ri] * filter_zp[oc_8div + oci]; + } + } + for (int oci = oc_8res; oci < C8NUM; oci += 1) { + for (int ri = 0; ri < C8NUM; ri++) { + input_sum_oc[ri * C8NUM + oci] = 0; + } + } + } /* oc8 res done */ +#endif + src_r += input_channel * C8NUM; + pack_r += ic4 * C8NUM; + input_sum_r += C8NUM * C8NUM; + } + + if (hw_8div != plane_size) { + memset(pack_r, 0, C8NUM * ic4); + for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { + int32_t *input_sum_oc = input_sum_r; + int32_t tmp_sum_value = 0; + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + tmp_sum_value += src_ic[0]; + tmp_sum_value += src_ic[1]; + tmp_sum_value += src_ic[2]; + tmp_sum_value += src_ic[3]; + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + tmp_sum_value += src_ic[0]; + pack_ic[0] = src_ic[0]; + src_ic += 1; + pack_ic += 1; + } + + for (int oci = 0; oci < oc_8div; oci += C8NUM) { + for (int curoi = 0; curoi < C8NUM; curoi++) { + input_sum_oc[curoi] = tmp_sum_value * filter_zp[oci + curoi]; + } + input_sum_oc += inputsum_stride; + } + if (oc_8div != output_channel) { + for (int oci = 0; oci < oc_8res; oci += 1) { + input_sum_oc[oci] = tmp_sum_value * filter_zp[oc_8div + oci]; + } + for (int oci = oc_8res; oci < C8NUM; oci += 1) { + input_sum_oc[oci] = 0; + } + } /* oc8 res done */ + + src_r += input_channel; + pack_r += C4NUM; + input_sum_r += C8NUM; + } + + for (int hwi = plane_size; hwi < hw8; hwi++) { + for (int oc = 0; oc < oc8; oc++) { + int oc8div = oc / C8NUM, oc8res = oc % C8NUM; + input_sum[oc8div * inputsum_stride + hwi * C8NUM + oc8res] = 0; + } + } + } + return; +} + +void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, + size_t plane_size, const ConvParameter *conv_param) { + int ic4 = UP_ROUND(input_channel, C4NUM); + size_t hw_8div = plane_size / C8NUM * C8NUM; + size_t ic_4div = input_channel / C4NUM * C4NUM; + int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + /* per layer */ + for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + int32_t *input_sum_r = input_sum + hwi; +#ifdef ENABLE_ARM64 + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + asm volatile( + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + "mov x14, %[input_sum_r] \n" + "dup v20.4s, %w[filter_zp] \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x0, #0 \n" + "1: \n" + "cmp x0, %[ic_4div] \n" + "add x0, x0, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + "ld1 {v1.s}[0], [x12], %[src_stride]\n" + "ld1 {v1.s}[1], [x12], %[src_stride]\n" + "ld1 {v1.s}[2], [x12], %[src_stride]\n" + "ld1 {v1.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 1b \n" + + "3: \n" /* col res 1 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + "ld1 {v1.b}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[8], [x12], %[src_stride]\n" + "ld1 {v1.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "4: \n" /* col res 2 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "5: \n" /* col res 3 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[2], [x13], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.b}[6], [x13], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[10], [x13], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "6: \n" + "mul v16.4s, v16.4s, v20.4s \n" + "mul v17.4s, v17.4s, v20.4s \n" + + "st1 {v16.4s}, [x14], #16 \n" + "st1 {v17.4s}, [x14], #16 \n" + + : + : [src_ic] "r"(src_ic), [pack_ic] "r"(pack_ic), [input_sum_r] "r"(input_sum_r), [src_stride] "r"(src_stride), + [ic_4div] "r"(ic_4div), [ic_4res] "r"(ic_4res), [filter_zp] "r"(filter_zp) + : "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", + "v20"); +#else + int32_t tmp_sum_value[8] = {0}; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[0 + i * input_channel]; + tmp_sum_value[i] += src_ic[1 + i * input_channel]; + tmp_sum_value[i] += src_ic[2 + i * input_channel]; + tmp_sum_value[i] += src_ic[3 + i * input_channel]; + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[i * input_channel]; + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (int ici = input_channel; ici < ic4; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + pack_ic[i * C4NUM] = 0; + } + pack_ic += 1; + } + + for (int i = 0; i < C8NUM; i++) { + input_sum_r[i] = tmp_sum_value[i] * filter_zp; + } +#endif + src_r += input_channel * C8NUM; + pack_r += ic4 * C8NUM; + } + + if (hw_8div != plane_size) { + memset(pack_r, 0, C8NUM * ic4); + for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { + int32_t tmp_sum_value = 0; + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + tmp_sum_value += src_ic[0]; + tmp_sum_value += src_ic[1]; + tmp_sum_value += src_ic[2]; + tmp_sum_value += src_ic[3]; + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + tmp_sum_value += src_ic[0]; + pack_ic[0] = src_ic[0]; + src_ic += 1; + pack_ic += 1; + } + input_sum[hwi] = tmp_sum_value * filter_zp; + src_r += input_channel; + pack_r += C4NUM; + } + for (int hwi = plane_size; hwi < UP_ROUND(plane_size, C8NUM); hwi++) { + input_sum[hwi] = 0; + } + } + return; +} + +void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, const int32_t *filter_zp, + const ConvParameter *conv_param) { + size_t hw = conv_param->output_h_ * conv_param->output_w_; + size_t hw4 = UP_ROUND(hw, C4NUM); + size_t ic16 = UP_ROUND(conv_param->input_channel_, C16NUM); + if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { + PackInputSum16x4PerLayer(input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16); + } else { +#ifdef ENABLE_ARM32 + PackInputSum16x4PerChannelArm32(input, input_sum, filter_zp, hw, conv_param->input_channel_, + conv_param->output_channel_); +#else + PackInputSum16x4PerChannel(input, input_sum, filter_zp, hw, conv_param->input_channel_, + conv_param->output_channel_); +#endif + } + return; +} + +void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num, + int block_index, const int32_t *filter_zp, int32_t *input_sum, + const ConvParameter *conv_param, bool per_channel, bool is_optimize) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int kernel_plane = kernel_h * kernel_w; + NNACL_CHECK_ZERO_RETURN(out_w); + NNACL_CHECK_ZERO_RETURN(dilation_h); + NNACL_CHECK_ZERO_RETURN(dilation_w); + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + int input_stride = input_h * in_w * in_channel + input_w * in_channel; + int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); + int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h)); + int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); + int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); + if (kw_e <= kw_s || kh_e <= kh_s) { + continue; + } + if (dilation_w == 1 && dilation_h == 1) { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * in_w * in_channel + input_stride; + int input_x_stride = input_y_stride + kw_s * in_channel; + int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; + memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, (kw_e - kw_s) * in_channel); + } // kernel_h loop + } else { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; + for (int k = kw_s; k < kw_e; ++k) { + int input_x_stride = input_y_stride + k * dilation_w * in_channel; + int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane; + memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, in_channel); + } + } // kernel_h loop + } + } // tile num loop + int deep = kernel_plane * in_channel; + if (is_optimize) { + if (per_channel) { + Conv1x1PreOptPeroc(matmul_input, packed_input, input_sum, deep, conv_param->output_channel_, real_cal_num, + filter_zp, C8NUM * C8NUM); + } else { + Conv1x1PreOptPert(matmul_input, packed_input, input_sum, deep, real_cal_num, conv_param); + } + } else { + RowMajor2Row16x4MajorInt8(matmul_input, packed_input, real_cal_num, deep); + if (per_channel) { +#ifdef ENABLE_ARM32 + PackInputSum16x4PerChannelArm32(packed_input, input_sum, filter_zp, real_cal_num, deep, + conv_param->output_channel_); +#else + PackInputSum16x4PerChannel(packed_input, input_sum, filter_zp, real_cal_num, deep, conv_param->output_channel_); +#endif + } else { + size_t hw4 = UP_ROUND(real_cal_num, C4NUM); + size_t ic16 = UP_ROUND(deep, C16NUM); + PackInputSum16x4PerLayer(packed_input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, + ic16); + } + } +} + +void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight, + const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id, + ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func, bool is_optimize) { + int in_channel = conv_param->input_channel_; + int out_channel = conv_param->output_channel_; + int tile_n = conv_param->tile_num_; + int output_count = conv_param->output_h_ * conv_param->output_w_; + NNACL_CHECK_ZERO_RETURN(tile_n); + int output_tile_count = UP_DIV(output_count, tile_n); + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; + int unit_size; + int input_sum_offset; + int up_round_oc; +#ifdef ENABLE_ARM32 + up_round_oc = UP_ROUND(out_channel, C2NUM); + unit_size = UP_ROUND(kernel_plane * in_channel, C16NUM); +#else + if (is_optimize) { + up_round_oc = UP_ROUND(out_channel, C8NUM); + unit_size = UP_ROUND(kernel_plane * in_channel, C4NUM); + } else { + up_round_oc = UP_ROUND(out_channel, C4NUM); + unit_size = UP_ROUND(kernel_plane * in_channel, C16NUM); + } +#endif + bool per_channel = false; + if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) { + input_sum_offset = tile_n * up_round_oc; + per_channel = true; + } else { + input_sum_offset = tile_n; + per_channel = false; + } + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * out_channel * conv_param->output_h_ * conv_param->output_w_; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { + int start_index = thread_id * tile_n; + int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; + int32_t *tmp_input_sum = input_sum + task_id * input_sum_offset; + int8_t *gemm_input = packed_input + task_id * unit_size * tile_n; + int8_t *matmul = matmul_input + task_id * kernel_plane * in_channel * tile_n; + memset(matmul, conv_param->conv_quant_arg_.input_quant_args_[0].zp_, kernel_plane * in_channel * tile_n); + Im2ColPackUnitInt8Opt(input_data + in_batch_offset, gemm_input, matmul, real_cal_num, start_index, filter_zp, + tmp_input_sum, conv_param, per_channel, is_optimize); + + int out_offset = thread_id * tile_n * out_channel + out_batch_offset; + int8_t *gemm_output = output_data + out_offset; +#ifdef ENABLE_ARM32 + MatmulInt8Neon32( + gemm_input, packed_weight, gemm_output, real_cal_num, out_channel, unit_size, tmp_input_sum, bias_data, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.quant_multiplier_, + conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, out_channel, per_channel); +#elif ENABLE_ARM64 + if (is_optimize) { + matmul_func(gemm_input, packed_weight, gemm_output, real_cal_num, out_channel, unit_size, out_channel, + tmp_input_sum, bias_data, conv_param->conv_quant_arg_.left_shift_, + conv_param->conv_quant_arg_.right_shift_, conv_param->conv_quant_arg_.quant_multiplier_, + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], + conv_param->conv_quant_arg_.out_act_max_[0], per_channel); + } else { + MatmulInt8Neon64(gemm_input, packed_weight, gemm_output, UP_ROUND(real_cal_num, C4NUM), + UP_ROUND(out_channel, C4NUM), unit_size, tmp_input_sum, bias_data, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.quant_multiplier_, conv_param->conv_quant_arg_.left_shift_, + conv_param->conv_quant_arg_.right_shift_, real_cal_num, out_channel, out_channel, per_channel); + } +#else + MatMulInt8_8x8_r( + gemm_input, packed_weight, gemm_output, real_cal_num, out_channel, unit_size, out_channel, tmp_input_sum, + bias_data, conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, + conv_param->conv_quant_arg_.quant_multiplier_, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], per_channel); +#endif + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..18318489d5816d58a52a26f5af514a5c5531d6fc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_CONV_INT8_H_ +#define NNACL_INT8_CONV_INT8_H_ + +#include +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/int8/common_func_int8.h" + +#ifdef __cplusplus +extern "C" { +#endif +// int8 conv common +void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight, + const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id, + ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func, bool is_optimize); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CONV_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/crop_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/crop_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..56205628e68762c0109d9057b2ec9e95f041d4bc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/crop_int8.c @@ -0,0 +1,236 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/crop_parameter.h" +#include +#include +#include +#include "nnacl_c/int8/crop_int8.h" + +void Int8Crop1D(const int8_t *input, int8_t *output, int *output_shape, int64_t *in_offset, int task_id, + int thread_count, const CropQuantArg *quant) { + const int out_batch = output_shape[0]; + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_batch, thread_count) : out_batch; + if (task_id_stride <= 0) { + return; + } + + float in_scale = quant->in_args_.scale_; + int32_t in_zp = quant->in_args_.zp_; + float out_scale = quant->out_args_.scale_; + int32_t out_zp = quant->out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + int n = task_id * task_id_stride; + if (n >= out_batch) { + return; + } + const int8_t *in_ptr = input + n + in_offset[0]; + int8_t *out_ptr = output + n; + int64_t out_dist_stride = MSMIN(out_batch - task_id * task_id_stride, task_id_stride); + if (fabs(in_scale - out_scale) <= FLT_EPSILON && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_dist_stride); + } else { + for (int i = 0; i < out_dist_stride; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > quant->output_activation_max_) { + out_ptr[i] = quant->output_activation_max_; + } else if (output_tmp < quant->output_activation_min_) { + out_ptr[i] = quant->output_activation_min_; + } else { + out_ptr[i] = (int8_t)output_tmp; + } + } + } + return; +} + +void Int8Crop2D(const int8_t *input, int8_t *output, int *input_shape, int *output_shape, int64_t *in_offset, + int task_id, int thread_count, const CropQuantArg *quant) { + const int in_height = input_shape[1]; + const int out_batch = output_shape[0]; + const int out_height = output_shape[1]; + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + float in_scale = quant->in_args_.scale_; + int32_t in_zp = quant->in_args_.zp_; + float out_scale = quant->out_args_.scale_; + int32_t out_zp = quant->out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + for (int n = 0; n < out_batch; n++) { + int h = task_id * task_id_stride; + if (h >= out_height) { + return; + } + const int8_t *in_ptr = input + (n + in_offset[0]) * in_height + h + in_offset[1]; + int8_t *out_ptr = output + n * out_height + h; + int64_t out_dist_stride = MSMIN(out_height - task_id * task_id_stride, task_id_stride); + if (fabs(in_scale - out_scale) <= FLT_EPSILON && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_dist_stride); + } else { + for (int i = 0; i < out_dist_stride; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > quant->output_activation_max_) { + out_ptr[i] = quant->output_activation_max_; + } else if (output_tmp < quant->output_activation_min_) { + out_ptr[i] = quant->output_activation_min_; + } else { + out_ptr[i] = (int8_t)output_tmp; + } + } + } + } + return; +} + +void Int8Crop3D(const int8_t *input, int8_t *output, int *input_shape, int *output_shape, int64_t *in_offset, + int task_id, int thread_count, const CropQuantArg *quant) { + const int in_height = input_shape[1]; + const int in_width = input_shape[2]; + + const int out_batch = output_shape[0]; + const int out_height = output_shape[1]; + const int out_width = output_shape[2]; + + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + const int in_stride_h = in_width; + const int in_stride_n = in_stride_h * in_height; + + const int out_stride_h = out_width; + const int out_stride_n = out_stride_h * out_height; + + float in_scale = quant->in_args_.scale_; + int32_t in_zp = quant->in_args_.zp_; + float out_scale = quant->out_args_.scale_; + int32_t out_zp = quant->out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + for (int n = 0; n < out_batch; n++) { + for (int t = 0; t < task_id_stride; t++) { + int h = t + task_id * task_id_stride; + if (h >= out_height) { + break; + } + const int8_t *in_ptr = input + (n + in_offset[0]) * in_stride_n + (h + in_offset[1]) * in_stride_h + in_offset[2]; + int8_t *out_ptr = output + n * out_stride_n + h * out_stride_h; + if (fabs(in_scale - out_scale) <= FLT_EPSILON && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_width); + } else { + for (int i = 0; i < out_width; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > quant->output_activation_max_) { + out_ptr[i] = quant->output_activation_max_; + } else if (output_tmp < quant->output_activation_min_) { + out_ptr[i] = quant->output_activation_min_; + } else { + out_ptr[i] = (int8_t)output_tmp; + } + } + } + } + } + return; +} + +void Int8Crop4D(const int8_t *input, int8_t *output, int *input_shape, int *output_shape, int64_t *in_offset, + int task_id, int thread_count, const CropQuantArg *quant) { + const int in_height = input_shape[1]; + const int in_width = input_shape[2]; + const int in_channel = input_shape[3]; + + const int out_batch = output_shape[0]; + const int out_height = output_shape[1]; + const int out_width = output_shape[2]; + const int out_channel = output_shape[3]; + + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + const int in_stride_w = in_channel; + const int in_stride_h = in_channel * in_width; + const int in_stride_n = in_stride_h * in_height; + + const int out_stride_w = out_channel; + const int out_stride_h = out_channel * out_width; + const int out_stride_n = out_stride_h * out_height; + + float in_scale = quant->in_args_.scale_; + int32_t in_zp = quant->in_args_.zp_; + float out_scale = quant->out_args_.scale_; + int32_t out_zp = quant->out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + for (int n = 0; n < out_batch; n++) { + for (int t = 0; t < task_id_stride; t++) { + int h = t + task_id * task_id_stride; + if (h >= out_height) { + break; + } + for (int w = 0; w < out_width; w++) { + const int8_t *in_ptr = input + (n + in_offset[0]) * in_stride_n + (h + in_offset[1]) * in_stride_h + + (w + in_offset[2]) * in_stride_w + in_offset[3]; + int8_t *out_ptr = output + n * out_stride_n + h * out_stride_h + w * out_stride_w; + if (fabs(in_scale - out_scale) <= FLT_EPSILON && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_channel); + } else { + for (int i = 0; i < out_channel; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > quant->output_activation_max_) { + out_ptr[i] = quant->output_activation_max_; + } else if (output_tmp < quant->output_activation_min_) { + out_ptr[i] = quant->output_activation_min_; + } else { + out_ptr[i] = (int8_t)output_tmp; + } + } + } + } + } + } + return; +} + +void Int8Crop(const int8_t *input, int8_t *output, int *input_shape, int *output_shape, int64_t *in_offset, + int input_dim, int task_id, int thread_count, const CropQuantArg *quant) { + switch (input_dim) { + case 1: + Int8Crop1D(input, output, output_shape, in_offset, task_id, thread_count, quant); + break; + case 2: + Int8Crop2D(input, output, input_shape, output_shape, in_offset, task_id, thread_count, quant); + break; + case 3: + Int8Crop3D(input, output, input_shape, output_shape, in_offset, task_id, thread_count, quant); + break; + case 4: + Int8Crop4D(input, output, input_shape, output_shape, in_offset, task_id, thread_count, quant); + break; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/crop_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/crop_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..b4cc4f263e45c27abb981ad0a67fd28af2b0ae89 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/crop_int8.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_CROP_INT8_H_ +#define NNACL_INT8_CROP_INT8_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/crop_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Int8Crop(const int8_t *input, int8_t *output, int *input_shape, int *output_shape, int64_t *in_offset, + int input_dim, int task_id, int thread_count, const CropQuantArg *quant); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CROP_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/deconv_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/deconv_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..cc2994e2a12a00e657161a1ec60224430baaa0b3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/deconv_int8.c @@ -0,0 +1,150 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/deconv_int8.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/int8/common_func_int8.h" +int DeConvPostInt8C4(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + const ConvParameter *conv_param) { + /* row4x4-major(ih*iw x oc*kh*kw) -> row4-major(oh*ow x oc) */ + int input_plane = conv_param->input_w_ * conv_param->input_h_; + int kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + int output_plane = conv_param->output_w_ * conv_param->output_h_; + int oc4 = UP_DIV(output_channel, C4NUM); + int in_plane4 = UP_ROUND(input_plane, C4NUM); + + int src_iw_stride = C4NUM; + int src_ih_stride = conv_param->input_w_ * C4NUM; + int src_kw_stride = in_plane4 * C4NUM; + int src_kh_stride = in_plane4 * conv_param->kernel_w_ * C4NUM; + int dst_oh_stride = conv_param->output_w_ * C4NUM; + int dst_ow_stride = C4NUM; + int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C4NUM; + int dst_kw_stride = conv_param->dilation_w_ * C4NUM; + + for (int c = 0; c < oc4; c++) { + int32_t *dst_ptr = tmp + c * output_plane * C4NUM; + const int32_t *src_ptr = src + c * in_plane4 * kernel_plane * C4NUM; + memset(dst_ptr, 0, (size_t)output_plane * C4NUM * sizeof(int32_t)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + for (int kh = kh_start; kh < kh_end; kh++) { + for (int kw = kw_start; kw < kw_end; kw++) { + int src_index = ih * src_ih_stride + iw * src_iw_stride + kh * src_kh_stride + kw * src_kw_stride; + int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride; + int32_t *tmp_dst = dst_ptr + dst_index; + const int32_t *tmp_src = src_ptr + src_index; +#ifndef ENABLE_ARM64 + for (int i = 0; i < C4NUM; i++) { + tmp_dst[i] += tmp_src[i]; + } +#else + asm volatile( + "mov x0, %[tmp_src] \n" + "mov x1, %[tmp_dst] \n" + + "ld1 {v0.4s}, [x0] \n" + "ld1 {v1.4s}, [x1] \n" + + "add v0.4s, v0.4s, v1.4s \n" + + "st1 {v0.4s}, [x1] \n" + + : + : [tmp_src] "r"(tmp_src), [tmp_dst] "r"(tmp_dst) + : "x0", "x1", "v0", "v1"); +#endif + } /*kw*/ + } /*kh*/ + } /*iw*/ + } /*ih*/ + } /*oc*/ + + PostFuncInt8C4(tmp, bias, out, output_channel, (size_t)output_plane, conv_param->output_channel_, + conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], + conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); + return NNACL_OK; +} + +void DeConvWeightTransInt8(const int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane, + bool support_optimize_) { + /* optimize normal -> same layout */ + int ic16 = UP_ROUND(input_channel, C16NUM); + int oc4 = UP_ROUND(output_channel, C4NUM); + for (int ic = 0; ic < input_channel; ic++) { + int ic16div = ic / C16NUM, ic16mod = ic % C16NUM; + for (int oc = 0; oc < output_channel; oc++) { + int oc4div = oc / C4NUM, oc4mod = oc % C4NUM; + for (int hw = 0; hw < plane; hw++) { + int src_index = ic * output_channel * plane + hw * output_channel + oc; + int dst_index = hw * ic16 * oc4 + oc4div * ic16 * C4NUM + ic16div * C16NUM * C4NUM + oc4mod * C16NUM + ic16mod; + dst[dst_index] = src[src_index]; + } + } + } + return; +} + +void DeConvPackWeightSum(const int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep, + int col4, bool suppport_opt) { + int deep16 = UP_ROUND(deep, C16NUM); + int32_t zp_sum = filter_zp * input_zp * deep; + for (int c = 0; c < col4; c++) { + int c4div = c / C4NUM, c4mod = c % C4NUM; + int32_t value = 0; + for (int r = 0; r < deep; r++) { + int r16div = r / C16NUM, r16mod = r % C16NUM; + int src_index = c4div * deep16 * C4NUM + r16div * C4NUM * C16NUM + c4mod * C16NUM + r16mod; + value += weight[src_index]; + } + weight_sum[c] = zp_sum - value * input_zp; + } + return; +} + +void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16, + bool suppport_opt) { + /* optimize normal -> same layout */ + PackInputSum16x4PerLayer(src, dst, filter_zp, row4, col16); + return; +} + +int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, const int32_t *weight_sum, + const int32_t *input_sum, size_t act_row, size_t act_col, size_t act_deep, + const ConvParameter *conv_param, MATMUL_OPT_R4_FUNC matmul_func) { + if (matmul_func != NULL) { + matmul_func(input, weight, output, act_row, act_col, act_deep, input_sum, weight_sum); + } else { + MatMulInt8_16x4(input, weight, output, act_row, act_col, act_deep, input_sum, weight_sum); + } + return NNACL_OK; +} + +int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + ConvParameter *conv_param, bool support_optimize) { + /* optimize normal -> same layout (C4) */ + int error_code = DeConvPostInt8C4(src, bias, tmp, out, output_channel, conv_param); + return error_code; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/deconv_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/deconv_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..60a7a31d11d31672f79c346791589256c841976b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/deconv_int8.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_DECONV_H_ +#define NNACL_INT8_DECONV_H_ + +#include +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/int8/matmul_int8.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DeConvPackWeightSum(const int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep, + int col4, bool suppport_opt); +void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16, + bool suppport_opt); +void DeConvWeightTransInt8(const int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane, + bool support_optimize_); + +int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, const int32_t *weight_sum, + const int32_t *input_sum, size_t act_row, size_t act_col, size_t act_deep, + const ConvParameter *conv_param, MATMUL_OPT_R4_FUNC matmul_func); +int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + ConvParameter *conv_param, bool support_optimize); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DECONV_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/depth_to_space_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/depth_to_space_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..be317eeaf0599543bafdff40975f9264af57c6c3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/depth_to_space_int8.c @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/int8/depth_to_space_int8.h" +#include + +void DepthToSpaceForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, DepthToSpaceArgs *param, + QuantArg *in_quant_arg, QuantArg *out_quant_arg) { + int32_t block_size = param->block_size_; + int32_t in_shape_dim2 = in_shape[2]; + int32_t in_shape_dim1 = in_shape[1]; + int64_t copy_size = block_size * param->out_stride_dim2_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float scale = in_quant_arg->scale_ * output_inverse_scale; + float bias = -in_quant_arg->zp_ * scale; + int32_t output_zp = out_quant_arg->zp_; + for (int i = 0; i < in_shape[0]; ++i) { + int64_t in_offset_n = i * param->in_stride_dim0_; + int64_t out_offset_n = i * param->out_stride_dim0_; + for (int j = 0; j < in_shape_dim1; ++j) { + int64_t in_offset_h = in_offset_n + j * param->in_stride_dim1_; + int64_t out_offset_h = out_offset_n + j * block_size * param->out_stride_dim1_; + for (int k = 0; k < in_shape_dim2; ++k) { + int64_t in_offset_w = in_offset_h + k * param->in_stride_dim2_; + int64_t out_offset_w = out_offset_h + k * block_size * param->out_stride_dim2_; + for (int l = 0; l < block_size; ++l) { + int64_t out_offset = out_offset_w + l * param->out_stride_dim1_; + int64_t in_offset = in_offset_w + l * block_size * param->out_stride_dim2_; + for (int m = 0; m < copy_size; ++m) { + int32_t output_tmp = round(input[in_offset + m] * scale + bias) + output_zp; + output_tmp = output_tmp > 127 ? 127 : output_tmp; + output_tmp = output_tmp < -128 ? -128 : output_tmp; + output[out_offset + m] = output_tmp; + } + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/depth_to_space_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/depth_to_space_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..faaba5891b18836c110a1badb8e5fe9631783fa3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/depth_to_space_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_DEPTH_TO_SPACE_INT8_H_ +#define NNACL_INT8_DEPTH_TO_SPACE_INT8_H_ + +#include "nnacl_c/depth_to_space_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/kernel/depth_to_space.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DepthToSpaceForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, DepthToSpaceArgs *param, + QuantArg *in_quant_arg, QuantArg *out_quant_arg); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DEPTH_TO_SPACE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/div_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/div_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..2dc62cd3dfd68c1aceb2e54dbca41712f90e55c8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/div_int8.c @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/div_int8.h" + +int DivInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const DivQuantArg *para) { + int index = 0; + for (; index < real_dst_count; ++index) { + const int32_t input0_val = para->in0_args_.zp_ + input0_data[index]; + const int32_t input1_val = para->in1_args_.zp_ + input1_data[index]; + if (input1_val == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + + int recip_shift; + const int32_t input1_inv = (input1_val > 0) ? ComputerReciprocal(input1_val, 31, &recip_shift) + : -ComputerReciprocal(-input1_val, 31, &recip_shift); + const int leading_bits = CountLeadingSignBits(input0_val); + const int32_t raw_data = + SaturatingRoundingDoublingHighMul(input0_val * (1 << (unsigned int)leading_bits), input1_inv); + const int total_shift = para->output_shift_ - recip_shift - leading_bits; + const int32_t raw_output = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_data, para->output_multiplier_), -total_shift) + + para->out_args_.zp_; + output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_)); + } + return NNACL_OK; +} + +int DivScalarInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const DivQuantArg *para) { + int index = 0; + const int32_t input1_val = para->in1_args_.zp_ + *input1_data; + if (input1_val == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + int recip_shift; + const int32_t input1_inv = (input1_val > 0) ? ComputerReciprocal(input1_val, 31, &recip_shift) + : -ComputerReciprocal(-input1_val, 31, &recip_shift); + for (; index < real_dst_count; ++index) { + const int32_t input0_val = para->in0_args_.zp_ + input0_data[index]; + + const int leading_bits = CountLeadingSignBits(input0_val); + const int32_t raw_data = + SaturatingRoundingDoublingHighMul(input0_val * (1 << (unsigned int)leading_bits), input1_inv); + const int total_shift = para->output_shift_ - recip_shift - leading_bits; + const int32_t raw_output = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_data, para->output_multiplier_), -total_shift) + + para->out_args_.zp_; + output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_)); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/div_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/div_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..dcd6210d20476fd72a76901964f00290e057abe0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/div_int8.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_DIV_INT8_H_ +#define NNACL_INT8_DIV_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/fixed_point.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DivInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const DivQuantArg *para); + +int DivScalarInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const DivQuantArg *para); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DIV_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_gather_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_gather_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..a1aad3b2edcff5eb452a5aa4e6c4c4a41a61d708 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_gather_int8.c @@ -0,0 +1,76 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include "nnacl_c/int8/dynamic_gather_int8.h" +#include "nnacl_c/op_base.h" + +void DynamicGather(const int8_t *input, int outer_size, int inner_size, int limit, const int32_t *indices, + int indices_element_size, float *output, const float *scale_in, const int32_t *zp_in) { + for (int m = 0; m < outer_size; ++m) { + const int8_t *int8_in_m = input + inner_size * m * limit; + float *int8_out_m = output + inner_size * m * indices_element_size; + for (int i = 0; i < indices_element_size; ++i) { + int index = indices[i]; + index = index < 0 ? index + limit : index; + const float scale = scale_in[index]; + const int zp = zp_in[index]; + float *out = int8_out_m + i * inner_size; + const int8_t *src = int8_in_m + index * inner_size; +#ifndef ENABLE_ARM64 + for (int j = 0; j < inner_size; ++j) { + out[j] = (src[j] - zp) * scale; + } +#else + int count_16 = DOWN_ROUND(inner_size, C16NUM); + DynamicGatherArm64(src, out, count_16, zp, scale); + for (int j = count_16; j < inner_size; ++j) { + out[j] = (src[j] - zp) * scale; + } +#endif + } + } + return; +} + +#ifdef ENABLE_FP16 +void DynamicGatherForFp16(const int8_t *input, int outer_size, int inner_size, int limit, const int32_t *indices, + int indices_element_size, float16_t *output, const float *scale_in, const int32_t *zp_in) { + for (int m = 0; m < outer_size; ++m) { + const int8_t *int8_in_m = input + inner_size * m * limit; + float16_t *int8_out_m = output + inner_size * m * indices_element_size; + for (int i = 0; i < indices_element_size; ++i) { + int index = indices[i]; + index = index < 0 ? index + limit : index; + const float scale = scale_in[index]; + const int zp = zp_in[index]; + float16_t *out = int8_out_m + i * inner_size; + const int8_t *src = int8_in_m + index * inner_size; +#ifndef ENABLE_ARM64 + for (int j = 0; j < inner_size; ++j) { + out[j] = (float16_t)(src[j] - zp) * scale; + } +#else + int count_16 = DOWN_ROUND(inner_size, C16NUM); + DynamicGatherArm64ForFp16(src, out, count_16, zp, scale); + for (int j = count_16; j < inner_size; ++j) { + out[j] = (float16_t)((src[j] - zp) * scale); + } +#endif + } + } + return; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_gather_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_gather_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..b51491b7898b18927285894fd1a2d3fa2614144f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_gather_int8.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_DYNAMIC_GATHER_INT8_H_ +#define NNACL_INT8_DYNAMIC_GATHER_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DynamicGather(const int8_t *input, int outer_size, int inner_size, int limit, const int32_t *indices, + int indices_element_size, float *output, const float *scale_in, const int32_t *zp_in); +void DynamicGatherArm64(const int8_t *src, float *output, int count_16, int zp, float scale); + +#ifdef ENABLE_FP16 +void DynamicGatherForFp16(const int8_t *input, int outer_size, int inner_size, int limit, const int32_t *indices, + int indices_element_size, float16_t *output, const float *scale_in, const int32_t *zp_in); +void DynamicGatherArm64ForFp16(const int8_t *src, float16_t *output, int count_16, int zp, float scale); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DYNAMIC_GATHER_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_matmul_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_matmul_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..5c4d49d89aa4e4b21fe7eb61b97efdeb13739147 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_matmul_int8.c @@ -0,0 +1,420 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/dynamic_matmul_int8.h" +#include "nnacl_c/int8/fixed_point.h" + +void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, const float *multi_scales, + const float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums, + const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode) { + /* * + * row4x4-major * row4x16-major => (int8)row-major + * support activation per-layer symmetric && weight per-layer/per-channel symmetric + * */ + for (int r = 0; r < row; r++) { + int64_t s2 = a_sums[r]; + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c16div = c / C16NUM, c16mod = c % C16NUM; + int32_t s1 = 0; + for (int d = 0; d < deep4; d++) { + int d4div = d / C4NUM, d4mod = d % C4NUM; + size_t ai = r4div * deep4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod; + size_t bi = c16div * deep4 * C16NUM + d4div * C4NUM * C16NUM + c16mod * C4NUM + d4mod; + s1 += a[ai] * b[bi]; + } + int64_t s3 = b_sums[c] * a_zp; + int64_t s4 = a_zp * b_zp_sum; + size_t ci = r * stride / sizeof(float) + c; + int scale_offset = mode == 0 ? 0 : (mode == 1 ? c : (mode == C2NUM ? r : r * C16NUM + c)); + out[ci] = multi_scales[scale_offset] * (s1 - s2 - s3 + s4); + if (bias != NULL) { + out[ci] += bias[c]; + } + if (act_type == ActType_Relu) { + out[ci] = MSMAX(0, out[ci]); + } else if (act_type == ActType_Relu6) { + out[ci] = MSMAX(0, out[ci]); + out[ci] = MSMIN(C6NUM, out[ci]); + } + } + } + return; +} + +#ifdef ENABLE_FP16 +void DynamicMatmul4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, + const float *multi_scales, const float16_t *bias, size_t row, size_t col, + size_t stride, const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, + int64_t b_zp_sum, int64_t act_type, int64_t mode) { + /* * + * row4x4-major * row4x16-major => (int8)row-major + * support activation per-layer symmetric && weight per-layer/per-channel symmetric + * */ + for (int r = 0; r < row; r++) { + int64_t s2 = a_sums[r]; + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c16div = c / C16NUM, c16mod = c % C16NUM; + int32_t s1 = 0; + for (int d = 0; d < deep4; d++) { + int d4div = d / C4NUM, d4mod = d % C4NUM; + size_t ai = r4div * deep4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod; + size_t bi = c16div * deep4 * C16NUM + d4div * C4NUM * C16NUM + c16mod * C4NUM + d4mod; + s1 += a[ai] * b[bi]; + } + int64_t s3 = b_sums[c] * a_zp; + int64_t s4 = a_zp * b_zp_sum; + size_t ci = r * stride / sizeof(float16_t) + c; + int scale_offset = mode == 0 ? 0 : (mode == 1 ? c : (mode == C2NUM ? r : r * C16NUM + c)); + out[ci] = multi_scales[scale_offset] * (s1 - s2 - s3 + s4); + if (bias != NULL) { + out[ci] += bias[c]; + } + if (act_type == ActType_Relu) { + out[ci] = MSMAX(0, out[ci]); + } else if (act_type == ActType_Relu6) { + out[ci] = MSMAX(0, out[ci]); + out[ci] = MSMIN(C6NUM, out[ci]); + } + } + } + return; +} +#endif + +void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col, + int deep, int deep16, size_t stride, int input_zp, const float *input_scale, + const float *filter_scale, int filter_zp, bool input_per_channel, bool filter_per_channel, + int64_t act_type) { + /* * + * row4x16-major * row16x4-major => (int8)row-major + * support activation per-layer symmetric && weight per-layer/per-channel symmetric + * */ + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c4div = c / C4NUM, c4mod = c % C4NUM; + int32_t value = 0; + int32_t s0 = 0; + int32_t s1 = 0; + int32_t s2 = 0; + int32_t s3 = 0; + for (int d = 0; d < deep; d++) { + int d16div = d / C16NUM, d16mod = d % C16NUM; + size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; + size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod; + s0 += a[ai] * b[bi]; + s1 += filter_zp * a[ai]; + s2 += input_zp * b[bi]; + s3 += input_zp * filter_zp; + } + value = s0 - s1 - s2 + s3; + int input_quant_index = input_per_channel ? r : 0; + int filter_quant_index = filter_per_channel ? c : 0; + float multi_scale = input_scale[input_quant_index] * filter_scale[filter_quant_index]; + size_t ci = r * stride + c; + dst[ci] = multi_scale * value; + if (bias != NULL) { + dst[ci] += bias[c]; + } + if (act_type == ActType_Relu) { + dst[ci] = MSMAX(0, dst[ci]); + } else if (act_type == ActType_Relu6) { + dst[ci] = MSMAX(0, dst[ci]); + dst[ci] = MSMIN(C6NUM, dst[ci]); + } + } + } + return; +} + +#ifdef ENABLE_ARM64 +void PackInput4x4Asm(const int8_t *src_ic, int8_t *pack_ic, size_t ic_4div, size_t input_channel) { + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + asm volatile( + "dup v2.4s, wzr \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x15, #0 \n" + "1: \n" + "cmp x15, %[ic_4div] \n" + "add x15, x15, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "b 1b \n" + + "3: \n" /* ic res 1 */ + "dup v0.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "b 6f \n" + + "4: \n" /* ic res 2 */ + "dup v0.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "b 6f \n" + + "5: \n" /* ic res 3 */ + "dup v0.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "b 6f \n" + + "6: \n" + + : + : [src_ic] "r"(src_ic), [pack_ic] "r"(pack_ic), [src_stride] "r"(src_stride), [ic_4div] "r"(ic_4div), + [ic_4res] "r"(ic_4res) + : "x10", "x11", "x12", "x13", "x14", "x15", "v0", "v1", "v2", "v3"); +} +#endif + +void PackInput4x4(const int8_t *src_input, int8_t *packed_input, size_t input_channel, size_t plane_size) { + int ic4 = UP_ROUND(input_channel, C4NUM); + size_t hw_4div = plane_size / C4NUM * C4NUM; + size_t ic_4div = input_channel / C4NUM * C4NUM; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + /* per layer */ + for (int hwi = 0; hwi < hw_4div; hwi += C4NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; +#ifdef ENABLE_ARM64 + PackInput4x4Asm(src_ic, pack_ic, ic_4div, input_channel); +#else + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + for (size_t i = 0; i < C4NUM; i++) { + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C4NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + for (int i = 0; i < C4NUM; i++) { + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (int ici = input_channel; ici < ic4; ici += 1) { + for (int i = 0; i < C4NUM; i++) { + pack_ic[i * C4NUM] = 0; + } + pack_ic += 1; + } +#endif + src_r += input_channel * C4NUM; + pack_r += ic4 * C4NUM; + } + + if (hw_4div != plane_size) { + memset(pack_r, 0, C4NUM * ic4); + for (int hwi = hw_4div; hwi < plane_size; hwi += 1) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C4NUM; + } + src_r += input_channel; + pack_r += C4NUM; + } + } + return; +} + +// For matmul input a transpose case +void PackInput2Col4x4(const int8_t *src_input, int8_t *packed_input, int row, int col, int row_stride) { + const int row_tile = C4NUM; + int row_align = UP_ROUND(row, row_tile); + int row_div = row / row_tile * row_tile; + const int row_res = row - row_div; + + const int col_tile = C4NUM; + int col_div = col / col_tile * col_tile; + const int col_res = col - col_div; + + const int8_t *src_ic = NULL; + int8_t *packed_ic = NULL; + for (int c = 0; c < col_div; c += C4NUM) { + int r = 0; + src_ic = src_input + c; + packed_ic = packed_input + c * row_align; +#ifdef ENABLE_ARM64 + size_t row_stride_int64 = row_stride; + asm volatile( + "mov w10, %w[row]\n" + "mov x11, %[src_ic]\n" + "mov x12, %[packed_ic]\n" + "cmp w10, wzr\n" + "beq 1f\n" + "2:\n" + "subs w10, w10, #4\n" + "ld1 {v0.s}[0], [x11], %[row_stride]\n" + "ld1 {v1.s}[0], [x11], %[row_stride]\n" + "ld1 {v0.s}[1], [x11], %[row_stride]\n" + "ld1 {v1.s}[1], [x11], %[row_stride]\n" + "zip1 v2.8b, v0.8b, v1.8b\n" + "zip2 v3.8b, v0.8b, v1.8b\n" + "zip1 v4.4h, v2.4h, v3.4h\n" + "zip2 v5.4h, v2.4h, v3.4h\n" + "st1 {v4.4h, v5.4h}, [x12], #16\n" + + "bgt 2b\n" + "1:\n" + + : + : [src_ic] "r"(src_ic), [packed_ic] "r"(packed_ic), [row] "r"(row_div), [row_stride] "r"(row_stride_int64) + : "memory", "w10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12"); + packed_ic += C4NUM * row_div; + src_ic += row_div * row_stride; +#else + for (; r < row_div; r += C4NUM) { + for (int i = 0; i < row_tile; i++) { + packed_ic[0 * row_tile + i] = src_ic[i * row_stride + 0]; + packed_ic[1 * row_tile + i] = src_ic[i * row_stride + 1]; + packed_ic[2 * row_tile + i] = src_ic[i * row_stride + 2]; + packed_ic[3 * row_tile + i] = src_ic[i * row_stride + 3]; + } + packed_ic += C16NUM; + src_ic += row_tile * row_stride; + } +#endif + for (r = 0; r < row_res; ++r) { + for (int i = 0; i < C4NUM; ++i) { + packed_ic[i * row_tile + r] = src_ic[r * row_stride + i]; + } + } + } + if (col_res == 0) { + return; + } + src_ic = src_input + col_div; + packed_ic = packed_input + row_align * col_div; + for (int r = 0; r < row_div; r += row_tile) { + for (int i = 0; i < col_res; ++i) { + packed_ic[i * row_tile + 0] = src_ic[0 * row_stride + i]; + packed_ic[i * row_tile + 1] = src_ic[1 * row_stride + i]; + packed_ic[i * row_tile + 2] = src_ic[2 * row_stride + i]; + packed_ic[i * row_tile + 3] = src_ic[3 * row_stride + i]; + } + src_ic += row_tile * row_stride; + packed_ic += row_tile * col_tile; + } + + for (int r = 0; r < row_res; ++r) { + for (int c = 0; c < col_res; ++c) { + packed_ic[c * row_tile + r] = src_ic[r * row_stride + c]; + } + } +} + +void CalcWeightSums(const int8_t *weight, int row, int col, int32_t *dst, DataOrder order) { + if (order == RowMajor) { + for (int c = 0; c < col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + sum += weight[r * col + c]; + } + dst[c] = sum; + } + } else { + for (int c = 0; c < col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + sum += weight[c * row + r]; + } + dst[c] = sum; + } + } + return; +} + +void CalcPartWeightSums(const int8_t *weight, int row, int stride, int cur_col, int32_t *dst, DataOrder order) { + if (order == RowMajor) { + for (int c = 0; c < cur_col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + sum += weight[r * stride + c]; + } + dst[c] = sum; + } + } else { + for (int c = 0; c < cur_col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + sum += weight[c * row + r]; + } + dst[c] = sum; + } + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_matmul_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_matmul_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..5b36268e29311ad44da2ab0e1633413aba5a777f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_matmul_int8.h @@ -0,0 +1,74 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_DYNAMIC_MATMUL_H_ +#define NNACL_INT8_DYNAMIC_MATMUL_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/matmul_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PackInput2Col4x4(const int8_t *src_input, int8_t *packed_input, int row, int col, int row_stride); +void PackInput4x4(const int8_t *src_input, int8_t *packed_input, size_t input_channel, size_t plane_size); +void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col, + int deep, int deep16, size_t stride, int input_zp, const float *input_scale, + const float *filter_scale, int filter_zp, bool input_per_channel, bool filter_per_channel, + int64_t act_type); +void CalcWeightSums(const int8_t *weight, int row, int col, int32_t *dst, DataOrder order); +void CalcPartWeightSums(const int8_t *weight, int row, int stride, int cur_col, int32_t *dst, DataOrder order); +#if defined(ENABLE_ARM64) && !defined(USE_AOS_GCC_TOOLCHAIN) +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, const float *multi_scales, + const float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums, + const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode); +#endif +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, const float *multi_scales, + const float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums, + const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode); +#ifdef ENABLE_FP16 +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmul4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, + const float *multi_scales, const float16_t *bias, size_t row, size_t col, + size_t stride, const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, + int64_t b_zp_sum, int64_t act_type, int64_t mode); +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmulSdot4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, + const float *multi_scales, const float16_t *bias, size_t row, size_t col, + size_t stride, const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, + int64_t b_zp_sum, int64_t act_type, int64_t mode); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DYNAMIC_MATMUL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_quant_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_quant_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..3cd0669c58f07833ef4445892db2793d04abaa7b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_quant_int8.c @@ -0,0 +1,91 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/dynamic_quant_int8.h" + +void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max) { + if (count == 0) { + return; + } +#ifndef ENABLE_ARM64 + for (int i = 0; i < count; ++i) { + *real_min = data[i] < *real_min ? data[i] : *real_min; + *real_max = data[i] > *real_max ? data[i] : *real_max; + } +#else + // avoid to compile optimize. + volatile int count_4 = DOWN_ROUND(count, C4NUM); + asm volatile( + "mov x4, %[data]\n" // reload data + "mov w5, %w[count_4]\n" // reload count + "ld1 {v31.4s}, [x4]\n" // min + "ld1 {v30.4s}, [x4], #16\n" // max + "subs w5, w5, #4\n" + "ble 1f\n" + + "0:\n" + "ld1 {v0.4s}, [x4], #16\n" + "fmin v31.4s, v31.4s, v0.4s\n" + "fmax v30.4s, v30.4s, v0.4s\n" + "subs w5, w5, #4\n" + "bgt 0b\n" + + "1:\n" + "fminv s6, v31.4s\n" + "fmaxv s7, v30.4s\n" + + "str s6, [%[real_min]]\n" + "str s7, [%[real_max]]\n" + + : + : [data] "r"(data), [count_4] "r"(count_4), [real_min] "r"(real_min), [real_max] "r"(real_max) + : "x4", "w5", "s6", "s7", "v0", "v30", "v31"); + for (int i = count_4; i < count; ++i) { + *real_min = data[i] < *real_min ? data[i] : *real_min; + *real_max = data[i] > *real_max ? data[i] : *real_max; + } +#endif +} + +void CalculateChannelRowMinMax(const float *data, int count, float *real_min, float *real_max, int row_length) { + if (row_length == 0) { + return; + } + int channel_total = count / row_length; + for (int i = 0; i < channel_total; i++) { + CalculateMinMaxFp32(data + i * row_length, row_length, real_min + i, real_max + i); + } +} + +void CalculateChannelColMinMax(const float *data, int count, float *real_min, float *real_max, int row_length) { + if (row_length == 0) { + return; + } + int row_total = count / row_length; + for (int r = 0; r < row_total; r++) { + const float *data_current = data + r * row_length; + for (int c = 0; c < row_length; c++) { + float *real_min_channel = real_min + c; + float *real_max_channel = real_max + c; + if (data_current[c] < *real_min_channel) { + *real_min_channel = data_current[c]; + } + if (data_current[c] > *real_max_channel) { + *real_max_channel = data_current[c]; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_quant_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_quant_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..05f26e68500c1ec639558dec44d89c4433bba1cc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_quant_int8.h @@ -0,0 +1,34 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_DYNAMIC_QUANT_INT8_H_ +#define NNACL_INT8_DYNAMIC_QUANT_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/pow_parameter.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max); +void CalculateChannelRowMinMax(const float *data, int count, float *real_min, float *real_max, int row_length); +void CalculateChannelColMinMax(const float *data, int count, float *real_min, float *real_max, int row_length); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DYNAMIC_QUANT_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/fixed_point.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/fixed_point.c new file mode 100644 index 0000000000000000000000000000000000000000..e77ba048b96c5882840c018046a71f25aa1e131f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/fixed_point.c @@ -0,0 +1,276 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/fixed_point.h" + +#define C31NUM 31 + +// returns the high-32 bits of a * b with rounding +// assume that a and b is divided by 2^31, who fall into [-1, 1] +// so the mantissa of a * b is (a / 2^31) * (b / 2^31) * 2^31= (a * b) / 2^31 +// actually we compute 2 * a * b / 2^32 +// and take 32 bits of mantissa for rounding +int SaturatingRoundingDoublingHighMul(int a, int b) { + if (a == INT_MIN && b == INT_MIN) { + return INT_MAX; + } + int64_t ab = ((int64_t)a) * ((int64_t)b); + int64_t rounding = ab >= 0 ? (1ll << 30) : (1ll - (1ll << 30)); + // do not apply right shift to potential negetive values + int ab_mantissa = (int)((ab + rounding) / (1ll << 31)); + return ab_mantissa; +} + +int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b) { + if (a == SHRT_MIN && b == SHRT_MIN) { + return SHRT_MAX; + } + int32_t ab = ((int32_t)a) * ((int32_t)b); + int16_t rounding = ab >= 0 ? (1ll << 14) : (1ll - (1ll << 14)); + return (int16_t)((ab + rounding) / (1ll << 15)); +} + +// division by a 2^exponent with rounding +// or arithmetic right shift with rounding +int RoundingDivideByPOT(int x, int exponent) { + if (exponent > C31NUM) { + exponent = C31NUM; + } + const int mask = (1ll << exponent) - 1; + const int remainder = x & mask; + const int threshold = (mask >> 1) + (x < 0 ? 1 : 0); + return (x >> exponent) + (remainder > threshold ? 1 : 0); +} + +int UpwardRounding(int x, int exponent) { + const int32_t rounding_offset = (exponent > 0) ? (1 << (exponent - 1)) : 0; + if (x > INT32_MAX - rounding_offset) { + return 1 << (31 - exponent); + } + return (x + rounding_offset) >> exponent; +} + +int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) { + return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); +} + +int MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value, int32_t multiplier, int32_t left_shift, + int32_t right_shift) { + return UpwardRounding(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); +} + +int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift) { + return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value, multiplier), right_shift); +} + +int FractionsBits(int integer_bits) { return 8 * (int)(sizeof(int32_t)) - 1 - integer_bits; } + +int FixedPoint_One(int integer_bits, int fractions_bits) { + return (integer_bits == 0 ? INT32_MAX : ((1) << (uint32_t)(integer_bits == 0 ? 0 : fractions_bits))); +} + +int RoundingHalfSum(int32_t a, int32_t b) { + int64_t sum = (int64_t)a + (int64_t)b; + return (int32_t)((sum + (sum > 0 ? 1 : -1)) / 2); +} + +int32_t BitAnd(int32_t a, int32_t b) { return (uint32_t)a & (uint32_t)b; } + +int32_t BitOr(int32_t a, int32_t b) { return (uint32_t)a | (uint32_t)b; } + +int32_t BitXor(int32_t a, int32_t b) { return (uint32_t)a ^ (uint32_t)b; } + +int32_t BitNot(int32_t a) { return ~(uint32_t)a; } + +int BitsSelect(int mask, int bound, int val) { return BitXor(BitAnd(mask, bound), BitAnd(BitNot(mask), val)); } + +int ConstantPOT(int fractional_bits, int exponent) { return (1 << (uint32_t)(fractional_bits + exponent)); } + +int32_t MaskIfNonZero(int32_t a) { return a ? BitNot(0) : 0; } + +int32_t MaskIfZero(int32_t a) { return MaskIfNonZero(!a); } + +int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero((a < b)); } + +uint32_t CountLeadingZeroBits(uint32_t x) { + if (x == 0) { + return 8 * sizeof(uint32_t) - 1; + } +#if defined(__GUNC__) + return __builtin_clz(x); +#else + const uint32_t leading_positive = (uint32_t)(1) << (8 * sizeof(uint32_t) - 1); + uint32_t leading_zeros = 0; + while (x < leading_positive) { + x <<= 1; + leading_zeros++; + } + return leading_zeros; +#endif +} + +uint32_t CountLeadingSignBits(int32_t x) { + if (x == 0) { + return 8 * sizeof(int32_t) - 1; + } +#if defined(__GUNC__) && !defined(__clang__) + return __builtin_clrsb(x); +#else + return x >= 0 ? CountLeadingZeroBits((uint32_t)x) - 1 : x != INT32_MIN ? CountLeadingZeroBits(2 * (uint32_t)(-x)) : 0; +#endif +} + +int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent) { + if (exponent > 0) { + const int min = INT32_MIN; + const int max = INT32_MAX; + const int scalar_int_bits = 8 * (int)(sizeof(int32_t)); + const int threshold = ((1 << (uint32_t)(scalar_int_bits - 1 - exponent)) - 1); + const int positive_mask = x > threshold ? BitNot(0) : 0; + const int negative_mask = x < -threshold ? BitNot(0) : 0; + int result = x * ((int32_t)(1) << (uint32_t)exponent); + result = BitsSelect(positive_mask, max, result); + result = BitsSelect(negative_mask, min, result); + return result; + } else if (exponent < 0) { + return RoundingDivideByPOT(x, -exponent); + } else { + return x; + } +} + +int32_t Rescale(int x, int integer_bits_src, int integer_bits_dst) { + int exponent = integer_bits_src - integer_bits_dst; + return SaturatingRoundingMultiplyByPOT(x, exponent); +} + +int32_t reciprocal_on_interval_between_0_1(int32_t a) { + int one = FixedPoint_One(0, FractionsBits(0)); + int half_sum = RoundingHalfSum(a, one); + const int constant_48_over_17 = 1515870810; + const int constant_neg_32_over_17 = -1010580540; + int x = constant_48_over_17 + SaturatingRoundingDoublingHighMul(half_sum, constant_neg_32_over_17); + for (int i = 0; i < 3; i++) { + int half_sum_times_x = SaturatingRoundingDoublingHighMul(half_sum, x); + int one_minus_half_sum_times_x = FixedPoint_One(2, FractionsBits(2)) - half_sum_times_x; + x = x + Rescale(SaturatingRoundingDoublingHighMul(x, one_minus_half_sum_times_x), 2 + 2, 2); + } + return Rescale(x, 2 - 1, 0); +} + +int32_t ComputerReciprocal(int32_t x, uint32_t x_digits, int32_t *recip_shift) { + uint32_t leading_zreos_plus_one = CountLeadingZeroBits((uint32_t)x); + *recip_shift = x_digits - leading_zreos_plus_one; + const int32_t shifted_minus_one = (int32_t)(((uint32_t)x << leading_zreos_plus_one) - ((uint32_t)(1) << 31)); + const int32_t shifted_scaled = reciprocal_on_interval_between_0_1(shifted_minus_one); + return shifted_scaled; +} + +int exp_on_interval_values(int a) { + const int constant_neg_1_over_8 = 1895147668; + const int constant_1_over_3 = 715827883; + int fractional_bits = FractionsBits(0); + int x = a + ConstantPOT(fractional_bits, -3); + int x2 = SaturatingRoundingDoublingHighMul(x, x); + int x3 = SaturatingRoundingDoublingHighMul(x2, x); + int x4 = SaturatingRoundingDoublingHighMul(x2, x2); + int x4_over_4 = SaturatingRoundingMultiplyByPOT(x4, -2); + int x4_over_24_plus_x3_over_6_plus_x2_over_2 = + SaturatingRoundingMultiplyByPOT((SaturatingRoundingDoublingHighMul((x4_over_4 + x3), constant_1_over_3) + x2), -1); + return constant_neg_1_over_8 + + SaturatingRoundingDoublingHighMul(constant_neg_1_over_8, (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); +} + +void exp_barrel_shifter(int exponent, int muliplier, int integer_bits, int fractional_bits, int remainder, + int32_t *result) { + if (integer_bits > exponent) { + int total_shift = integer_bits > exponent ? fractional_bits + exponent : 0; + *result = BitsSelect(MaskIfNonZero(BitAnd(remainder, (1 << (uint32_t)total_shift))), + SaturatingRoundingDoublingHighMul(*result, muliplier), *result); + } +} + +int exp_on_negative_values(int a, const int integer_bits) { + int fractional_bits = FractionsBits(integer_bits); + const int one_quarter = ConstantPOT(fractional_bits, -2); + int a_mod_quarter_minus_one_quarter = ((unsigned)(a) & (one_quarter - 1)) - one_quarter; + int result = exp_on_interval_values(Rescale(a_mod_quarter_minus_one_quarter, integer_bits, 0)); + int remainder = a_mod_quarter_minus_one_quarter - a; + + exp_barrel_shifter(-2, 1672461947, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(-1, 1302514674, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+0, 790015084, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+1, 290630308, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+2, 39332535, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+3, 720401, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+4, 242, integer_bits, fractional_bits, remainder, &result); + + int clamp_bits = integer_bits > 5 ? 36 - integer_bits : 0; + if (integer_bits > 5) { + const int clamp = -(1 << (uint32_t)clamp_bits); + result = BitsSelect(MaskIfLessThan(a, clamp), 0, result); + } + result = BitsSelect(MaskIfZero(a), FixedPoint_One(0, fractional_bits), result); + return result; +} + +void GetSqrtQuantMultiplierExp(int32_t input, int reverse_shift, int32_t *multiplier, int32_t *shift) { + if (input <= 1) { + *multiplier = INT_MAX; + *shift = 0; + } + *shift = 11; + while (input >= (1 << 29)) { + input /= 4; + ++*shift; + } + uint32_t max_left_shift_bits = CountLeadingSignBits(input); + if (max_left_shift_bits < 2) { + return; + } + uint32_t left_shift_bit_pairs = max_left_shift_bits / 2 - 1; + *shift -= left_shift_bit_pairs; + input <<= 2 * left_shift_bit_pairs; + int32_t fixedpoint_f3_input = input >> 1; // sign: 1 bit, integer: 3 bit, fractional: 28 bit + int32_t fp_f3_half_input = SaturatingRoundingMultiplyByPOT(fixedpoint_f3_input, -1); + int32_t fp_f3_half_three = (1 << 28) + (1 << 27); + int32_t tmp = (1 << 28); // one + for (int i = 0; i < 5; i++) { + int32_t tmp3 = Rescale(SaturatingRoundingDoublingHighMul(tmp, SaturatingRoundingDoublingHighMul(tmp, tmp)), 9, 3); + tmp = Rescale(SaturatingRoundingDoublingHighMul(fp_f3_half_three, tmp) - + SaturatingRoundingDoublingHighMul(fp_f3_half_input, tmp3), + 6, 3); + } + const int32_t fp_f0_half_sqrt_2 = 1518500250; // sqrt(2) / 2 + tmp = SaturatingRoundingDoublingHighMul(tmp, fp_f0_half_sqrt_2); + *multiplier = tmp; + if (*shift < 0) { + *multiplier <<= -*shift; + *shift = 0; + } + *shift *= reverse_shift; +} + +#ifdef ENABLE_NEON +int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent) { + const int32x4_t shift_vec = vdupq_n_s32(-exponent); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31); + const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); + return vrshlq_s32(fixed_up_x, shift_vec); +} + +int32x4_t SaturatingRoundingDoublingHighMulInt32x4(int32x4_t a, int32x4_t b) { return vqrdmulhq_s32(a, b); } +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/fixed_point.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/fixed_point.h new file mode 100644 index 0000000000000000000000000000000000000000..503a5e1d672a7e226d7cf1c71dbac8cd9724d863 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/fixed_point.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_QUANTIZATION_FIXED_POINT_H_ +#define NNACL_QUANTIZATION_FIXED_POINT_H_ + +#include +#include +#ifdef ENABLE_NEON +#include +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +// returns the high-32 bits of a * b with rounding +// assume that a and b is divided by 2^31, who fall into [-1, 1] +// so the mantissa of a * b is (a / 2^31) * (b / 2^31) * 2^31= (a * b) / 2^31 +// actually we compute 2 * a * b / 2^32 +// and take 32 bits of mantissa for rounding +int SaturatingRoundingDoublingHighMul(int a, int b); + +int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b); + +// division by a 2^exponent with rounding +// or arithmetic right shift with rounding +int RoundingDivideByPOT(int x, int exponent); + +int UpwardRounding(int x, int exponent); + +int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift); + +int MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value, int32_t multiplier, int32_t left_shift, + int32_t right_shift); + +int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift); + +int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent); + +int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst); + +uint32_t CountLeadingSignBits(int32_t x); + +int32_t ComputerReciprocal(int32_t x, uint32_t x_digits, int32_t *recip_shift); + +int exp_on_negative_values(int a, const int tIntegerBits); + +void GetSqrtQuantMultiplierExp(int32_t input, int reverse_shift, int32_t *multiplier, int32_t *shift); + +#ifdef __cplusplus +} +#endif + +#ifdef ENABLE_NEON +int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent); + +int32x4_t SaturatingRoundingDoublingHighMulInt32x4(int32x4_t a, int32x4_t b); +#endif + +#endif // NNACL_QUANTIZATION_FIXED_POINT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gatherNd_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gatherNd_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..22d0618730fa1f6baed85131c1797a2024d8ff87 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gatherNd_int8.c @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/gatherNd_int8.h" +#include +#include "nnacl_c/errorcode.h" + +int GatherNdInt8(int8_t *input, int8_t *output, const int32_t *in_offset, int area, int count, GatherQuantArg param) { + double alpha = param.alpha_; + int z1 = param.zp_in_; + int z2 = param.zp_out_; + for (int i = 0; i < count; ++i) { + for (int j = 0; j < area; ++j) { + int32_t tmp = round(alpha * (input[in_offset[i] + j] - z1)) + z2; + tmp = tmp > 127 ? 127 : tmp; + tmp = tmp < -128 ? -128 : tmp; + output[area * i + j] = (int8_t)tmp; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gatherNd_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gatherNd_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..91c74ab6649210ee3ad8c64b86c35033bb007c8b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gatherNd_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_GATHERND_INT8_H_ +#define NNACL_INT8_GATHERND_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int GatherNdInt8(int8_t *in_data, int8_t *out_data, const int32_t *in_offset, int area, int count, + GatherQuantArg param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_GATHERND_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gather_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gather_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..148eb0c38f8a181cbe31368fd7bed8e5641129ae --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gather_int8.c @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include "nnacl_c/int8/gather_int8.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/errorcode.h" + +int GatherInt8Int32Index(const int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, + const int32_t *indices, int indices_element_size, GatherQuantArg para) { + double alpha = para.alpha_; + int z1 = para.zp_in_; + int z2 = para.zp_out_; + int i, m, j; + for (m = 0; m < outer_size; ++m) { + const int8_t *inputm = in_data + inner_size * m * limit; + int8_t *outputm = out_data + inner_size * m * indices_element_size; + for (i = 0; i < indices_element_size; ++i) { + if (indices[i] < 0 || indices[i] > limit) { + return NNACL_ERR; + } + for (j = 0; j < inner_size; ++j) { + int32_t tmp = round(alpha * (inputm[indices[i] * inner_size + j] - z1)) + z2; + tmp = tmp > 127 ? 127 : tmp; + tmp = tmp < -128 ? -128 : tmp; + outputm[i * inner_size + j] = (int8_t)tmp; + } + } + } + return NNACL_OK; +} + +int GatherInt8Int64Index(const int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, + const int64_t *indices, int indices_element_size, GatherQuantArg para) { + double alpha = para.alpha_; + int z1 = para.zp_in_; + int z2 = para.zp_out_; + int i, m, j; + for (m = 0; m < outer_size; ++m) { + const int8_t *inputm = in_data + inner_size * m * limit; + int8_t *outputm = out_data + inner_size * m * indices_element_size; + for (i = 0; i < indices_element_size; ++i) { + if (indices[i] < 0 || indices[i] > limit) { + return NNACL_ERR; + } + for (j = 0; j < inner_size; ++j) { + int32_t tmp = round(alpha * (inputm[indices[i] * inner_size + j] - z1)) + z2; + tmp = tmp > 127 ? 127 : tmp; + tmp = tmp < -128 ? -128 : tmp; + outputm[i * inner_size + j] = (int8_t)tmp; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gather_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gather_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..86d3664be535dceb1a86d1e92202c579f8e37c2c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gather_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_GATHER_INT8_H_ +#define NNACL_INT8_GATHER_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int GatherInt8Int32Index(const int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, + const int32_t *indices, int indices_element_size, GatherQuantArg para); + +int GatherInt8Int64Index(const int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, + const int64_t *indices, int indices_element_size, GatherQuantArg para); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_GATHER_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/hswish_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/hswish_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..01393dac1a0cef4310cded81db9cad4b854c75f8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/hswish_int8.c @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/hswish_int8.h" + +int16_t SaturatingLeftShift(int16_t value, int shift_num) { + int32_t result = (int32_t)value * (1 << shift_num); + return MSMAX(MSMIN(result, SHRT_MAX), SHRT_MIN); +} + +int HSwishInt8(const int8_t *src, int length, int8_t *dst, const HswishQuantArg *arg) { + for (int i = 0; i < length; i++) { + const int16_t input_value = src[i] - arg->input_zp; + const int16_t input_value_scale = input_value * (1 << 7); + const int16_t input_value_on_preshift_output_scale = + SaturatingRoundingDoublingHighMulInt16(input_value_scale, arg->output_multiplier_fixedpoint_int16); + int16_t relu6_value = input_value_scale; + if (arg->relu6_multiplier_exponent > 0) { + relu6_value = SaturatingLeftShift(relu6_value, arg->relu6_multiplier_exponent - 1); + } + relu6_value = SaturatingRoundingDoublingHighMulInt16(relu6_value, arg->relu6_multiplier_fixedpoint_int16); + + if (arg->relu6_multiplier_exponent > 0) { + relu6_value = SaturatingLeftShift(relu6_value, 1); + } + if (arg->relu6_multiplier_exponent < 0) { + relu6_value = RoundingDivideByPOT(relu6_value, -arg->relu6_multiplier_exponent); + } + relu6_value = (size_t)(relu6_value + (1 << 15)) >> 1; + const int16_t preshift_output_value = + SaturatingRoundingDoublingHighMulInt16(relu6_value, input_value_on_preshift_output_scale); + + int16_t output = RoundingDivideByPOT(preshift_output_value, -arg->output_multiplier_exponent); + output += arg->output_zp; + output = MSMIN(output, 127); + output = MSMAX(output, -128); + dst[i] = (int8_t)output; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/hswish_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/hswish_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..688f62ec4d6c245476b7051096c28ac020b2311a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/hswish_int8.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_HSWISH_INT8_H_ +#define NNACL_INT8_HSWISH_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/int8/fixed_point.h" + +typedef struct HswishQuantArg { + double input_scale; + int32_t input_zp; + double output_scale; + int32_t output_zp; + int16_t relu6_multiplier_fixedpoint_int16; + int32_t relu6_multiplier_exponent; + int16_t output_multiplier_fixedpoint_int16; + int32_t output_multiplier_exponent; +} HswishQuantArg; + +#ifdef __cplusplus +extern "C" { +#endif +int HSwishInt8(const int8_t *src, int length, int8_t *dst, const HswishQuantArg *arg); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_HSWISH_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/l2_norm_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/l2_norm_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..233ff4a159bc46e4fac17f17e18861c480ab2e7a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/l2_norm_int8.c @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/int8/l2_norm_int8.h" +#include +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/errorcode.h" + +int L2NormalizationInt8(const int8_t *input_data, int8_t *output_data, const L2NormParameter *param, + const L2NormQuantArg *quant_param, const int begin, const int end) { + const int inner_size = param->shape_[param->shape_num_ - 1]; + + for (int i = begin; i < end; ++i) { + int32_t square_sum = 0; + for (int j = 0; j < inner_size; ++j) { + int32_t in = input_data[i * inner_size + j] - quant_param->in_.zp_; + square_sum += in * in; + } + int32_t multiplier; + int32_t shift; + GetSqrtQuantMultiplierExp(square_sum, -1, &multiplier, &shift); + for (int k = 0; k < inner_size; ++k) { + int32_t in = input_data[i * inner_size + k] - quant_param->in_.zp_; + int32_t out = RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(in * (1 << 7), multiplier), -shift); + output_data[i * inner_size + k] = MSMIN(127, MSMAX(-128, out)); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/l2_norm_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/l2_norm_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..4cfec79797414208aa17f8726d94ba551e266a2b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/l2_norm_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_L2_NORM_INT8_H_ +#define NNACL_INT8_L2_NORM_INT8_H_ + +#include "nnacl_c/l2_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int L2NormalizationInt8(const int8_t *input_data, int8_t *output_data, const L2NormParameter *param, + const L2NormQuantArg *quant_param, const int begin, const int end); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_L2_NORM_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/layer_norm_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/layer_norm_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..81d829c5507f12d9ec8bdae1365805676580e43a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/layer_norm_int8.c @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/layer_norm_int8.h" + +void LayerNormGammaAndBetaInt8(int8_t *dst, const int8_t *src, const float *gamma_data, const float *beta_data, + const LayerNormQuantArg *quant, int num, const float mean, const float deno) { + for (int i = 0; i < num; i++) { + float fp32_src = (src[i] - quant->in_zp_) * quant->in_scale_; + float fp32_dst = (fp32_src - mean) * deno; + fp32_dst = fp32_dst * gamma_data[i] + beta_data[i]; + int32_t int32_dst = (int32_t)round(fp32_dst * 1.0 / quant->out_scale_ + quant->out_zp_); + dst[i] = (int8_t)MSMAX(MSMIN(int32_dst, 127), -128); + } +} + +/* + * origin : (x-mean) / sqrt(variance + epsilon) * gamma + beta + * quant : (x-mean) / sqrt(sum(x * x) - mean * mean) * gamma + beta + * + * */ +int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data, + const LayerNormComputeParam *param, const LayerNormQuantArg *quant, int task_id, int thread_num) { + if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) { + return NNACL_NULL_PTR; + } + NNACL_CHECK_ZERO_RETURN_ERR(param->params_inner_size_); + NNACL_CHECK_ZERO_RETURN_ERR(param->params_outer_size_); + + int step = UP_DIV(param->norm_outer_size_, thread_num); + int thread_end = NNACL_MIN((task_id + 1) * step, param->norm_outer_size_); + for (int i = task_id * step; i < thread_end; i++) { + const int8_t *src_norm = src_data + i * param->norm_inner_size_; + int8_t *dst_norm = dst_data + i * param->norm_inner_size_; + float mean = 0.0f; + float square_mean = 0.0f; + for (int j = 0; j < param->norm_inner_size_; j++) { + float float_src = (src_norm[j] - quant->in_zp_) * quant->in_scale_; + mean += float_src; + square_mean += float_src * float_src; + } + mean /= (float)param->norm_inner_size_; + square_mean /= (float)param->norm_inner_size_; + const float deno = 1 / sqrtf(square_mean - mean * mean + param->epsilon_); + + if (param->norm_outer_size_ <= param->params_outer_size_) { + for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) { + const int8_t *src_param = src_norm + x * param->params_inner_size_; + int8_t *dst_param = dst_norm + x * param->params_inner_size_; + LayerNormGammaAndBetaInt8(dst_param, src_param, gamma_data, beta_data, quant, param->norm_inner_size_, mean, + deno); + } + } else { + int x = i / param->params_outer_size_; + const float *gamma = gamma_data + x * param->norm_inner_size_; + const float *beta = beta_data + x * param->norm_inner_size_; + LayerNormGammaAndBetaInt8(dst_norm, src_norm, gamma, beta, quant, param->norm_inner_size_, mean, deno); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/layer_norm_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/layer_norm_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..e33b34b2ba853d89bd7935f67011977ef954e8bd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/layer_norm_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_LAYER_NORM_H_ +#define NNACL_INT8_LAYER_NORM_H_ + +#include "nnacl_c/errorcode.h" +#include "nnacl_c/layer_norm_parameter.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/kernel/layer_norm.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data, + const LayerNormComputeParam *param, const LayerNormQuantArg *quant, int task_id, int thread_num); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_LAYER_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/leaky_relu_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/leaky_relu_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..8b38c5f13eb2ea85972f26be404a66af68f155d4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/leaky_relu_int8.c @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/leaky_relu_int8.h" +#include "nnacl_c/errorcode.h" + +int DoLeakReluInt8(const int8_t *inputs, int8_t *output_ptr, const LeakyReluQuantArg *quant_prelu_parm, int task_id) { + if (quant_prelu_parm == NULL) { + return NNACL_NULL_PTR; + } + float output_scale = quant_prelu_parm->out_args_.scale_; + int output_zp = quant_prelu_parm->out_args_.zp_; + const float output_inverse_scale = 1.f / output_scale; + + float scale = quant_prelu_parm->in_args_.scale_ * output_inverse_scale; + float bias = -quant_prelu_parm->in_args_.zp_ * scale; + for (int j = task_id; j < quant_prelu_parm->element_num; j += quant_prelu_parm->thread_num_) { + if (inputs[j] <= 0) { + int32_t output_tmp = round(inputs[j] * quant_prelu_parm->slope_ * scale + bias) + output_zp; + if (output_tmp > 127) { + output_ptr[j] = 127; + } else if (output_tmp < -128) { + output_ptr[j] = -128; + } else { + output_ptr[j] = (int8_t)output_tmp; + } + } else { + int32_t output_tmp = round(inputs[j] * scale + bias) + output_zp; + if (output_tmp > 127) { + output_ptr[j] = 127; + } else if (output_tmp < -128) { + output_ptr[j] = -128; + } else { + output_ptr[j] = (int8_t)output_tmp; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/leaky_relu_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/leaky_relu_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..cde802958d1eec5eb521e54dab92105843dec3e0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/leaky_relu_int8.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_PRELU_INT8_H_ +#define NNACL_INT8_PRELU_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoLeakReluInt8(const int8_t *inputs, int8_t *output_ptr, const LeakyReluQuantArg *quant_Prelu_parm, int task_id); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_PRELU_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/matmul_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/matmul_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..cfe16eb3ad6e99536516a3b75af7c14068ba9373 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/matmul_int8.c @@ -0,0 +1,839 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/int8/fixed_point.h" + +void RowMajor2Row2x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int col16 = UP_ROUND(col, C16NUM); + for (int r = 0; r < row; r++) { + int rd2 = r / C2NUM; + int rm2 = r % C2NUM; + for (int c = 0; c < col; c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + int dst_index = rd2 * col16 * C2NUM + cd16 * C2NUM * C16NUM + rm2 * C16NUM + cm16; + int src_index = r * col + c; + dst_ptr[dst_index] = src_ptr[src_index]; + } + } +} + +void RowMajor2Row4x4MajorInt8(const int8_t *src, int8_t *dst, int row, int col) { + int row_div = row / C4NUM * C4NUM; + int col_4 = UP_ROUND(col, C4NUM); + int col_div = col / C4NUM * C4NUM; + + const int8_t *src_r4 = src; + int8_t *packed_r4 = dst; + const int8_t *src_c4 = NULL; + int8_t *packed_c4 = NULL; + for (int r = 0; r < row_div; r += C4NUM) { + src_c4 = src_r4; + packed_c4 = packed_r4; + + for (int c = 0; c < col_div; c += C4NUM) { + for (int i = 0; i < C4NUM; i++) { + packed_c4[i * C4NUM + 0] = src_c4[i * col + 0]; + packed_c4[i * C4NUM + 1] = src_c4[i * col + 1]; + packed_c4[i * C4NUM + 2] = src_c4[i * col + 2]; + packed_c4[i * C4NUM + 3] = src_c4[i * col + 3]; + } + src_c4 += C4NUM; + packed_c4 += C16NUM; + } + + if (col == col_div) { + continue; + } + memset(packed_c4, 0, C16NUM * sizeof(int8_t)); + for (int i = 0; i < C4NUM; ++i) { + for (int c = 0; c < col - col_div; ++c) { + packed_c4[i * C4NUM + c] = src_c4[i * col + c]; + } + } + src_r4 += C4NUM * col; + packed_r4 += C4NUM * col_4; + } + + if (row == row_div) { + return; + } + memset(packed_r4, 0, C4NUM * col_4); + src_c4 = src_r4; + packed_c4 = packed_r4; + for (int c = 0; c < col_div; c += C4NUM) { + for (int i = 0; i < row - row_div; ++i) { + packed_c4[i * C4NUM + 0] = src_c4[i * col + 0]; + packed_c4[i * C4NUM + 1] = src_c4[i * col + 1]; + packed_c4[i * C4NUM + 2] = src_c4[i * col + 2]; + packed_c4[i * C4NUM + 3] = src_c4[i * col + 3]; + } + src_c4 += C4NUM; + packed_c4 += C16NUM; + } + if (col == col_div) { + return; + } + for (int i = 0; i < row - row_div; ++i) { + for (int c = 0; c < col - col_div; ++c) { + packed_c4[i * C4NUM + c] = src_c4[i * col + c]; + } + } +} + +void RowMajor2Col16x2MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int row16 = UP_ROUND(row, C16NUM); + int stride = C16NUM * C2NUM; + for (int r = 0; r < row; ++r) { + for (int c = 0; c < col; ++c) { + int stride_idx = c / C2NUM * (row16 / C16NUM) + r / C16NUM; + int dst_idx = stride * stride_idx + c % C2NUM * C16NUM + r % C16NUM; + int src_idx = r * col + c; + dst_ptr[dst_idx] = src_ptr[src_idx]; + } + } +} + +void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int col4 = UP_ROUND(col, C4NUM); + for (int r = 0; r < row; r++) { + int rd8 = r / C8NUM; + int rm8 = r % C8NUM; + for (int c = 0; c < col; c++) { + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + int dst_index = rd8 * col4 * C8NUM + cd4 * C8NUM * C4NUM + rm8 * C4NUM + cm4; + int src_index = r * col + c; + dst_ptr[dst_index] = src_ptr[src_index]; + } + } +} + +void MatrixPack4x16UnitInt8(const int8_t *src, int8_t *dst, int row, int col, int stride) { + for (int r = 0; r < row; r++) { + const int8_t *src_r = src + r * stride; + int8_t *dst_r = dst + r * C16NUM; + memcpy(dst_r, src_r, col * sizeof(int8_t)); + } + return; +} + +void MatrixEmptyInt8(int8_t *dst, int row, int col) { + for (int r = 0; r < row; r++) { + int8_t *dst_r = dst + r * C16NUM; + memset(dst_r, 0, col * sizeof(int8_t)); + } + return; +} + +void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int col4 = UP_ROUND(col, C4NUM); + for (int r = 0; r < row; r++) { + int rd16 = r / C16NUM; + int rm16 = r % C16NUM; + for (int c = 0; c < col; c++) { + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + int dst_index = rd16 * col4 * C16NUM + cd4 * C16NUM * C4NUM + rm16 * C4NUM + cm4; + int src_index = r * col + c; + dst_ptr[dst_index] = src_ptr[src_index]; + } + } +} + +void RowMajor2Row16x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + /* Row-major to row16x4-major (block row-major) */ + int col16 = UP_ROUND(col, C16NUM); + int row_4div = row / C4NUM * C4NUM; + int row_4res = row - row_4div; + int col_16div = col / C16NUM * C16NUM; + int col_16res = col - col_16div; + int8_t *src_r = (int8_t *)src_ptr; + int8_t *dst_r = (int8_t *)dst_ptr; + + for (int ri = 0; ri < row_4div; ri += C4NUM) { + for (int ci = 0; ci < col_16div; ci += C16NUM) { + size_t col_offset = (size_t)col; + int8_t *src_c = src_r + ci; + int8_t *dst_c = dst_r + ci * C4NUM; +#ifdef ENABLE_ARM64 + asm volatile( + "mov x10, %[src_c] \n" + "mov x11, %[dst_c] \n" + + "ld1 {v0.16b}, [x10], %[col_offset]\n" + "ld1 {v1.16b}, [x10], %[col_offset]\n" + "ld1 {v2.16b}, [x10], %[col_offset]\n" + "ld1 {v3.16b}, [x10], %[col_offset]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "st1 {v2.16b}, [x11], #16\n" + "st1 {v3.16b}, [x11], #16\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [col_offset] "r"(col_offset) + : "x10", "x11", "v0", "v1", "v2", "v3"); +#elif ENABLE_ARM32 + asm volatile( + "mov r0, %[src_c] \n" + "mov r1, %[dst_c] \n" + "mov r2, %[col_offset] \n" + "mov r3, #16 \n" + + "vld1.8 {q0}, [r0], r2 \n" + "vld1.8 {q1}, [r0], r2 \n" + "vld1.8 {q2}, [r0], r2 \n" + "vld1.8 {q3}, [r0], r2 \n" + + "vst1.32 {d0, d1}, [r1], r3 \n" + "vst1.32 {d2, d3}, [r1], r3 \n" + "vst1.32 {d4, d5}, [r1], r3 \n" + "vst1.32 {d6, d7}, [r1], r3 \n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [col_offset] "r"(col_offset) + : "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3"); +#else + MatrixPack4x16UnitInt8(src_c, dst_c, C4NUM, C16NUM, col_offset); +#endif + } + + if (col != col_16div) { + MatrixPack4x16UnitInt8(src_r + col_16div, dst_r + col_16div * C4NUM, C4NUM, col_16res, col); + MatrixEmptyInt8(dst_r + col_16div * C4NUM + col_16res, C4NUM, C16NUM - col_16res); + } + src_r += C4NUM * col; + dst_r += C4NUM * col16; + } + + if (row != row_4div) { + memset(dst_r, 0, C4NUM * col16); + + for (int ci = 0; ci < col_16div; ci += C16NUM) { + MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, row_4res, C16NUM, col); + } + + if (col != col_16div) { + MatrixPack4x16UnitInt8(src_r + col_16div, dst_r + col_16div * C4NUM, row_4res, col_16res, col); + } + } + return; +} + +void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int32_t *dst, int row_4, int col_4, int deep_16, + const int32_t *input_sum, const int32_t *bias) { + /* row4x16-major * row16x4-major => row4x4-major */ + for (int r = 0; r < row_4; r++) { + for (int c = 0; c < col_4; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c4div = c / C4NUM, c4mod = c % C4NUM; + int64_t ci = c4div * row_4 * C4NUM + r * C4NUM + c4mod; + int32_t value = 0; + for (int d = 0; d < deep_16; d++) { + int d16div = d / C16NUM, d16mod = d % C16NUM; + int64_t ai = r4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; + int64_t bi = c4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod; + value = value + a[ai] * b[bi]; + } + value -= input_sum[r]; + value += bias[c]; + ((int32_t *)dst)[ci] = value; + } + } + return; +} + +void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, bool peroc) { + /* support per-layer && weight per-channel */ + /* row4x16-major * row16x2-major => (int8)row-major*/ + for (size_t r = 0; r < row; r++) { + for (size_t c = 0; c < col; c++) { + size_t r4div = r / C4NUM, r4mod = r % C4NUM; + size_t c2div = c / C2NUM, c2mod = c % C2NUM; + size_t ci = r * stride + c; + int32_t value = 0; + for (size_t d = 0; d < deep_16; d++) { + size_t d16div = d / C16NUM, d16mod = d % C16NUM; + size_t ai = r4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; + size_t bi = c2div * deep_16 * C2NUM + d16div * C2NUM * C16NUM + c2mod * C16NUM + d16mod; + value = value + a[ai] * b[bi]; + } + int32_t cur_input_sum = + peroc ? input_sum[c2div * UP_ROUND(row, C4NUM) * C2NUM + r * C2NUM + c2mod] : input_sum[r]; + value -= cur_input_sum; + value += bias[c]; + int32_t cur_left_shift = peroc ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = peroc ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = peroc ? multiplier[c] : multiplier[0]; + value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + dst[ci] = (int8_t)value; + } + } + return; +} + +#ifndef ENABLE_ARM +void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int32_t *a_sums, + const int32_t *bias, int mini, int maxi, int out_zp, const int32_t *multiplier, + const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc, + const int32_t *filter_zp) { + /* + * row4x16-major * row16x4-major => (int8)row-major + * support per-layer && weight per-channel + * a_sums is perT : input_row_sum * filter_zp + * perOc : input_row_sum + * */ + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c4div = c / C4NUM, c4mod = c % C4NUM; + int64_t ci = r * stride + c; + int32_t value = 0; + for (int d = 0; d < deep16; d++) { + int d16div = d / C16NUM, d16mod = d % C16NUM; + int64_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; + int64_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod; + value = value + a[ai] * b[bi]; + } + int32_t cur_input_sum = filter_peroc ? a_sums[r] * filter_zp[c] : a_sums[r]; + value -= cur_input_sum; + value += bias[c]; + int32_t cur_left_shift = filter_peroc ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = filter_peroc ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = filter_peroc ? multiplier[c] : multiplier[0]; + value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + out_zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + dst[ci] = (int8_t)value; + } + } + return; +} +#endif +void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, size_t per_channel) { + /* row8x4-major * row4x8-major => (int8)row-major */ + for (size_t r = 0; r < row; r++) { + for (size_t c = 0; c < col; c++) { + size_t r8div = r / C8NUM, r8mod = r % C8NUM; + size_t c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = r * stride + c; + int32_t value = 0; + for (size_t d = 0; d < deep_4; d++) { + size_t d4div = d / C4NUM, d4mod = d % C4NUM; + size_t ai = r8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + r8mod * C4NUM + d4mod; + size_t bi = c8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + c8mod * C4NUM + d4mod; + value = value + a[ai] * b[bi]; + } + int32_t cur_input_sum = + per_channel ? input_sum[c8div * UP_ROUND(row, C8NUM) * C8NUM + r * C8NUM + c8mod] : input_sum[r]; + value -= cur_input_sum; + value += bias[c]; + int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0]; + value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + dst[ci] = (int8_t)value; + } + } + return; +} + +void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, size_t per_channel, const int32_t *filter_zp) { + /* row4x4-major * row4x16-major => (int8)row-major */ + for (size_t r = 0; r < row; r++) { + for (size_t c = 0; c < col; c++) { + size_t r4div = r / C4NUM, r4mod = r % C4NUM; + size_t c16div = c / C16NUM, c16mod = c % C16NUM; + size_t ci = r * stride + c; + int32_t value = 0; + for (size_t d = 0; d < deep_4; d++) { + size_t d4div = d / C4NUM, d4mod = d % C4NUM; + size_t ai = r4div * deep_4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod; + size_t bi = c16div * deep_4 * C16NUM + d4div * C16NUM * C4NUM + c16mod * C4NUM + d4mod; + value = value + a[ai] * b[bi]; + } + int32_t cur_input_sum = per_channel ? input_sum[r] * filter_zp[c] : input_sum[r]; + value -= cur_input_sum; + value += bias[c]; + int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0]; + value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + dst[ci] = (int8_t)value; + } + } + return; +} + +#ifdef ENABLE_ARM64 +void PackInput4x4AndInputSumPert_arm64(const int8_t *src_ic, int8_t *pack_ic, int32_t *input_sum_r, size_t src_stride, + size_t ic_4div, size_t ic_4res, int32_t filter_zp) { + asm volatile( + "dup v2.4s, wzr \n" + "mov x14, %[input_sum_r] \n" + "dup v3.4s, %w[filter_zp] \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x15, #0 \n" + "1: \n" + "cmp x15, %[ic_4div] \n" + "add x15, x15, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 1b \n" + + "3: \n" /* ic res 1 */ + "dup v0.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 6f \n" + + "4: \n" /* ic res 2 */ + "dup v0.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 6f \n" + + "5: \n" /* ic res 3 */ + "dup v0.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 6f \n" + + "6: \n" + "mul v2.4s, v2.4s, v3.4s \n" + + "st1 {v2.4s}, [x14], #16 \n" + + : + : [src_ic] "r"(src_ic), [pack_ic] "r"(pack_ic), [input_sum_r] "r"(input_sum_r), [src_stride] "r"(src_stride), + [ic_4div] "r"(ic_4div), [ic_4res] "r"(ic_4res), [filter_zp] "r"(filter_zp) + : "x10", "x11", "x12", "x13", "x14", "x15", "v0", "v1", "v2", "v3"); + return; +} +#endif +void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, + size_t input_channel, size_t plane_size, int32_t filter_zp) { + size_t ic4 = UP_ROUND(input_channel, C4NUM); + size_t hw4 = UP_ROUND(plane_size, C4NUM); + size_t hw_4div = plane_size / C4NUM * C4NUM; + size_t ic_4div = input_channel / C4NUM * C4NUM; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + /* per layer */ + for (size_t hwi = 0; hwi < hw_4div; hwi += C4NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + int32_t *input_sum_r = input_sum + hwi; +#ifdef ENABLE_ARM64 + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + PackInput4x4AndInputSumPert_arm64(src_ic, pack_ic, input_sum_r, src_stride, ic_4div, ic_4res, filter_zp); +#else + int32_t tmp_sum_value[4] = {0}; + for (size_t ici = 0; ici < ic_4div; ici += C4NUM) { + for (size_t i = 0; i < C4NUM; i++) { + tmp_sum_value[i] += src_ic[0 + i * input_channel]; + tmp_sum_value[i] += src_ic[1 + i * input_channel]; + tmp_sum_value[i] += src_ic[2 + i * input_channel]; + tmp_sum_value[i] += src_ic[3 + i * input_channel]; + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C4NUM; + } + for (size_t ici = ic_4div; ici < input_channel; ici += 1) { + for (size_t i = 0; i < C4NUM; i++) { + tmp_sum_value[i] += src_ic[i * input_channel]; + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (size_t ici = input_channel; ici < ic4; ici += 1) { + for (size_t i = 0; i < C4NUM; i++) { + pack_ic[i * C4NUM] = 0; + } + pack_ic += 1; + } + + for (size_t i = 0; i < C4NUM; i++) { + input_sum_r[i] = tmp_sum_value[i] * filter_zp; + } +#endif + src_r += input_channel * C4NUM; + pack_r += ic4 * C4NUM; + } + + if (hw_4div != plane_size) { + (void)memset(pack_r, 0, C4NUM * ic4); + for (size_t hwi = hw_4div; hwi < plane_size; hwi += 1) { + int32_t tmp_sum_value = 0; + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (size_t ici = 0; ici < ic_4div; ici += C4NUM) { + tmp_sum_value += src_ic[0]; + tmp_sum_value += src_ic[1]; + tmp_sum_value += src_ic[2]; + tmp_sum_value += src_ic[3]; + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C4NUM; + } + for (size_t ici = ic_4div; ici < input_channel; ici += 1) { + tmp_sum_value += src_ic[0]; + pack_ic[0] = src_ic[0]; + src_ic += 1; + pack_ic += 1; + } + input_sum[hwi] = tmp_sum_value * filter_zp; + src_r += input_channel; + pack_r += C4NUM; + } + for (size_t hwi = plane_size; hwi < hw4; hwi++) { + input_sum[hwi] = 0; + } + } + return; +} + +#ifdef ENABLE_ARM64 +void PackInput2Col4x4AndInputSumPert_arm64(const int8_t *src_ic, int8_t *packed_ic, int32_t *input_sum, int row, + size_t row_stride, int32_t filter_zp) { + asm volatile( + "ld1 {v12.s}[0], [%[input_sum]]\n" + "mov w10, %w[row]\n" + "mov x11, %[src_ic]\n" + "mov x12, %[packed_ic]\n" + "sxtl v6.8h, v12.8b\n" + "sxtl v12.4s, v6.4h\n" + "cmp w10, wzr\n" + "beq 1f\n" + "2:\n" + "subs w10, w10, #4\n" + "ld1 {v0.s}[0], [x11], %[row_stride]\n" + "ld1 {v1.s}[0], [x11], %[row_stride]\n" + "ld1 {v0.s}[1], [x11], %[row_stride]\n" + "ld1 {v1.s}[1], [x11], %[row_stride]\n" + "zip1 v2.8b, v0.8b, v1.8b\n" + "zip2 v3.8b, v0.8b, v1.8b\n" + "zip1 v4.4h, v2.4h, v3.4h\n" + "zip2 v5.4h, v2.4h, v3.4h\n" + "st1 {v4.4h, v5.4h}, [x12], #16\n" + + "sxtl v6.8h, v0.8b\n" + "sxtl v7.4s, v6.4h\n" + "sxtl2 v8.4s, v6.8h\n" + "sxtl v9.8h, v1.8b\n" + "sxtl v10.4s, v9.4h\n" + "sxtl2 v11.4s, v9.8h\n" + "add v10.4s, v10.4s, v7.4s\n" + "add v10.4s, v10.4s, v8.4s\n" + "add v10.4s, v10.4s, v10.4s\n" + "add v10.4s, v10.4s, v11.4s\n" + "bgt 2b\n" + "1:\n" + + : + : [src_ic] "r"(src_ic), [packed_ic] "r"(packed_ic), [input_sum] "r"(input_sum), [row] "r"(row), + [row_stride] "r"(row_stride), [filter_zp] "r"(filter_zp) + : "memory", "w10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12"); + + return; +} +#endif + +// For matmul input a transpose case +void PackInput2Col4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, int row, + int col, int row_stride, int32_t filter_zp) { + const int row_tile = C4NUM; + int row_align = UP_ROUND(row, row_tile); + int row_div = row / row_tile * row_tile; + const int row_res = row - row_div; + + const int col_tile = C4NUM; + int col_div = col / col_tile * col_tile; + const int col_res = col - col_div; + + const int8_t *src_ic = NULL; + int8_t *packed_ic = NULL; + int32_t *tmp_sum = NULL; + for (int c = 0; c < col_div; c += C4NUM) { + int r = 0; + src_ic = src_input + c; + packed_ic = packed_input + c * row_align; + tmp_sum = input_sum + c; +#ifdef ENABLE_ARM64 + PackInput2Col4x4AndInputSumPert_arm64(src_ic, packed_ic, tmp_sum, row_div, row_stride, filter_zp); + packed_ic += C4NUM * row_div; + src_ic += row_div * row_stride; +#else + for (; r < row_div; r += C4NUM) { + for (int i = 0; i < row_tile; i++) { + packed_ic[0 * row_tile + i] = src_ic[i * row_stride + 0]; + packed_ic[1 * row_tile + i] = src_ic[i * row_stride + 1]; + packed_ic[2 * row_tile + i] = src_ic[i * row_stride + 2]; + packed_ic[3 * row_tile + i] = src_ic[i * row_stride + 3]; + + tmp_sum[0] += src_ic[i * row_stride + 0]; + tmp_sum[1] += src_ic[i * row_stride + 1]; + tmp_sum[2] += src_ic[i * row_stride + 2]; + tmp_sum[3] += src_ic[i * row_stride + 3]; + } + packed_ic += C16NUM; + src_ic += row_tile * row_stride; + } +#endif + + for (r = 0; r < row_res; ++r) { + for (int i = 0; i < C4NUM; ++i) { + packed_ic[i * row_tile + r] = src_ic[r * row_stride + i]; + tmp_sum[i] += src_ic[r * row_stride + i]; + } + } + } + if (col_res == 0) { + for (int i = 0; i < col; ++i) { + input_sum[i] *= filter_zp; + } + return; + } + src_ic = src_input + col_div; + packed_ic = packed_input + row_align * col_div; + tmp_sum = input_sum + col_div; + for (int r = 0; r < row_div; r += row_tile) { + for (int i = 0; i < col_res; ++i) { + packed_ic[i * row_tile + 0] = src_ic[0 * row_stride + i]; + packed_ic[i * row_tile + 1] = src_ic[1 * row_stride + i]; + packed_ic[i * row_tile + 2] = src_ic[2 * row_stride + i]; + packed_ic[i * row_tile + 3] = src_ic[3 * row_stride + i]; + + tmp_sum[i] += src_ic[0 * row_stride + i]; + tmp_sum[i] += src_ic[1 * row_stride + i]; + tmp_sum[i] += src_ic[2 * row_stride + i]; + tmp_sum[i] += src_ic[3 * row_stride + i]; + } + src_ic += row_tile * row_stride; + packed_ic += row_tile * col_tile; + } + + for (int r = 0; r < row_res; ++r) { + for (int c = 0; c < col_res; ++c) { + packed_ic[c * row_tile + r] = src_ic[r * row_stride + c]; + tmp_sum[c] += src_ic[r * row_stride + c]; + } + } + + for (int i = 0; i < col; ++i) { + input_sum[i] *= filter_zp; + } +} + +void RowMajor2Col16x4MajorInt8(const int8_t *src, int8_t *dst, int row, int col) { + int row_16 = UP_ROUND(row, C16NUM); + int stride = sizeof(int8_t) * 16 * 4; + for (int r = 0; r < row_16; ++r) { + for (int c = 0; c < col; ++c) { + int stride_idx = c / 4 * (row_16 / 16) + r / 16; + if (r >= row) { + dst[stride * stride_idx + c % 4 * 16 + r % 16] = 0; + } else { + int src_idx = r * col + c; + dst[stride * stride_idx + c % 4 * 16 + r % 16] = src[src_idx]; + } + } + } +} + +void RowMajor2Col4x4MajorInt8(const int8_t *src, int row, int col, int8_t *dst) { + int row_4 = UP_ROUND(row, C4NUM); + int stride = C4NUM * C4NUM; + for (int r = 0; r < row_4; ++r) { + for (int c = 0; c < col; ++c) { + int stride_idx = c / C4NUM * (row_4 / C4NUM) + r / C4NUM; + if (r >= row) { + dst[stride * stride_idx + c % C4NUM * C4NUM + r % C4NUM] = 0; + } else { + int src_idx = r * col + c; + dst[stride * stride_idx + c % C4NUM * C4NUM + r % C4NUM] = src[src_idx]; + } + } + } +} + +void RowMajor2Col4x16MajorPartInt8(const int8_t *src, int8_t *dst, int row, int col, int cur_oc) { + int row_4 = UP_ROUND(row, C4NUM); + int stride = C16NUM * C4NUM; + for (int r = 0; r < row_4; ++r) { + for (int c = 0; c < cur_oc; ++c) { + int stride_idx = c / C16NUM * (row_4 / C4NUM) + r / C4NUM; + if (r >= row) { + dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = 0; + } else { + int src_idx = r * col + c; + dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = src[src_idx]; + } + } + } +} + +void RowMajor2Col4x16MajorInt8(const int8_t *src, int8_t *dst, int row, int col) { + int row_4 = UP_ROUND(row, C4NUM); + int stride = C16NUM * C4NUM; + for (int r = 0; r < row_4; ++r) { + for (int c = 0; c < col; ++c) { + int stride_idx = c / C16NUM * (row_4 / C4NUM) + r / C4NUM; + if (r >= row) { + dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = 0; + } else { + int src_idx = r * col + c; + dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = src[src_idx]; + } + } + } +} + +void CalcInputSums(const int8_t *input, int row, int col, int weight_zp, int32_t *dst, DataOrder order) { + for (int r = 0; r < row; ++r) { + int sum = 0; + for (int c = 0; c < col; ++c) { + if (order == RowMajor) { + sum += input[r * col + c]; + } else { + sum += input[c * row + r]; + } + } + sum *= weight_zp; + dst[r] = sum; + } +} + +// dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums +void CalcWeightBiasSums(const int8_t *weight, int row, int col, int input_zp, const int32_t *weight_zp_ptr, + const int32_t *bias, int32_t *dst, DataOrder order, bool filter_per_channel) { + for (int c = 0; c < col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + if (order == RowMajor) { + sum += weight[r * col + c]; + } else { + sum += weight[c * row + r]; + } + } + int weight_zp = filter_per_channel ? weight_zp_ptr[c] : weight_zp_ptr[0]; + dst[c] = row * input_zp * weight_zp - input_zp * sum; + if (bias != NULL) { + dst[c] += bias[c]; + } + } +} + +void CalcPartWeightBiasSums(const int8_t *weight, int row, int stride, int cur_col, int input_zp, + const int32_t *weight_zp_ptr, const int32_t *bias, int32_t *dst, DataOrder order, + bool filter_per_channel) { + for (int c = 0; c < cur_col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + if (order == RowMajor) { + sum += weight[r * stride + c]; + } else { + sum += weight[c * row + r]; + } + } + int weight_zp = filter_per_channel ? weight_zp_ptr[c] : weight_zp_ptr[0]; + dst[c] = row * input_zp * weight_zp - input_zp * sum; + if (bias != NULL) { + dst[c] += bias[c]; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/matmul_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/matmul_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..3dec3b18c042377cf941315dbf6982c8e6af0a8f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/matmul_int8.h @@ -0,0 +1,93 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_MATMUL_H_ +#define NNACL_INT8_MATMUL_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/matmul_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +/* 4x16 16x4 -> 4x4 */ +/* sdot 4x4 4x16 -> 4x16 */ +/* matmul */ +void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int32_t *dst, int row_4, int col_4, int deep_16, + const int32_t *input_sum, const int32_t *bias); +void RowMajor2Row16x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void RowMajor2Col16x4MajorInt8(const int8_t *src, int8_t *dst, int row, int col); +void RowMajor2Col4x16MajorInt8(const int8_t *src, int8_t *dst, int row, int col); +void RowMajor2Col4x16MajorPartInt8(const int8_t *src, int8_t *dst, int row, int col, int cur_oc); +void PackInput2Col4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, int row, + int col, int row_stride, int32_t filter_zp); +void CalcInputSums(const int8_t *input, int row, int col, int weight_zp, int32_t *dst, DataOrder order); +void CalcWeightBiasSums(const int8_t *weight, int row, int col, int input_zp, const int32_t *weight_zp_ptr, + const int32_t *bias, int32_t *dst, DataOrder order, bool filter_per_channel); +void CalcPartWeightBiasSums(const int8_t *weight, int row, int stride, int cur_col, int input_zp, + const int32_t *weight_zp_ptr, const int32_t *bias, int32_t *dst, DataOrder order, + bool filter_per_channel); +void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int32_t *a_sums, + const int32_t *bias, int act_min, int act_max, int out_zp, const int32_t *multiplier, + const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc, + const int32_t *filter_zp); +/* 8x4 4x8 -> 8x8 */ +/* optimize conv */ +void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, size_t per_channel); + +/* 4x16 16x2 -> 4x2 */ +/* arm32 conv1x1 */ +void RowMajor2Row2x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void RowMajor2Col16x2MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, bool peroc); + +/* 4x4 4x16 -> 4x16 */ +/* optimize conv1x1 */ +void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, + size_t input_channel, size_t plane_size, int32_t filter_zp); +void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, size_t per_channel, const int32_t *filter_zp); + +#ifdef ENABLE_ARM64 +void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, + const int32_t *a_sums, const int32_t *bias, int act_min, int act_max, int out_zp, + int32_t *multiplier, int32_t *left_shift, int32_t *right_shift, int row, int col, int stride, + int filter_peroc); + +void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, + const int32_t *input_sum, const int32_t *bias); +#endif +#ifdef ENABLE_ARM32 +void MatmulInt8Neon32(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, + const int32_t *input_sums, const int32_t *weight_bias, int act_min, int act_max, int out_zp, + int32_t *multiplier, int32_t *left_shift, int32_t *right_shift, int stride, int per_channel); +#endif +#ifdef __cplusplus +} +#endif + +#endif // LITE_SRC_BACKEND_ARM_NNACL_INT8_MATMUL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/mul_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/mul_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..d7e5593452cb3d1b8c5b41556a91202d3f701c64 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/mul_int8.c @@ -0,0 +1,238 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/mul_int8.h" + +#ifdef ENABLE_NEON +int16x4_t ClacSumHalfWordMul(int16x4_t scaled_input0, int16x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t right_shift_out_vec, int32x4_t output_multiplier_vec) { + int32x4_t input_scale = vmull_s16(scaled_input0, scaled_input1); + int32x4_t raw_sum = vqrdmulhq_s32(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(raw_sum, right_shift_out_vec), 31); + const int32x4_t fixed_up_x = vqaddq_s32(raw_sum, fixup); + raw_sum = vrshlq_s32(fixed_up_x, right_shift_out_vec); + return vqmovn_s32(raw_sum); +} + +void MulInt8NEON(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const MulQuantArg *quant_arg, int32_t *index) { + int32x4_t output_multiplier_vec = vdupq_n_s32(quant_arg->output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << (size_t)quant_arg->shift_left_); + int32x4_t right_shift_out_vec = vdupq_n_s32(-quant_arg->shift_right_); + int16x8_t out_zp_vec = vdupq_n_s16(quant_arg->out_quant_arg_.zp_); + int8x16_t out_min_vec = vdupq_n_s8(quant_arg->output_activation_min_); + int8x16_t out_max_vec = vdupq_n_s8(quant_arg->output_activation_max_); + int8x8_t out_min_vec_s8 = vdup_n_s8(quant_arg->output_activation_min_); + int8x8_t out_max_vec_s8 = vdup_n_s8(quant_arg->output_activation_max_); + + for (; (*index) <= real_dst_count - 16; (*index) += 16) { + int16x8_t zp1_vec = vdupq_n_s16(quant_arg->in_quant_args_[0].zp_); + int16x8_t zp2_vec = vdupq_n_s16(quant_arg->in_quant_args_[1].zp_); + int8x16_t input0_vec = vld1q_s8(input0_data + *index); + int8x16_t input1_vec = vld1q_s8(input1_data + *index); + int16x8_t input0_low = vmovl_s8(vget_low_s8(input0_vec)); + int16x8_t input0_high = vmovl_s8(vget_high_s8(input0_vec)); + int16x8_t input1_low = vmovl_s8(vget_low_s8(input1_vec)); + int16x8_t input1_high = vmovl_s8(vget_high_s8(input1_vec)); + input0_low = vaddq_s16(input0_low, zp1_vec); + input0_high = vaddq_s16(input0_high, zp1_vec); + input1_low = vaddq_s16(input1_low, zp2_vec); + input1_high = vaddq_s16(input1_high, zp2_vec); + + int16x4_t input0_low_low = vget_low_s16(input0_low); + int16x4_t input0_low_high = vget_high_s16(input0_low); + int16x4_t input0_high_low = vget_low_s16(input0_high); + int16x4_t input0_high_high = vget_high_s16(input0_high); + int16x4_t input1_low_low = vget_low_s16(input1_low); + int16x4_t input1_low_high = vget_high_s16(input1_low); + int16x4_t input1_high_low = vget_low_s16(input1_high); + int16x4_t input1_high_high = vget_high_s16(input1_high); + + int16x4_t sum_low_low = ClacSumHalfWordMul(input0_low_low, input1_low_low, left_shift_out_vec, right_shift_out_vec, + output_multiplier_vec); + int16x4_t sum_low_high = ClacSumHalfWordMul(input0_low_high, input1_low_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_low = ClacSumHalfWordMul(input0_high_low, input1_high_low, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_high = ClacSumHalfWordMul(input0_high_high, input1_high_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low_low, sum_low_high), out_zp_vec); + int16x8_t res_s162 = vaddq_s16(vcombine_s16(sum_high_low, sum_high_high), out_zp_vec); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + int8x8_t res_u8_n1 = vqmovn_s16(res_s162); + int8x16_t res_s8 = vcombine_s8(res_u8_n0, res_u8_n1); + res_s8 = vminq_s8(res_s8, out_max_vec); + res_s8 = vmaxq_s8(res_s8, out_min_vec); + vst1q_s8(output_data, res_s8); + output_data += 16; + } + for (; (*index) <= real_dst_count - 8; (*index) += 8) { + int16x8_t input0_val = LoadAndAddOffset(input0_data, *index, quant_arg->in_quant_args_[0].zp_); + int16x8_t input1_val = LoadAndAddOffset(input1_data, *index, quant_arg->in_quant_args_[1].zp_); + + int16x4_t input0_low = vget_low_s16(input0_val); + int16x4_t input0_high = vget_high_s16(input0_val); + int16x4_t input1_low = vget_low_s16(input1_val); + int16x4_t input1_high = vget_high_s16(input1_val); + + int16x4_t sum_low = + ClacSumHalfWordMul(input0_low, input1_low, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high = + ClacSumHalfWordMul(input0_high, input1_high, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); + + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low, sum_high), out_zp_vec); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + res_u8_n0 = vmin_s8(res_u8_n0, out_max_vec_s8); + res_u8_n0 = vmax_s8(res_u8_n0, out_min_vec_s8); + vst1_s8(output_data, res_u8_n0); + output_data += 8; + } +} +#endif + +void FastMul(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int depth, + int64_t real_dst_count, bool input1_broad, const MulQuantArg *quant_arg) { + // input0 need broadcast + int32_t zp1 = quant_arg->in_quant_args_[0].zp_; + int32_t zp2 = quant_arg->in_quant_args_[1].zp_; + if (input1_broad) { + zp1 = quant_arg->in_quant_args_[1].zp_; + zp2 = quant_arg->in_quant_args_[0].zp_; + } +#ifdef ENABLE_NENO + int32x4_t output_multiplier_vec = vdupq_n_s32(quant_arg->output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << (size_t)quant_arg->shift_left_); + int32x4_t right_shift_out_vec = vdupq_n_s32(-quant_arg->shift_right_); + int16x8_t out_zp_vec = vdupq_n_s16(quant_arg->out_quant_arg_.zp_); + int8x16_t out_min_vec = vdupq_n_s8(quant_arg->output_activation_min_); + int8x16_t out_max_vec = vdupq_n_s8(quant_arg->output_activation_max_); + int8x8_t out_min_vec_s8 = vdup_n_s8(quant_arg->output_activation_min_); + int8x8_t out_max_vec_s8 = vdup_n_s8(quant_arg->output_activation_max_); + int16x8_t zp1_vec = vdupq_n_s16(zp1); + int16x8_t zp2_vec = vdupq_n_s16(zp2); +#endif + for (int index = 0; index < real_dst_count; ++index) { + int j = 0; +#ifdef ENABLE_NENO + for (; j <= depth - 16; j += 16) { + int8x16_t input0_vec = vld1q_s8(input0_data + j); + int8x16_t input1_vec = vld1q_s8(input1_data); + int16x8_t input0_low = vmovl_s8(vget_low_s8(input0_vec)); + int16x8_t input0_high = vmovl_s8(vget_high_s8(input0_vec)); + int16x8_t input1_low = vmovl_s8(vget_low_s8(input1_vec)); + int16x8_t input1_high = vmovl_s8(vget_high_s8(input1_vec)); + input0_low = vaddq_s16(input0_low, zp1_vec); + input0_high = vaddq_s16(input0_high, zp1_vec); + input1_low = vaddq_s16(input1_low, zp2_vec); + input1_high = vaddq_s16(input1_high, zp2_vec); + + int16x4_t input0_low_low = vget_low_s16(input0_low); + int16x4_t input0_low_high = vget_high_s16(input0_low); + int16x4_t input0_high_low = vget_low_s16(input0_high); + int16x4_t input0_high_high = vget_high_s16(input0_high); + int16x4_t input1_low_low = vget_low_s16(input1_low); + int16x4_t input1_low_high = vget_high_s16(input1_low); + int16x4_t input1_high_low = vget_low_s16(input1_high); + int16x4_t input1_high_high = vget_high_s16(input1_high); + + int16x4_t sum_low_low = ClacSumHalfWordMul(input0_low_low, input1_low_low, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_low_high = ClacSumHalfWordMul(input0_low_high, input1_low_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_low = ClacSumHalfWordMul(input0_high_low, input1_high_low, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_high = ClacSumHalfWordMul(input0_high_high, input1_high_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low_low, sum_low_high), out_zp_vec); + int16x8_t res_s162 = vaddq_s16(vcombine_s16(sum_high_low, sum_high_high), out_zp_vec); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + int8x8_t res_u8_n1 = vqmovn_s16(res_s162); + int8x16_t res_s8 = vcombine_s8(res_u8_n0, res_u8_n1); + res_s8 = vminq_s8(res_s8, out_max_vec); + res_s8 = vmaxq_s8(res_s8, out_min_vec); + vst1q_s8(output_data, res_s8); + input1_data += 16; + output_data += 16; + } + for (; j <= depth - 8; j += 8) { + int8x8_t input0_vec = vld1_s8(input0_data + j); + int8x8_t input1_vec = vld1_s8(input1_data); + int16x8_t input0_val = vmovl_s8(input0_vec); + int16x8_t input1_val = vmovl_s8(input1_vec); + input0_val = vaddq_s16(input0_val, zp1_vec); + input1_val = vaddq_s16(input1_val, zp2_vec); + + int16x4_t input0_low = vget_low_s16(input0_val); + int16x4_t input0_high = vget_high_s16(input0_val); + int16x4_t input1_low = vget_low_s16(input1_val); + int16x4_t input1_high = vget_high_s16(input1_val); + + int16x4_t sum_low = + ClacSumHalfWordMul(input0_low, input1_low, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high = + ClacSumHalfWordMul(input0_high, input1_high, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); + + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low, sum_high), out_zp_vec); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + res_u8_n0 = vmin_s8(res_u8_n0, out_max_vec_s8); + res_u8_n0 = vmax_s8(res_u8_n0, out_min_vec_s8); + vst1_s8(output_data, res_u8_n0); + input1_data += 8; + output_data += 8; + } +#endif + for (; j < depth; ++j) { + const int32_t input0_val = zp1 + input0_data[j]; + const int32_t input1_val = zp2 + input1_data[0]; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << (size_t)quant_arg->shift_left_), + quant_arg->output_multiplier_), + quant_arg->shift_right_); + + mul_result += quant_arg->out_quant_arg_.zp_; + mul_result = mul_result < quant_arg->output_activation_max_ ? mul_result : quant_arg->output_activation_max_; + mul_result = mul_result > quant_arg->output_activation_min_ ? mul_result : quant_arg->output_activation_min_; + output_data[0] = (int8_t)mul_result; + input1_data++; + output_data++; + } + } + return; +} + +void Mul(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const MulQuantArg *quant_arg) { + int index = 0; +#ifdef ENABLE_NEON + MulInt8NEON(input0_data, input1_data, output_data, real_dst_count, quant_arg, &index); +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = quant_arg->in_quant_args_[0].zp_ + input0_data[index]; + const int32_t input1_val = quant_arg->in_quant_args_[1].zp_ + input1_data[index]; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << (size_t)quant_arg->shift_left_), + quant_arg->output_multiplier_), + quant_arg->shift_right_); + + mul_result += quant_arg->out_quant_arg_.zp_; + mul_result = mul_result < quant_arg->output_activation_max_ ? mul_result : quant_arg->output_activation_max_; + mul_result = mul_result > quant_arg->output_activation_min_ ? mul_result : quant_arg->output_activation_min_; + output_data[index] = (int8_t)mul_result; + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/mul_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/mul_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..ef88bba11bad3075e4cda8135127dbcf48ee1cf3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/mul_int8.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_MUL_INT8_H_ +#define NNACL_INT8_MUL_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/mul_parameter.h" +#include "nnacl_c/int8/common_func_int8.h" +#include "nnacl_c/int8/fixed_point.h" +#ifdef ENABLE_NEON +#include +#endif + +#ifdef __cplusplus +extern "C" { +#endif +void Mul(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const MulQuantArg *quant_arg); +void FastMul(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int depth, + int64_t real_dst_count, bool input1_broad, const MulQuantArg *quant_arg); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_MUL_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..25f0eceac44e7b5243ab4b9880d11fcc4b069b86 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.c @@ -0,0 +1,452 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/pack_int8.h" + +void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, const ConvParameter *conv_param) { + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int ic8_round = UP_ROUND(in_channel, C8NUM); + int ic8 = in_channel / C8NUM * C8NUM; + int in_plane = in_h * in_w; + + for (int b = 0; b < in_batch; b++) { + int src_batch_offset = b * in_channel * in_plane; + int dst_batch_offset = b * ic8_round * in_plane; + for (int k = 0; k < in_plane; k++) { + int src_plane_offset = src_batch_offset + k * in_channel; + int dst_plane_offset = dst_batch_offset + k * C8NUM; + for (int i = 0; i < ic8; i += 8) { + int src_c_offset = src_plane_offset + i; + int dst_c_offset = dst_plane_offset + i * in_plane; +#ifdef ENABLE_ARM + vst1q_s16(packed_input + dst_c_offset, vmovl_s8(vld1_s8(input_data + src_c_offset))); +#else + for (int j = 0; j < C8NUM; ++j) { + (packed_input + dst_c_offset)[j] = (int16_t)(input_data + src_c_offset)[j]; + } +#endif + } // ic8_minus loop + int res_c = in_channel - ic8; + int tmp_ic_offset = ic8 * in_plane; + for (int l = 0; l < res_c; ++l) { + int src_c_offset = src_plane_offset + ic8 + l; + int dst_c_offset = dst_plane_offset + tmp_ic_offset + l; + (packed_input + dst_c_offset)[0] = (int16_t)(input_data + src_c_offset)[0]; + } // res ic loop + int res2 = ic8_round - in_channel; + for (int l = 0; l < res2; ++l) { + int dst_c_offset = dst_plane_offset + tmp_ic_offset + res_c + l; + (packed_input + dst_c_offset)[0] = 0; + } // res ic loop + } // kh * kw loop + } +} + +void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, + const ConvParameter *conv_param) { + // origin weight format : ohwi + int input_channel = conv_param->input_channel_; + int ic8 = input_channel / C8NUM * C8NUM; + int ic8_round = UP_ROUND(input_channel, C8NUM); + int output_channel = conv_param->output_channel_; + QuantArg *filter_zp = conv_param->conv_quant_arg_.filter_quant_args_; + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; + + for (int k = 0; k < kernel_plane; k++) { + int src_kernel_offset = k * input_channel; + int dst_kernel_offset = k * C8NUM; + for (int o = 0; o < output_channel; o++) { + int32_t zp; + if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { + zp = filter_zp[0].zp_; + } else { + zp = filter_zp[o].zp_; + } + int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; + int dst_oc_offset = dst_kernel_offset + o * ic8_round * kernel_plane; + int i = 0; + for (; i < ic8; i += C8NUM) { + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + i * kernel_plane; +#ifdef ENABLE_ARM64 + int8x8_t src_s8 = vld1_s8(origin_weight_data + src_ic_offset); + int16x8_t src_s16 = vmovl_s8(src_s8); + int16x4_t src1_s16 = vget_low_s16(src_s16); + int16x4_t src2_s16 = vget_high_s16(src_s16); + int32x4_t src1_s32 = vmovl_s16(src1_s16); + int32x4_t src2_s32 = vmovl_s16(src2_s16); + int32x4_t zp_s32 = vdupq_n_s32(zp); + int32x4_t dst1_s32 = vsubq_s32(src1_s32, zp_s32); + int32x4_t dst2_s32 = vsubq_s32(src2_s32, zp_s32); + int16x4_t dst1_s16 = vqmovn_s32(dst1_s32); + int16x4_t dst2_s16 = vqmovn_s32(dst2_s32); + vst1_s16(packed_weight_data + dst_ic_offset, dst1_s16); + vst1_s16(packed_weight_data + dst_ic_offset + 4, dst2_s16); +#else + for (int ci = 0; ci < C8NUM; ++ci) { + (packed_weight_data + dst_ic_offset + ci)[0] = (int16_t)((origin_weight_data + src_ic_offset + ci)[0] - zp); + } +#endif + } + dst_oc_offset += ic8 * kernel_plane; + for (; i < input_channel; i++) { + int c8_block_rem = i % C8NUM; + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + c8_block_rem; + (packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - zp); + } + } + } +} + +void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) { + /* normal matmul : 4x16 * 16x4 -> 4x4 */ +#ifdef ENABLE_ARM + PreSum4x16Int8Pert(src, dst, row4, col16, filter_zp); +#else + for (size_t r = 0; r < row4; r++) { + int32_t tmp_value = 0; + for (size_t c = 0; c < col16; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM; + int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod; + tmp_value += src[src_index]; + } + dst[r] = tmp_value * filter_zp; + } +#endif + return; +} +void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param) { + int input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; + int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); + int unit = conv_param->input_h_ * conv_param->input_w_; + + for (int b = 0; b < conv_param->input_batch_; b++) { + const int8_t *src_b = src + b * unit * conv_param->input_channel_; + int16_t *dst_b = dst + b * unit * ic4 * C4NUM; + for (int k = 0; k < unit; k++) { + const int8_t *src_k = src_b + k * conv_param->input_channel_; + int16_t *dst_k = dst_b + k * ic4 * C4NUM; + for (int c = 0; c < conv_param->input_channel_; c++) { + dst_k[c] = (int16_t)(src_k[c] - input_zp); + } + } + } +} + +void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + const ConvQuantArg *quant_qrg) { + int weight_zp = quant_qrg->filter_quant_args_[0].zp_; + for (int c = 0; c < channel; c++) { + if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { + weight_zp = quant_qrg->filter_quant_args_[c].zp_; + } + int c8_block_num = c / C8NUM; + int c8_block_rem = c % C8NUM; + const int8_t *src_c = origin_weight + c * plane; + int16_t *dst_c = packed_weight_ + c8_block_num * plane * C8NUM; + for (int k = 0; k < plane; k++) { + const int8_t *src_kernel = src_c + k; + int16_t *dst_kernel = dst_c + C8NUM * k + c8_block_rem; + *dst_kernel = (int16_t)(src_kernel[0] - weight_zp); + } + } +} + +void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + const ConvQuantArg *quant_qrg) { + int weight_zp = quant_qrg->filter_quant_args_[0].zp_; + for (int c = 0; c < channel; c++) { + if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { + weight_zp = quant_qrg->filter_quant_args_[c].zp_; + } + int c4_block_num = c / C4NUM; + int c4_block_rem = c % C4NUM; + const int8_t *src_c = origin_weight + c * plane; + int16_t *dst_c = packed_weight_ + c4_block_num * plane * C4NUM; + for (int k = 0; k < plane; k++) { + const int8_t *src_kernel = src_c + k; + int16_t *dst_kernel = dst_c + C4NUM * k + c4_block_rem; + *dst_kernel = (int16_t)(src_kernel[0] - weight_zp); + } + } +} +void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int c4_channel = c4 * C4NUM; + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc4_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + int8_t *dst_per_plane = (int8_t *)dst + nhwc4_batch_offset + i * c4_channel; + memcpy(dst_per_plane, (int8_t *)src + batch_offset + i * channel, channel); + for (int j = channel; j < c4_channel; ++j) { + dst_per_plane[j] = 0; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + int nhwc4_batch_offset = b * nhwc4_batch_unit_offset; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc4_batch_offset + i * c4 * C4NUM, + channel); + } + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + int nhwc8_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + nhwc8_batch_offset + i * c8 * C8NUM, (int8_t *)src + batch_offset + i * channel, + channel); + } + nhwc8_batch_offset += nhwc8_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + int nhwc8_batch_offset = b * nhwc8_batch_unit_offset; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc8_batch_offset + i * c8 * C8NUM, + channel); + } + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C4NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c4 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C4NUM; + int dst_c_offset = dst_kernel_offset + c * C4NUM; + ((int8_t *)dst + dst_c_offset)[0] = ((int8_t *)src + src_c_offset)[0]; + ((int8_t *)dst + dst_c_offset)[1] = ((int8_t *)src + src_c_offset)[1]; + ((int8_t *)dst + dst_c_offset)[2] = ((int8_t *)src + src_c_offset)[2]; + ((int8_t *)dst + dst_c_offset)[3] = ((int8_t *)src + src_c_offset)[3]; + } + // res part + int res_c = channel - (c4 - 1) * C4NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; + ((int8_t *)dst + dst_res_c_offset)[0] = ((int8_t *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int c = 0; c < channel; c++) { + for (int hw = 0; hw < plane; hw++) { + int nhwc_index = n * channel * plane + hw * channel + c; + int nchw_index = n * channel * plane + c * plane + hw; + ((int8_t *)(dst))[nhwc_index] = ((const int8_t *)(src))[nchw_index]; + } + } + } + return; +} + +void PackNHWCToNCHWInt8(const void *src, void *dst, int batches, int plane, int channel) { + int hw8 = plane / C8NUM * C8NUM; + int c8 = channel / C8NUM * C8NUM; + int batch = plane * channel; + for (int n = 0; n < batches; n++) { + const int8_t *src_batch = (const int8_t *)src + n * batch; + int8_t *dst_batch = (int8_t *)dst + n * batch; + int hw = 0; + for (; hw < hw8; hw += C8NUM) { + int c = 0; + for (; c < c8; c += C8NUM) { + const int8_t *src_ptr = src_batch + hw * channel + c; + int8_t *dst_ptr = dst_batch + c * plane + hw; +#ifdef ENABLE_ARM64 + size_t srcStride = channel * sizeof(int8_t); + size_t dstStride = plane * sizeof(int8_t); + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8b}, [x10], %[srcStride]\n" + "ld1 {v1.8b}, [x10], %[srcStride]\n" + "ld1 {v2.8b}, [x10], %[srcStride]\n" + "ld1 {v3.8b}, [x10], %[srcStride]\n" + + "trn1 v4.8b, v0.8b, v1.8b\n" + "trn2 v5.8b, v0.8b, v1.8b\n" + "trn1 v6.8b, v2.8b, v3.8b\n" + "trn2 v7.8b, v2.8b, v3.8b\n" + + "ld1 {v0.8b}, [x10], %[srcStride]\n" + "ld1 {v1.8b}, [x10], %[srcStride]\n" + "ld1 {v2.8b}, [x10], %[srcStride]\n" + "ld1 {v3.8b}, [x10], %[srcStride]\n" + + "trn1 v8.4h, v4.4h, v6.4h\n" + "trn2 v9.4h, v4.4h, v6.4h\n" + "trn1 v10.4h, v5.4h, v7.4h\n" + "trn2 v11.4h, v5.4h, v7.4h\n" + + "trn1 v4.8b, v0.8b, v1.8b\n" + "trn2 v5.8b, v0.8b, v1.8b\n" + "trn1 v6.8b, v2.8b, v3.8b\n" + "trn2 v7.8b, v2.8b, v3.8b\n" + + "trn1 v12.4h, v4.4h, v6.4h\n" + "trn2 v13.4h, v4.4h, v6.4h\n" + "trn1 v14.4h, v5.4h, v7.4h\n" + "trn2 v15.4h, v5.4h, v7.4h\n" + + "trn1 v0.2s, v8.2s, v12.2s\n" + "trn2 v4.2s, v8.2s, v12.2s\n" + "trn1 v1.2s, v10.2s, v14.2s\n" + "trn2 v5.2s, v10.2s, v14.2s\n" + "trn1 v2.2s, v9.2s, v13.2s\n" + "trn2 v6.2s, v9.2s, v13.2s\n" + "trn1 v3.2s, v11.2s, v15.2s\n" + "trn2 v7.2s, v11.2s, v15.2s\n" + + "st1 {v0.8b}, [x11], %[dstStride]\n" + "st1 {v1.8b}, [x11], %[dstStride]\n" + "st1 {v2.8b}, [x11], %[dstStride]\n" + "st1 {v3.8b}, [x11], %[dstStride]\n" + "st1 {v4.8b}, [x11], %[dstStride]\n" + "st1 {v5.8b}, [x11], %[dstStride]\n" + "st1 {v6.8b}, [x11], %[dstStride]\n" + "st1 {v7.8b}, [x11], %[dstStride]\n" + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [srcStride] "r"(srcStride), [dstStride] "r"(dstStride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31"); +#elif ENABLE_ARM32 + size_t srcStride = channel * sizeof(int8_t); + size_t dstStride = plane * sizeof(int8_t); + asm volatile( + "mov r10, %[src_ptr]\n" + "mov r12, %[dst_ptr]\n" + + "vld1.8 {d0}, [r10], %[srcStride]\n" + "vld1.8 {d1}, [r10], %[srcStride]\n" + "vld1.8 {d2}, [r10], %[srcStride]\n" + "vld1.8 {d3}, [r10], %[srcStride]\n" + "vld1.8 {d4}, [r10], %[srcStride]\n" + "vld1.8 {d5}, [r10], %[srcStride]\n" + "vld1.8 {d6}, [r10], %[srcStride]\n" + "vld1.8 {d7}, [r10], %[srcStride]\n" + + "vtrn.8 d0, d1\n" + "vtrn.8 d2, d3\n" + "vtrn.8 d4, d5\n" + "vtrn.8 d6, d7\n" + + "vtrn.16 d0, d2\n" + "vtrn.16 d1, d3\n" + "vtrn.16 d4, d6\n" + "vtrn.16 d5, d7\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + + "vst1.8 {d0}, [r12], %[dstStride]\n" + "vst1.8 {d1}, [r12], %[dstStride]\n" + "vst1.8 {d2}, [r12], %[dstStride]\n" + "vst1.8 {d3}, [r12], %[dstStride]\n" + "vst1.8 {d4}, [r12], %[dstStride]\n" + "vst1.8 {d5}, [r12], %[dstStride]\n" + "vst1.8 {d6}, [r12], %[dstStride]\n" + "vst1.8 {d7}, [r12], %[dstStride]\n" + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [srcStride] "r"(srcStride), [dstStride] "r"(dstStride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +#else + for (int tr = 0; tr < C8NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; + } + } +#endif + } + for (; c < channel; c++) { + const int8_t *src_ptr = src_batch + hw * channel + c; + int8_t *dst_ptr = dst_batch + c * plane + hw; + for (size_t i = 0; i < C8NUM; i++) { + dst_ptr[i] = src_ptr[i * channel]; + } + } + } + for (; hw < plane; hw++) { + const int8_t *src_ptr = src_batch + hw * channel; + int8_t *dst_ptr = dst_batch + hw; + for (size_t i = 0; i < channel; i++) { + dst_ptr[i * plane] = src_ptr[i]; + } + } + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..b89b74c933fb8c4acf9f3a536a8d5a61fe1fb51d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_PACK_INT8_H_ +#define NNACL_INT8_PACK_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel); + +void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16); +void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, const ConvParameter *conv_param); +void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, const ConvParameter *conv_param); +#ifdef ENABLE_ARM +void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp); +void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, const int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div, + size_t oc_res, size_t stride); +#endif + +void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param); +void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + const ConvQuantArg *quant_qrg); +void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + const ConvQuantArg *quant_qrg); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_PAD_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pad_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pad_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..cf042b6d9eaa0c36f28ba1b2ea3c0c27fde91e7a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pad_int8.c @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/pad_int8.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/errorcode.h" + +int PadConstant4D(const int8_t *in_data, int8_t *out_data, const int32_t *in_dims, const int32_t *out_dims, + const int32_t *paddings, const int tid, const int thread_num) { + if (thread_num == 0) { + return NNACL_ERR; + } + int32_t copy_size = in_dims[3]; + for (int n = 0; n < in_dims[0]; n++) { + for (int h = tid; h < in_dims[1]; h += thread_num) { + for (int w = 0; w < in_dims[2]; w++) { + const int8_t *in = in_data + Offset(in_dims, n, h, w, 0); + int8_t *out = out_data + Offset(out_dims, n + paddings[0], h + paddings[2], w + paddings[4], paddings[6]); + memcpy(out, in, (size_t)copy_size * sizeof(int8_t)); + } + } + } + return NNACL_OK; +} + +int TransOut2InputDimIndexInt8(int out_dim_index, int left_pad, int in_dim, int offset) { + if (out_dim_index < left_pad) { + // left pad + const int index_sum = left_pad + offset - 1; + return MSMAX(index_sum - out_dim_index, offset); + } + out_dim_index -= left_pad; + if (out_dim_index < in_dim) { + return out_dim_index; + } + // right pad + out_dim_index -= in_dim; + const int index_sum = in_dim - 1 - offset; + return MSMAX(index_sum - out_dim_index, 0); +} + +int GetInputFlattenIndexInt8(int out_flatten_index, const int32_t *input_shape, int mirror_offset, + const int *in_strides, const int *out_strides, const int *paddings) { + int in_flatten_index = 0; + int i; + for (i = 0; i < COMM_SHAPE_SIZE; ++i) { + int left_pad = paddings[i * 2]; + NNACL_CHECK_ZERO_RETURN_ERR(out_strides[i]); + int out_dim_index = out_flatten_index / out_strides[i]; + out_flatten_index %= out_strides[i]; + int in_dim_index = TransOut2InputDimIndexInt8(out_dim_index, left_pad, input_shape[i], mirror_offset); + in_flatten_index += in_dim_index * in_strides[i]; + } + return in_flatten_index; +} + +void MirrorPadInt8(const int8_t *in, int8_t *out, const int32_t *input_shape, int mirror_offset, const int *in_strides, + const int *out_strides, const int *paddings, int begin, int end) { + for (int i = begin; i < end; ++i) { + out[i] = in[GetInputFlattenIndexInt8(i, input_shape, mirror_offset, in_strides, out_strides, paddings)]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pad_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pad_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..08fe53c0566b7910a033d4d3a62a2205d5f8f290 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pad_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_PAD_INT8_H_ +#define NNACL_INT8_PAD_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/pad_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int PadConstant4D(const int8_t *in_data, int8_t *out_data, const int32_t *in_dims, const int32_t *out_dims, + const int32_t *paddings, const int tid, const int thread_num); +void MirrorPadInt8(const int8_t *in, int8_t *out, const int32_t *input_shape, int mirror_offset, const int *in_strides, + const int *out_strides, const int *paddings, int begin, int end); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_PAD_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..3f4de98dce589cc8ea33470140d485ac9608bcc9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.c @@ -0,0 +1,516 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/pooling_int8.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/errorcode.h" + +int AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = compute_args->input_channel_; + int in_w = compute_args->input_w_; + int in_h = compute_args->input_h_; + int output_w = compute_args->output_w_; + int output_h = compute_args->output_h_; + int output_batch = compute_args->output_batch_; + int out_plane = output_w * output_h; + float input_scale = quant_args[0][0].scale_; + int input_zp = quant_args[0][0].zp_; + float output_scale = quant_args[1][0].scale_; + int output_zp = quant_args[1][0].zp_; + double real_multiplier = input_scale / output_scale; + const int8_t out_min = INT8_MIN; + const int8_t out_max = INT8_MAX; + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int i = 0; i < out_plane; i++) { + int out_w_index = i % output_w; + int out_h_index = i / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + i * channel; + for (int j = 0; j < channel; j++) { + int in_channel_offset = in_batch_offset + j; + int out_channel_offset = out_plane_offset + j; + int16_t tmp_avg = 0; + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += *(input_ptr + in_offset); + ++real_count; + } + } // win_w loop + } // win_h loop + if (real_count == 0) { + return NNACL_ERR; + } + int16_t tmp_out = round((float)tmp_avg / (float)real_count); + tmp_out = (int8_t)(round((tmp_out - input_zp) * real_multiplier) + output_zp); + int8_t real_out = tmp_out < out_min ? out_min : tmp_out; + real_out = real_out > out_max ? out_max : real_out; + *(output_ptr + out_channel_offset) = real_out; + } // in_channel loop + } // out_plane loop + } // out_batch loop + return NNACL_OK; +} + +int AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args, int task_id, int thread_num) { + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = compute_args->input_channel_; + int c16 = channel / C16NUM; + int in_w = compute_args->input_w_; + int output_w = compute_args->output_w_; + int out_plane = output_w * compute_args->output_h_; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + thread_num = out_tile_count < thread_num ? out_tile_count : thread_num; + int input_zp = quant_args[0][0].zp_; + int output_zp = quant_args[1][0].zp_; + double real_multiplier = quant_args[0][0].scale_ / quant_args[1][0].scale_; + const int8_t out_min = INT8_MIN; + const int8_t out_max = INT8_MAX; + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + for (int batch = 0; batch < compute_args->output_batch_; batch++) { + int in_batch_offset = batch * compute_args->input_h_ * in_w * channel; + int out_batch_offset = batch * compute_args->output_h_ * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + int out_plane_offset = out_batch_offset + index * channel; + int input_stride = (in_h_index * in_w + in_w_index) * channel; + int kw_s = MSMAX(0, -in_w_index); + int kw_e = MSMIN(win_w, in_w - in_w_index); + int kh_s = MSMAX(0, -in_h_index); + int kh_e = MSMIN(win_h, compute_args->input_h_ - in_h_index); + int real_count = (kw_e - kw_s) * (kh_e - kh_s); + if (real_count == 0) { + return NNACL_ERR; + } + + // 16 channels + for (int j = 0; j < c16; j++) { +#ifdef ENABLE_NEON + int16x8_t tmp_avg[2]; + tmp_avg[0] = vmovq_n_s16(0); + tmp_avg[1] = vmovq_n_s16(0); +#else + int16_t tmp_avg[16]; + int16_t real_out[16]; + for (int m = 0; m < C16NUM; ++m) { + tmp_avg[m] = 0; + } +#endif + int in_channel_offset = in_batch_offset + j * C16NUM; + int out_channel_offset = out_plane_offset + j * C16NUM; + + for (int h = kh_s; h < kh_e; h++) { + for (int w = kw_s; w < kw_e; w++) { + int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel; +#ifdef ENABLE_NEON + int8x16_t in_ptr = vld1q_s8(input_ptr + in_offset); + int8x8_t in_data1 = vget_low_s8(in_ptr); + int8x8_t in_data2 = vget_high_s8(in_ptr); + int16x8_t data1 = vmovl_s8(in_data1); + int16x8_t data2 = vmovl_s8(in_data2); + tmp_avg[0] = vaddq_s16(tmp_avg[0], data1); + tmp_avg[1] = vaddq_s16(tmp_avg[1], data2); +#else + for (int k = 0; k < C16NUM; ++k) { + tmp_avg[k] += input_ptr[in_offset + k]; + } +#endif + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + int16_t tmp_data[8]; + int16_t tmp_out[8]; + int16_t tmp_data1[8]; + int16_t tmp_out1[8]; + for (int l = 0; l < C8NUM; l++) { + tmp_data[l] = tmp_avg[0][l] + 128 * real_count; + tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count; + tmp_out[l] -= 128; + tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp; + } + for (int l = 0; l < C8NUM; l++) { + tmp_data1[l] = tmp_avg[1][l] + 128 * real_count; + tmp_out1[l] = (tmp_data1[l] + real_count / 2) / real_count; + tmp_out1[l] -= 128; + tmp_out1[l] = round((tmp_out1[l] - input_zp) * real_multiplier) + output_zp; + } + int8x8_t real_out[2]; + int8x8_t output_min = vdup_n_s8(out_min); + int8x8_t output_max = vdup_n_s8(out_max); + real_out[0] = vqmovn_s16(vld1q_s16(tmp_out)); + real_out[0] = vmin_s8(real_out[0], output_max); + real_out[0] = vmax_s8(real_out[0], output_min); + vst1_s8(output_ptr + out_channel_offset, real_out[0]); + real_out[1] = vqmovn_s16(vld1q_s16(tmp_out1)); + real_out[1] = vmin_s8(real_out[1], output_max); + real_out[1] = vmax_s8(real_out[1], output_min); + vst1_s8(output_ptr + out_channel_offset + 8, real_out[1]); +#else + for (int l = 0; l < C16NUM; ++l) { + int16_t tmp_data = tmp_avg[l] + 128 * real_count; + real_out[l] = (tmp_data + real_count / 2) / real_count - 128; + real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp); + real_out[l] = real_out[l] < out_min ? out_min : real_out[l]; + real_out[l] = real_out[l] > out_max ? out_max : real_out[l]; + *(output_ptr + out_channel_offset + l) = (int8_t)real_out[l]; + } +#endif + } + + // 8 channels + int channel_16_res = channel - c16 * C16NUM; + int c8 = channel_16_res / C8NUM; + int in_c16_offset = in_batch_offset + c16 * C16NUM; + int out_c16_offset = out_plane_offset + c16 * C16NUM; + for (int j = 0; j < c8; j++) { +#ifdef ENABLE_NEON + int16x8_t tmp_avg = vmovq_n_s16(0); +#else + int16_t tmp_avg[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + int16_t real_out[8]; +#endif + int in_channel_offset = in_c16_offset + j * C8NUM; + int out_channel_offset = out_c16_offset + j * C8NUM; + for (int h = kh_s; h < kh_e; h++) { + for (int w = kw_s; w < kw_e; w++) { + int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel; +#ifdef ENABLE_NEON + int8x8_t in_ptr = vld1_s8(input_ptr + in_offset); + int16x8_t data = vmovl_s8(in_ptr); + tmp_avg = vaddq_s16(tmp_avg, data); +#else + for (int k = 0; k < C8NUM; ++k) { + tmp_avg[k] += input_ptr[in_offset + k]; + } +#endif + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + int16_t tmp_data[8]; + int16_t tmp_out[8]; + for (int l = 0; l < C8NUM; l++) { + tmp_data[l] = tmp_avg[l] + 128 * real_count; + tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count; + tmp_out[l] -= 128; + tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp; + } + int8x8_t real_out; + int8x8_t output_min = vdup_n_s8(out_min); + int8x8_t output_max = vdup_n_s8(out_max); + real_out = vqmovn_s16(vld1q_s16(tmp_out)); + real_out = vmin_s8(real_out, output_max); + real_out = vmax_s8(real_out, output_min); + vst1_s8(output_ptr + out_channel_offset, real_out); +#else + for (int l = 0; l < C8NUM; ++l) { + int16_t tmp_data = tmp_avg[l] + 128 * real_count; + real_out[l] = (tmp_data + real_count / 2) / real_count - 128; + real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp); + real_out[l] = real_out[l] < out_min ? out_min : real_out[l]; + real_out[l] = real_out[l] > out_max ? out_max : real_out[l]; + *(output_ptr + out_channel_offset + l) = (int8_t)real_out[l]; + } +#endif + } + + // less than 8 channel + int channel_8_res = channel_16_res - c8 * C8NUM; + int in_c8_offset = in_c16_offset + c8 * C8NUM; + int out_c8_offset = out_c16_offset + c8 * C8NUM; + for (int k = 0; k < channel_8_res; k++) { + int in_channel_offset = in_c8_offset + k; + int out_channel_offset = out_c8_offset + k; + int16_t tmp_avg = 0; + for (int h = kh_s; h < kh_e; h++) { + for (int w = kw_s; w < kw_e; w++) { + int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel; + tmp_avg += input_ptr[in_offset]; + } // win_w loop + } // win_h loop + int16_t tmp_out = round((float)tmp_avg / (float)real_count + 128) - 128; + tmp_out = (int8_t)(round((tmp_out - input_zp) * real_multiplier) + output_zp); + int16_t real_out = tmp_out < out_min ? out_min : tmp_out; + real_out = real_out > out_max ? out_max : real_out; + *(output_ptr + out_channel_offset) = (int8_t)real_out; + } // channel_res loop + } // out_plane loop + } // out_batch loop + } + return NNACL_OK; +} + +void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = compute_args->input_channel_; + int in_w = compute_args->input_w_; + int in_h = compute_args->input_h_; + int output_w = compute_args->output_w_; + int output_h = compute_args->output_h_; + int output_batch = compute_args->output_batch_; + int out_plane = output_w * output_h; + // input channel is equal to output channel + float input_scale = quant_args[0][0].scale_; + int input_zp = quant_args[0][0].zp_; + float output_scale = quant_args[1][0].scale_; + int output_zp = quant_args[1][0].zp_; + double real_multiplier = input_scale / output_scale; + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int i = 0; i < out_plane; i++) { + int out_w_index = i % output_w; + int out_h_index = i / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + i * channel; + for (int j = 0; j < channel; j++) { + int in_channel_offset = in_batch_offset + j; + int out_channel_offset = out_plane_offset + j; + int8_t tmp_max = INT8_MIN; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset)); + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = (int8_t)(round((tmp_max - input_zp) * real_multiplier) + output_zp); + } // in_channel loop + } // out_plane loop + } // out_batch loop +} + +void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args, int task_id, int thread_num) { + int channel = compute_args->input_channel_; + int in_w = compute_args->input_w_; + int in_h = compute_args->input_h_; + int output_w = compute_args->output_w_; + int out_plane = output_w * compute_args->output_h_; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + thread_num = out_tile_count < thread_num ? out_tile_count : thread_num; + int c16 = UP_DIV(channel, 16); + // input channel is equal to output channel + float input_scale = quant_args[0][0].scale_; + int input_zp = quant_args[0][0].zp_; + float output_scale = quant_args[1][0].scale_; + int output_zp = quant_args[1][0].zp_; + double real_multiplier = input_scale / output_scale; + + NNACL_CHECK_ZERO_RETURN(output_w); + for (int batch = 0; batch < compute_args->output_batch_; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * compute_args->output_h_ * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c16 - 1; j++) { + int in_channel_offset = in_batch_offset + j * 16; + int out_channel_offset = out_plane_offset + j * 16; +#ifdef ENABLE_NEON + int8x16_t tmp_max = vdupq_n_s8(INT8_MIN); +#else + int8_t tmp_max[16]; + for (int m = 0; m < C16NUM; ++m) { + tmp_max[m] = INT8_MIN; + } +#endif + for (int h = 0; h < pooling_param->window_h_; h++) { + for (int w = 0; w < pooling_param->window_w_; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset)); +#else + for (int k = 0; k < C16NUM; ++k) { + tmp_max[k] = MaxInt8(tmp_max[k], *(input_ptr + in_offset + k)); + } +#endif + } + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + for (int l = 0; l < C16NUM; ++l) { + tmp_max[l] = (int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp); + } + vst1q_s8(output_ptr + out_channel_offset, tmp_max); +#else + for (int l = 0; l < C16NUM; ++l) { + *(output_ptr + out_channel_offset + l) = + (int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp); + } +#endif + } // in_channel loop + + // res channel + int channel_s = (c16 - 1) * 16; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + int8_t tmp_max = INT8_MIN; + for (int h = 0; h < pooling_param->window_h_; h++) { + for (int w = 0; w < pooling_param->window_w_; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset)); + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = (int8_t)(round((tmp_max - input_zp) * real_multiplier) + output_zp); + } // channel_res loop + } // out_plane loop + } // out_batch loop + } +} + +void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, int task_id, int thread_num) { + int channel = compute_args->input_channel_; + int in_w = compute_args->input_w_; + int output_w = compute_args->output_w_; + int out_plane = output_w * compute_args->output_h_; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + thread_num = MSMIN(out_tile_count, thread_num); + int8_t out_array[MAX_MAXPOOL_SIZE]; + + NNACL_CHECK_ZERO_RETURN(output_w); + for (int batch = 0; batch < compute_args->output_batch_; batch++) { + int in_batch_offset = batch * compute_args->input_h_ * in_w * channel; + int out_batch_offset = batch * compute_args->output_h_ * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = out_plane - cal_start_index; + real_cal_num = MSMIN(real_cal_num, TILE_NUM); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + const int ky_s = 0 > (-in_h_index) ? 0 : (-in_h_index); + int ky_e = MSMIN(compute_args->window_h_, compute_args->input_h_ - in_h_index); + const int kx_s = 0 > (-in_w_index) ? 0 : (-in_w_index); + int kx_e = MSMIN(compute_args->window_w_, in_w - in_w_index); + int input_stride = (in_h_index * in_w + in_w_index) * channel + in_batch_offset; + int out_plane_offset = out_batch_offset + index * channel; + + int c = 0; + for (; c < channel; c += MAX_MAXPOOL_SIZE) { + int real_channel = channel - c; + real_channel = MSMIN(real_channel, MAX_MAXPOOL_SIZE); + memset(out_array, INT8_MIN, real_channel); + int8_t *out_data = output_ptr + out_plane_offset + c; + for (int h = ky_s; h < ky_e; ++h) { + int in_h_offset = input_stride + h * in_w * channel + c; + for (int w = kx_s; w < kx_e; ++w) { + const int8_t *in_data = input_ptr + in_h_offset + w * channel; + int j = 0; +#ifdef ENABLE_NEON + const int8_t *tmp_in_data = in_data; + int c16 = real_channel / 16 * 16; + int c8 = real_channel / 8 * 8; + for (; j < c16; j += 16) { + int8x16_t ori_in = vld1q_s8(tmp_in_data); + int8x16_t out_array16 = vld1q_s8(out_array + j); + tmp_in_data += 16; + out_array16 = vmaxq_s8(ori_in, out_array16); + vst1q_s8(out_array + j, out_array16); + } // 16 channel loop + + for (; j < c8; j += 8) { + int8x8_t ori_in = vld1_s8(tmp_in_data); + int8x8_t out_array8 = vld1_s8(out_array + j); + tmp_in_data += 8; + out_array8 = vmax_s8(ori_in, out_array8); + vst1_s8(out_array + j, out_array8); + } // 8 channel loop +#endif + for (; j < real_channel; ++j) { + out_array[j] = out_array[j] > in_data[j] ? out_array[j] : in_data[j]; + } + } // kw loop + } // kh loop + + int j = 0; +#ifdef ENABLE_NEON + int c16 = real_channel / 16 * 16; + int c8 = real_channel / 8 * 8; + int8_t *tmp_out_data = out_data; + for (; j < c16; j += 16) { + vst1q_s8(tmp_out_data, vld1q_s8(out_array + j)); + tmp_out_data += 16; + } // 16 channel loop + + for (; j < c8; j += 8) { + vst1_s8(tmp_out_data, vld1_s8(out_array + j)); + tmp_out_data += 8; + } // 8 channel loop +#endif + for (; j < real_channel; ++j) { + out_data[j] = out_array[j]; + } + } // 256 channel loop + } // out_plane loop + } // out_batch loop + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..3b1be326c9da2e3a6dead21f23fd576299da1e66 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_POOLING_H_ +#define NNACL_INT8_POOLING_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/pooling_fp32.h" +#include "nnacl_c/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +#define MAX_MAXPOOL_SIZE 256 + +int AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args); + +int AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args, int task_id, int thread_num); + +void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args); + +void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args, int task_id, int thread_num); + +void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_POOLING_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/power_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/power_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..3fc6a60cf94a417853cb42a5e8dde72cdcd15cf0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/power_int8.c @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/power_int8.h" + +int PowerInt8(const int8_t *input, const int8_t *exp_ptr, int8_t *output, int count, const PowQuantArg *args, + bool broadcast, int scale, int shift) { + double input_scale = args->in_args_.scale_; + int input_zp = args->in_args_.zp_; + double output_scale = args->out_args_.scale_; + int output_zp = args->out_args_.zp_; + int act_min = args->output_activation_min_; + int act_max = args->output_activation_max_; + double exp_scale = args->exp_args_.scale_; + int exp_zp = args->exp_args_.zp_; + + if (broadcast) { + float exp_val = exp_scale * (exp_ptr[0] - exp_zp); + for (int i = 0; i < count; ++i) { + float input_val = input_scale * (input[i] - input_zp); + float output_val = pow(scale * input_val + shift, exp_val); + int32_t output_scaled = round(output_val / output_scale) + output_zp; + output[i] = (int8_t)MSMAX(act_min, MSMIN(output_scaled, act_max)); + } + } else { + for (int i = 0; i < count; ++i) { + float input_val = input_scale * (input[i] - input_zp); + float exp_val = exp_scale * (exp_ptr[i] - exp_zp); + float output_val = pow(scale * input_val + shift, exp_val); + int32_t output_scaled = round(output_val / output_scale) + output_zp; + output[i] = (int8_t)MSMAX(act_min, MSMIN(output_scaled, act_max)); + } + } + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/power_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/power_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..7bf8b80947afa4c0de3aacfe9edbec19ce719042 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/power_int8.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_POWER_INT8_H_ +#define NNACL_INT8_POWER_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/pow_parameter.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int PowerInt8(const int8_t *input, const int8_t *exp_ptr, int8_t *output, int count, const PowQuantArg *args, + bool broadcast, int scale, int shift); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_POWER_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quant_dtype_cast_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quant_dtype_cast_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..edd05f50de63366812363e341536c54c4d2857e9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quant_dtype_cast_int8.c @@ -0,0 +1,437 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/errorcode.h" +#ifdef ENABLE_ARM +#include +#endif + +#ifdef ENABLE_ARM64 +inline void Int8ToFp32_arm64(const int8_t *quant_values, float *dst, float scale, int32_t zp, int size) { + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, #0\n" + "beq 2f\n" + + "dup v20.4s, %w[zp32]\n" + "dup v21.4s, %w[scale]\n" + + "cmp w8, #16\n" + "blt 1f\n" + + "0:\n" + "subs w8, w8, #16\n" + "ld1 {v7.16b}, [%[quant_values]], #16\n" + + "sxtl v8.8h, v7.8b\n" + "sxtl2 v9.8h, v7.16b\n" + + "sxtl v0.4s, v8.4h\n" + "sxtl2 v1.4s, v8.8h\n" + "sxtl v2.4s, v9.4h\n" + "sxtl2 v3.4s, v9.8h\n" + "sub v0.4s, v0.4s, v20.4s\n" + "sub v1.4s, v1.4s, v20.4s\n" + "sub v2.4s, v2.4s, v20.4s\n" + "sub v3.4s, v3.4s, v20.4s\n" + "scvtf v4.4s, v0.4s\n" + "scvtf v5.4s, v1.4s\n" + "scvtf v6.4s, v2.4s\n" + "scvtf v7.4s, v3.4s\n" + + "fmul v0.4s, v4.4s, v21.4s\n" + "fmul v1.4s, v5.4s, v21.4s\n" + "fmul v2.4s, v6.4s, v21.4s\n" + "fmul v3.4s, v7.4s, v21.4s\n" + + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[dst]], #64\n" + "beq 2f\n" + "cmp w8, #16\n" + "bge 0b\n" + + "1:\n" + "ldrsb w9, [%[quant_values]], #1\n" + + "subs w8, w8, #1\n" + "sub w9, w9, %w[zp32]\n" + "scvtf s9, w9\n" + + "fmul s9, s9, s21\n" + "str s9, [%[dst]], #4\n" + "bne 1b\n" + + "2:\n" + + : + : [quant_values] "r"(quant_values), [dst] "r"(dst), [scale] "r"(scale), [zp32] "r"(zp), [size] "r"(size) + : "w8", "w9", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v20", "v21"); +} +#endif + +int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + +#ifdef ENABLE_ARM64 + Int8ToFp32_arm64(quant_values, real_values, scale, zp, size); +#else + for (int i = 0; i < size; i++) { + real_values[i] = (quant_values[i] - zp) * scale; + } +#endif + return NNACL_OK; +} + +#ifdef ENABLE_ARM64 +inline void Fp32ToInt8_arm64(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, + int32_t min_value, int32_t max_value) { + float ivs = 1.0f / scale; + + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, #0\n" + "beq 2f\n" + + "dup v12.4s, %w[ivs]\n" + "dup v13.4s, %w[min_value]\n" + "dup v14.4s, %w[max_value]\n" + "cmp w8, #16\n" + "blt 1f\n" + "0:\n" + "subs w8, w8, #16\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[real_values]], #64\n" + "dup v8.4s, %w[zp]\n" + "dup v9.4s, %w[zp]\n" + "dup v10.4s, %w[zp]\n" + "dup v11.4s, %w[zp]\n" + "scvtf v4.4s, v8.4s\n" + "scvtf v5.4s, v9.4s\n" + "scvtf v6.4s, v10.4s\n" + "scvtf v7.4s, v11.4s\n" + "fmla v4.4s, v0.4s, v12.4s\n" + "fmla v5.4s, v1.4s, v12.4s\n" + "fmla v6.4s, v2.4s, v12.4s\n" + "fmla v7.4s, v3.4s, v12.4s\n" + + "fcvtas v0.4s, v4.4s\n" + "fcvtas v1.4s, v5.4s\n" + "fcvtas v2.4s, v6.4s\n" + "fcvtas v3.4s, v7.4s\n" + "smax v0.4s, v0.4s, v13.4s\n" + "smax v1.4s, v1.4s, v13.4s\n" + "smax v2.4s, v2.4s, v13.4s\n" + "smax v3.4s, v3.4s, v13.4s\n" + "smin v0.4s, v0.4s, v14.4s\n" + "smin v1.4s, v1.4s, v14.4s\n" + "smin v2.4s, v2.4s, v14.4s\n" + "smin v3.4s, v3.4s, v14.4s\n" + + "sqxtn v4.4h, v0.4s\n" + "sqxtn2 v4.8h, v1.4s\n" + "sqxtn v5.4h, v2.4s\n" + "sqxtn2 v5.8h, v3.4s\n" + "sqxtn v6.8b, v4.8h\n" + "sqxtn2 v6.16b, v5.8h\n" + "st1 {v6.16b}, [%[quant_values]], #16\n" + + "beq 2f\n" + "cmp w8, #16\n" + "bge 0b\n" + + "1:\n" + "scvtf s0, %w[zp]\n" + "subs w8, w8, #1\n" + "ldr s4, [%[real_values]], #4\n" + "fmul s4, s4, s12\n" + "fadd s0, s0, s4\n" + "fcvtas s0, s0\n" + "smax v0.4s, v0.4s, v13.4s\n" + "smin v0.4s, v0.4s, v14.4s\n" + "sqxtn v1.4h, v0.4s\n" + "sqxtn v0.8b, v1.8h\n" + "st1 {v0.b}[0], [%[quant_values]], #1\n" + + "bne 1b\n" + + "2:\n" + : + : [quant_values] "r"(quant_values), [real_values] "r"(real_values), [scale] "r"(scale), [zp] "r"(zp), + [size] "r"(size), [ivs] "r"(ivs), [min_value] "r"(min_value), [max_value] "r"(max_value) + : "w8", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14"); +} +#endif + +int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, + int32_t min_value, int32_t max_value) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } +#ifdef ENABLE_ARM64 + Fp32ToInt8_arm64(real_values, quant_values, scale, zp, size, min_value, max_value); +#else + const float inverse_scale = 1.0f / scale; + for (int i = 0; i < size; ++i) { + if (real_values[i] == INFINITY) { + quant_values[i] = max_value; + } else if (real_values[i] == -INFINITY) { + quant_values[i] = min_value; + } else { + int temp = round(real_values[i] * inverse_scale + zp); + temp = temp < max_value ? temp : max_value; + temp = temp > min_value ? temp : min_value; + quant_values[i] = (int8_t)temp; + } + } +#endif + return NNACL_OK; +} + +#ifdef ENABLE_ARM64 +inline void Fp32ToInt8Perchannel_arm64(const float *real_values, int8_t *quant_values, float *scales, int32_t *zps, + int size, int row_length, int32_t min_value, int32_t max_value) { + volatile float ivs[size]; + for (int i = 0; i < size; i++) { + volatile int channel_index = i / row_length; + ivs[i] = 1.0f / scales[channel_index]; + } + volatile int32_t zp = zps[0]; + + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, #0\n" + "beq 2f\n" + + "mov x4, %[ivs]\n" // reload ivs + "dup v13.4s, %w[min_value]\n" + "dup v14.4s, %w[max_value]\n" + "cmp w8, #16\n" + "blt 1f\n" + "0:\n" + "subs w8, w8, #16\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[real_values]], #64\n" + "dup v8.4s, %w[zp]\n" + "dup v9.4s, %w[zp]\n" + "dup v10.4s, %w[zp]\n" + "dup v11.4s, %w[zp]\n" + "scvtf v4.4s, v8.4s\n" + "scvtf v5.4s, v9.4s\n" + "scvtf v6.4s, v10.4s\n" + "scvtf v7.4s, v11.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v4.4s, v0.4s, v12.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v5.4s, v1.4s, v12.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v6.4s, v2.4s, v12.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v7.4s, v3.4s, v12.4s\n" + + "fcvtas v0.4s, v4.4s\n" + "fcvtas v1.4s, v5.4s\n" + "fcvtas v2.4s, v6.4s\n" + "fcvtas v3.4s, v7.4s\n" + "smax v0.4s, v0.4s, v13.4s\n" + "smax v1.4s, v1.4s, v13.4s\n" + "smax v2.4s, v2.4s, v13.4s\n" + "smax v3.4s, v3.4s, v13.4s\n" + "smin v0.4s, v0.4s, v14.4s\n" + "smin v1.4s, v1.4s, v14.4s\n" + "smin v2.4s, v2.4s, v14.4s\n" + "smin v3.4s, v3.4s, v14.4s\n" + + "sqxtn v4.4h, v0.4s\n" + "sqxtn2 v4.8h, v1.4s\n" + "sqxtn v5.4h, v2.4s\n" + "sqxtn2 v5.8h, v3.4s\n" + "sqxtn v6.8b, v4.8h\n" + "sqxtn2 v6.16b, v5.8h\n" + "st1 {v6.16b}, [%[quant_values]], #16\n" + + "beq 2f\n" + "cmp w8, #16\n" + "bge 0b\n" + + "1:\n" + "scvtf s0, %w[zp]\n" + "subs w8, w8, #1\n" + "ldr s4, [%[real_values]], #4\n" + "fmul s4, s4, s12\n" + "fadd s0, s0, s4\n" + "fcvtas s0, s0\n" + "smax v0.4s, v0.4s, v13.4s\n" + "smin v0.4s, v0.4s, v14.4s\n" + "sqxtn v1.4h, v0.4s\n" + "sqxtn v0.8b, v1.8h\n" + "st1 {v0.b}[0], [%[quant_values]], #1\n" + + "bne 1b\n" + + "2:\n" + : + : [quant_values] "r"(quant_values), [real_values] "r"(real_values), [scales] "r"(scales), [zp] "r"(zp), + [size] "r"(size), [row_length] "r"(row_length), [ivs] "r"(ivs), [min_value] "r"(min_value), + [max_value] "r"(max_value) + : "w8", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "x4"); +} +#endif + +int DoChannelRowFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size, + int row_length, int32_t min_value, int32_t max_value) { + if (quant_values == NULL || real_values == NULL || scale == NULL || zp == NULL || row_length == 0) { + return NNACL_PARAM_INVALID; + } +#ifdef ENABLE_ARM64 + Fp32ToInt8Perchannel_arm64(real_values, quant_values, scale, zp, size, row_length, min_value, max_value); +#else + for (int i = 0; i < size; ++i) { + int channel_index = i / row_length; + const float inverse_scale = 1.0f / scale[channel_index]; + if (real_values[i] == INFINITY) { + quant_values[i] = max_value; + } else if (real_values[i] == -INFINITY) { + quant_values[i] = min_value; + } else { + int temp = round(real_values[i] * inverse_scale + zp[channel_index]); + temp = temp < max_value ? temp : max_value; + temp = temp > min_value ? temp : min_value; + quant_values[i] = (int8_t)temp; + } + } +#endif + return NNACL_OK; +} + +int DoChannelColFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size, + int row_length, int32_t min_value, int32_t max_value) { + if (quant_values == NULL || real_values == NULL || scale == NULL || zp == NULL || row_length == 0) { + return NNACL_PARAM_INVALID; + } + int row_total = size / row_length; + for (int r = 0; r < row_total; r++) { + const float *real_current = real_values + r * row_length; + int8_t *quant_current = quant_values + r * row_length; + for (int c = 0; c < row_length; c++) { + const float inverse_scale = 1.0f / scale[c]; + if (real_current[c] == INFINITY) { + quant_current[c] = max_value; + } else if (real_current[c] == -INFINITY) { + quant_current[c] = min_value; + } else { + int temp = round(real_current[c] * inverse_scale + zp[c]); + temp = temp < max_value ? temp : max_value; + temp = temp > min_value ? temp : min_value; + quant_current[c] = (int8_t)temp; + } + } + } + return NNACL_OK; +} + +int DoQuantizeFp32ToInt8FromUint8Source(const float *real_values, int8_t *quant_values, float scale, int32_t zp, + int size, int32_t min_value, int32_t max_value) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + zp += 128; + const float inverse_scale = 1.0f / scale; + for (int i = 0; i < size; ++i) { + if (real_values[i] == INFINITY) { + quant_values[i] = max_value; + } else if (real_values[i] == -INFINITY) { + quant_values[i] = min_value; + } else { + int temp = round(real_values[i] * inverse_scale + zp); + temp -= 128; + temp = temp < 127 ? temp : 127; + temp = temp > -128 ? temp : -128; + quant_values[i] = (int8_t)temp; + } + } + return NNACL_OK; +} + +int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + real_values[i] = (float)((int)quant_values[i] - zp) * scale; + } + return NNACL_OK; +} + +int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + if (isinf(real_values[i])) { + quant_values[i] = 255; + } else { + float temp = (float)round(real_values[i] * 1.0 / scale + zp); + if (temp > 255) { + quant_values[i] = 255; + } else if (temp < 0) { + quant_values[i] = 0; + } else { + quant_values[i] = (uint8_t)temp; + } + } + } + return NNACL_OK; +} + +int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + int temp = quant_values[i] + 128; + if (temp > 255) { + real_values[i] = (uint8_t)255; + } else if (temp < 0) { + real_values[i] = 0; + } else { + real_values[i] = (uint8_t)temp; + } + } + return NNACL_OK; +} + +int UInt8ToInt8(const uint8_t *real_values, int8_t *quant_values, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + int temp = (int)real_values[i] - 128; + if (temp > 127) { + quant_values[i] = 127; + } else if (temp < -128) { + quant_values[i] = -128; + } else { + quant_values[i] = (int8_t)temp; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quant_dtype_cast_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quant_dtype_cast_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..6de46c60d0ce3e127827800dba6c5456b201460a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quant_dtype_cast_int8.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_QUANTDTYPECAST_H_ +#define NNACL_INT8_QUANTDTYPECAST_H_ + +#include "nnacl_c/op_base.h" + +typedef struct QuantDTypeCastParameter { + OpParameter op_parameter_; + int32_t srcT; + int32_t dstT; + int32_t axis; +} QuantDTypeCastParameter; + +#ifdef __cplusplus +extern "C" { +#endif +int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); +int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, + int32_t min_value, int32_t max_value); +int DoChannelRowFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size, + int row_length, int32_t min_value, int32_t max_value); +int DoChannelColFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size, + int row_length, int32_t min_value, int32_t max_value); +int DoQuantizeFp32ToInt8FromUint8Source(const float *real_values, int8_t *quant_values, float scale, int32_t zp, + int size, int32_t min_value, int32_t max_value); +#ifdef ENABLE_ARM64 +void Fp32ToInt8_arm64(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, + int32_t min_value, int32_t max_value); +void Int8ToFp32_arm64(const int8_t *quant_values, float *dst, float scale, int32_t zp, int size); +void Fp32ToInt8Perchannel_arm64(const float *real_values, int8_t *quant_values, float *scales, int32_t *zps, int size, + int row_length, int32_t min_value, int32_t max_value); +#endif +int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size); +int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size); +int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size); +int UInt8ToInt8(const uint8_t *real_values, int8_t *quant_values, int size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_QUANTDTYPECAST_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quantize.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quantize.c new file mode 100644 index 0000000000000000000000000000000000000000..e9eadb23f1c081712a1c7b458f09fd23210464c5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quantize.c @@ -0,0 +1,161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/quantize.h" +#include + +const uint64_t dSignMask = 1ull << 63; +const uint64_t dExponentMask = 0x7ffull << 52; +const uint64_t dFractionMask = (1ull << 52) - 1; +const int dExponentBias = 1022; +const int dMantissaBits = 52; +const int dInfiniteExponent = 0x7ff; +const double dNormalizer = 0x1p54; +const int dNormalizerBias = 54; +const int iMantissaBits = 31; + +void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, int32_t *right_shift) { + if (quantized_multiplier == NULL || right_shift == NULL) { + return; + } + int shift = 0; + QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift); + *right_shift = -shift; +} + +void QuantizeRoundParameterWithDoublePrecision(double double_multiplier, int32_t *quantized_multiplier, + int32_t *left_shift, int32_t *right_shift) { + int shift = 0; + QuantizeMultiplierSmallerThanOne(double_multiplier, quantized_multiplier, &shift); + shift = -shift; + if (shift < 0) { + *left_shift = 0; + *right_shift = shift; + } else { + *left_shift = shift; + *right_shift = 0; + } +} + +void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t *quantized_multiplier, + int32_t *left_shift, int32_t *right_shift) { + int shift = 0; + const uint32_t scale_bits = (uint32_t)(double_multiplier); + /* multiplier is in[0x40000000, 0x7FFFFF80] range */ + *quantized_multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + if (quantized_multiplier[0] < INT32_C(0x40000000) || quantized_multiplier[0] > INT32_C(0x7FFFFF80)) { + return; + } + /* shift is in [0, 31] range */ + shift = 127 + 31 - 32 - ((uint32_t)(double_multiplier) >> 23); + shift = -shift; + if (shift < 0) { + *left_shift = 0; + *right_shift = shift; + } else { + *left_shift = shift; + *right_shift = 0; + } +} + +uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } + +int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } + +void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int32_t *mini, + int32_t *maxi) { + int32_t min = INT8_MIN; + int32_t max = INT8_MAX; + int32_t quantized_zero = QuantizeToInt8(0, scale, zp); + int32_t quantized_six = QuantizeToInt8(6, scale, zp); + if (is_relu) { + min = min > quantized_zero ? min : quantized_zero; + } else if (is_relu6) { + min = min > quantized_zero ? min : quantized_zero; + max = max < quantized_six ? max : quantized_six; + } else { + // do nothing + } + *mini = min; + *maxi = max; +} + +// quantize from float to int8 +void Quantize(const float *input_data, int length, float scale, int zero_point, int8_t *output_data) { + for (int i = 0; i < length; ++i) { + int q = (int)round(input_data[i] / scale + zero_point); + q = q > SCHAR_MAX ? SCHAR_MAX : q; + q = q < SCHAR_MIN ? SCHAR_MIN : q; + output_data[i] = (int8_t)q; + } +} + +// dequantize from int8 to float +void Dequantize(const int8_t *input_data, int length, float scale, int zero_point, float *output_data) { + for (int i = 0; i < length; ++i) { + output_data[i] = scale * (input_data[i] - zero_point); + } +} + +void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int32_t *shift) { + if (quantized_multiplier == NULL || shift == NULL) { + return; + } + // we split a floating number into two parts: exponent and fraction + // since fraction is stored as int32, only 31 bits of mantissa is remained + union { + double d; + uint64_t ul; + } dul; + dul.d = double_multiplier; + if (!(dul.ul & (~dSignMask))) { + // multiplier is 0 + *quantized_multiplier = 0; + *shift = 0; + return; + } + int exponent = (int)((dul.ul & dExponentMask) >> dMantissaBits); + if (exponent == dInfiniteExponent) { + // multiplier is inf or NaN + *shift = 0; + if (!(dul.ul & dFractionMask)) { + // inf + *quantized_multiplier = (dul.ul & dSignMask) ? INT_MIN : INT_MAX; + } else { + // NaN + *quantized_multiplier = 0; + } + return; + } + if (exponent == 0) { + // multiplier is a subnormal number + dul.d *= dNormalizer; + exponent = (int)((dul.ul & dExponentMask) >> dMantissaBits); + *shift = exponent - dExponentBias - dNormalizerBias; + } else { + *shift = exponent - dExponentBias; + } + uint64_t fraction = dul.ul & dFractionMask; + fraction += (1ull << dMantissaBits); + uint64_t rounded = ((fraction >> (dMantissaBits - iMantissaBits)) + 1ull) >> 1; + // we get 31 rounded bits now + if (rounded == (1ull << iMantissaBits)) { + // rounding may cause a carry + rounded >>= 1; + ++*shift; + } + *quantized_multiplier = (dul.ul & dSignMask) ? (-(int32_t)(rounded)) : (int32_t)(rounded); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quantize.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quantize.h new file mode 100644 index 0000000000000000000000000000000000000000..380713c517a03451e9f0d479236fee4645dc16d9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quantize.h @@ -0,0 +1,222 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_QUANTIZATION_QUANTIZE_H_ +#define NNACL_QUANTIZATION_QUANTIZE_H_ + +#include +#include "nnacl_c/op_base.h" + +#define INPUT_PER_CHANNEL 0b001 +#define FILTER_PER_CHANNEL 0b010 +#define OUTPUT_PER_CHANNEL 0b100 + +typedef struct ConvQuantArg { + RoundingMode round_mode_; + CalFixedMultiplierMode quant_multiplier_mode_; + QuantArg *input_quant_args_; + QuantArg *filter_quant_args_; + QuantArg *output_quant_args_; + double *real_multiplier_; + int32_t *left_shift_; + int32_t *right_shift_; + int32_t *quant_multiplier_; + int32_t *out_act_min_; + int32_t *out_act_max_; + size_t input_arg_num_; + size_t filter_arg_num_; + size_t output_arg_num_; + uint8_t per_channel_; +} ConvQuantArg; + +typedef struct ConcatQuantArg { + QuantArg *in_args_; + QuantArg out_args_; + int8_t output_activation_min_; + int8_t output_activation_max_; +} ConcatQuantArg; + +typedef struct PreluQuantArg { + int32_t *input_sizes_; + int output_size_; + int32_t **input_shapes_; + int32_t *output_shape_; + size_t input_num_; + size_t output_dim_; + float alpha_; + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + QuantArg *in_quant_args_; + QuantArg out_quant_args_; +} PreluQuantArg; + +typedef struct CropQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; +} CropQuantArg; + +typedef struct ArithSelfQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + int output_multiplier_; + int shift_left_; + int shift_right_; +} ArithSelfQuantArg; + +typedef struct GatherQuantArg { + double alpha_; + int zp_in_; + int zp_out_; +} GatherQuantArg; + +typedef struct DynamicGatherQuantArg { + float *scale_in_; + int32_t *zp_in_; +} DynamicGatherQuantArg; + +typedef struct SoftmaxQuantArg { + QuantArg in_quant_args_; + QuantArg out_quant_arg_; + int output_activation_min_; + int output_activation_max_; + int output_multiplier_; + int shift_left_; + int shift_right_; +} SoftmaxQuantArg; + +typedef struct SubQuantArg { + QuantArg in0_args_; + QuantArg in1_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + int input0_multiplier_; + int input1_multiplier_; + int output_multiplier_; + int input0_shift_; + int input1_shift_; + int output_shift_; + int left_shift_result0_; + int left_shift_result1_; + int right_shift0_; + int right_shift1_; + int left_shift_out_; + int right_shift_out_; +} SubQuantArg; + +typedef struct ArithmeticQuantArg { + QuantArg in0_args_; + QuantArg in1_args_; + QuantArg out_args_; +} ArithmeticQuantArg; + +typedef struct DivQuantArg { + QuantArg in0_args_; + QuantArg in1_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + int output_multiplier_; + int output_shift_; +} DivQuantArg; + +typedef struct ReduceQuantArg { + double in_scale_; + int32_t in_zp_; + double out_scale_; + int32_t out_zp_; + int32_t in_out_multiplier_; + int in_out_left_shift_; + int in_out_right_shift_; + int32_t mean_multiplier_; + int mean_left_shift_; + int mean_right_shift_; + int32_t prod_multiplier_; + int prod_left_shift_; + int prod_right_shift_; + int32_t sum_square_multiplier_; + int sum_square_left_shift_; + int sum_square_right_shift_; +} ReduceQuantArg; + +typedef struct LeakyReluQuantArg { + QuantArg in_args_; + QuantArg out_args_; + float slope_; + int input_dim_; + int element_num; + int thread_num_; +} LeakyReluQuantArg; + +typedef struct ResizeQuantArg { + int32_t ratio_x_; + int32_t ratio_y_; + int32_t *x_axis_index_; + int32_t *x_axis_lower_; + int32_t *x_axis_upper_; + int32_t *y_axis_index_; + int32_t *y_axis_lower_; + int32_t *y_axis_upper_; +} ResizeQuantArg; + +typedef struct ResizeFloatScaleQuantArg { + float ratio_x_; + float ratio_y_; + float *x_axis_index_; + int32_t *x_axis_lower_; + int32_t *x_axis_upper_; + float *y_axis_index_; + int32_t *y_axis_lower_; + int32_t *y_axis_upper_; +} ResizeFloatScaleQuantArg; + +#ifdef __cplusplus +extern "C" { +#endif + +void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int32_t *shift); + +void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, int32_t *right_shift); + +void QuantizeRoundParameterWithDoublePrecision(double double_multiplier, int32_t *quantized_multiplier, + int32_t *left_shift, int32_t *right_shift); + +void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t *quantized_multiplier, + int32_t *left_shift, int32_t *right_shift); + +uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp); + +int32_t QuantizeToInt8(float real_value, float scale, int32_t zp); + +void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int32_t *mini, + int32_t *maxi); +// quantize from float to int8 +void Quantize(const float *input_data, int length, float scale, int zero_point, int8_t *output_data); + +// dequantize from int8 to float +void Dequantize(const int8_t *input_data, int length, float scale, int zero_point, float *output_data); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_QUANTIZATION_QUANTIZE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reduce_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reduce_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..be1aca7b3f9d090574521bb1e338a6e135e90622 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reduce_int8.c @@ -0,0 +1,597 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/int8/reduce_int8.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/common_func.h" + +int ReduceMeanN(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanH(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanNH(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanNW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanNC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanHW(int n, int plane, int count, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg, + int32_t bias) { + int stride = plane * UP_ROUND(c, C4NUM); + for (int batch = 0; batch < n; ++batch) { + int8_t *in_ptr = in_data + batch * stride; + int8_t *out_ptr = out_data + batch * c; + for (int i = 0; i < count; ++i) { + int32_t sum_array = 0; + int j = 0; +#ifdef ENABLE_ARM64 + for (; j < plane; j += 16) { + int8x16_t in_data_vec = vld1q_s8(in_ptr); + sum_array += vaddlvq_s8(in_data_vec); + in_ptr += 16; + } + for (; j < plane; j += 8) { + int8x8_t in_data_vec = vld1_s8(in_ptr); + sum_array += vaddlv_s8(in_data_vec); + in_ptr += 8; + } + for (; j < plane; j += 4) { + int32x4_t in_data_vec; + in_data_vec[0] = in_ptr[0]; + in_data_vec[1] = in_ptr[1]; + in_data_vec[2] = in_ptr[2]; + in_data_vec[3] = in_ptr[3]; + sum_array += vaddvq_s32(in_data_vec); + in_ptr += 4; + } +#elif ENABLE_ARM32 + int32x4_t accum = vmovq_n_s32(0); + for (; j < plane; j += 16) { + int32x4_t in_data_vec1; + int32x4_t in_data_vec2; + int32x4_t in_data_vec3; + int32x4_t in_data_vec4; + in_data_vec1[0] = in_ptr[0]; + in_data_vec1[1] = in_ptr[1]; + in_data_vec1[2] = in_ptr[2]; + in_data_vec1[3] = in_ptr[3]; + in_data_vec2[0] = in_ptr[4]; + in_data_vec2[1] = in_ptr[5]; + in_data_vec2[2] = in_ptr[6]; + in_data_vec2[3] = in_ptr[7]; + in_data_vec3[0] = in_ptr[8]; + in_data_vec3[1] = in_ptr[9]; + in_data_vec3[2] = in_ptr[10]; + in_data_vec3[3] = in_ptr[11]; + in_data_vec4[0] = in_ptr[12]; + in_data_vec4[1] = in_ptr[13]; + in_data_vec4[2] = in_ptr[14]; + in_data_vec4[3] = in_ptr[15]; + accum = vaddq_s32(accum, in_data_vec1); + accum = vaddq_s32(accum, in_data_vec2); + accum = vaddq_s32(accum, in_data_vec3); + accum = vaddq_s32(accum, in_data_vec4); + in_ptr += 16; + } + for (; j < plane; j += 8) { + int32x4_t in_data_vec1; + int32x4_t in_data_vec2; + in_data_vec1[0] = in_ptr[0]; + in_data_vec1[1] = in_ptr[1]; + in_data_vec1[2] = in_ptr[2]; + in_data_vec1[3] = in_ptr[3]; + in_data_vec2[0] = in_ptr[4]; + in_data_vec2[1] = in_ptr[5]; + in_data_vec2[2] = in_ptr[6]; + in_data_vec2[3] = in_ptr[7]; + accum = vaddq_s32(accum, in_data_vec1); + accum = vaddq_s32(accum, in_data_vec2); + in_ptr += 8; + } + for (; j < plane; j += 4) { + int32x4_t in_data_vec; + in_data_vec[0] = in_ptr[0]; + in_data_vec[1] = in_ptr[1]; + in_data_vec[2] = in_ptr[2]; + in_data_vec[3] = in_ptr[3]; + accum = vaddq_s32(accum, in_data_vec); + in_ptr += 4; + } + sum_array += accum[0]; + sum_array += accum[1]; + sum_array += accum[2]; + sum_array += accum[3]; +#endif + for (; j < plane; j++) { + sum_array += in_ptr[0]; + in_ptr++; + } + int32_t mean = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum_array * (1 << (unsigned int)quant_arg.left_shift_), + quant_arg.multiplier_), + quant_arg.right_shift_); + mean += bias; + *out_ptr++ = MSMAX(MSMIN(mean, INT8_MAX), INT8_MIN); + } + } + return NNACL_OK; +} + +int ReduceMeanHC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanNHW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanNHC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanNWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanHWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanNHWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +// Get x such that (x-zp_in) * scale_in = mean +// Assuming reduce n axes, this works for first n-1 reduce. One call for one reduce. +int ReduceMeanInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isAddOverflow(sum, tmp)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + int32_t mean = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->mean_left_shift_), quant->mean_multiplier_), + quant->mean_right_shift_); + if (isAddOverflow(mean, quant->in_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + *inner_dst = mean + quant->in_zp_; + } + } + return NNACL_OK; +} + +// suppose reduce n axes, this works for last reduce axis. +// get y such that (y-zp_out) * scale_out = mean(x-zp_in)*scale_in +int ReduceMeanLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isAddOverflow(tmp, sum)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + int32_t mean = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->mean_left_shift_), quant->mean_multiplier_), + quant->mean_right_shift_); + // trans to output scale + int32_t mean_scaled = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(mean * (1 << (unsigned int)quant->in_out_left_shift_), + quant->in_out_multiplier_), + quant->in_out_right_shift_); + if (isAddOverflow(mean_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + mean = mean_scaled + quant->out_zp_; + + *inner_dst = MSMAX(MSMIN(mean, INT8_MAX), INT8_MIN); + } + } + return NNACL_OK; +} + +// Get x such that (x-zp_in) * scale_in = sum(item-zp_in)*scale_in +// Assuming reduce n axes, this works for first n-1 reduce. One call for one reduce. +int ReduceSumInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isAddOverflow(tmp, sum)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + + if (isAddOverflow(quant->in_zp_, sum)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + *inner_dst = sum + quant->in_zp_; + } + } + return NNACL_OK; +} + +// suppose reduce n axes, this works for last reduce axis. +// get y such that (y-zp_out) * scale_out = sum(item-zp_in)*scale_in +int ReduceSumLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isAddOverflow(tmp, sum)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + int32_t sum_scaled = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->in_out_left_shift_), + quant->in_out_multiplier_), + quant->in_out_right_shift_); + if (isAddOverflow(sum_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum = sum_scaled + quant->out_zp_; + if (sum > INT8_MAX) { + *inner_dst = INT8_MAX; + } else if (sum < INT8_MIN) { + *inner_dst = INT8_MIN; + } else { + *inner_dst = (int8_t)sum; + } + } + } + return NNACL_OK; +} + +int ReduceMaxLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t tmp = INT8_MIN; + for (i = 0; i < axis_size; i++) { + tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + int32_t tmp_scaled = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul((tmp - quant->in_zp_) * (1 << (unsigned int)quant->in_out_left_shift_), + quant->in_out_multiplier_), + quant->in_out_right_shift_); + if (isAddOverflow(tmp_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + tmp = tmp_scaled + quant->out_zp_; + if (tmp > INT8_MAX) { + *inner_dst = INT8_MAX; + } else if (tmp < INT8_MIN) { + *inner_dst = INT8_MIN; + } else { + *inner_dst = (int8_t)tmp; + } + } + } + return NNACL_OK; +} + +int ReduceMaxInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t tmp = INT8_MIN; + for (i = 0; i < axis_size; i++) { + tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceMinLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + const int base_offset = 20; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t tmp = INT8_MAX; + for (i = 0; i < axis_size; i++) { + tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + int32_t tmp_scaled = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( + (tmp - quant->in_zp_) * (1 << ((unsigned int)quant->in_out_left_shift_ + base_offset)), + quant->in_out_multiplier_), + quant->in_out_right_shift_ + base_offset); + if (isAddOverflow(tmp_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + tmp = tmp_scaled + quant->out_zp_; + if (tmp > INT8_MAX) { + *inner_dst = INT8_MAX; + } else if (tmp < INT8_MIN) { + *inner_dst = INT8_MIN; + } else { + *inner_dst = (int8_t)tmp; + } + } + } + return NNACL_OK; +} + +int ReduceMinInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t tmp = INT8_MAX; + for (i = 0; i < axis_size; i++) { + tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceProdLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t prod = 1; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isMulOverflow(prod, tmp)) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + prod *= tmp; + } + prod = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(prod * (1 << (unsigned int)quant->prod_left_shift_), quant->prod_multiplier_), + quant->prod_right_shift_); + int32_t prod_scaled = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(prod * (1 << (unsigned int)quant->in_out_left_shift_), + quant->in_out_multiplier_), + quant->in_out_right_shift_); + if (isAddOverflow(prod_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + prod = prod_scaled + quant->out_zp_; + if (prod > INT8_MAX) { + *inner_dst = INT8_MAX; + } else if (prod < INT8_MIN) { + *inner_dst = INT8_MIN; + } else { + *inner_dst = (int8_t)prod; + } + } + } + return NNACL_OK; +} + +int ReduceProdInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t prod = 1; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isMulOverflow(prod, tmp)) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + prod *= tmp; + } + prod = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(prod * (1 << (unsigned int)quant->prod_left_shift_), quant->prod_multiplier_), + quant->prod_right_shift_); + if (isAddOverflow(prod, quant->in_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + *inner_dst = prod + quant->in_zp_; + } + } + return NNACL_OK; +} + +int ReduceSumSquareLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp; + if (isMulOverflow(inner_src[i * inner_size] - quant->in_zp_, inner_src[i * inner_size] - quant->in_zp_)) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + tmp = (inner_src[i * inner_size] - quant->in_zp_) * (inner_src[i * inner_size] - quant->in_zp_); + if (isAddOverflow(sum, tmp)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + int32_t sum_scaled = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->sum_square_left_shift_), + quant->sum_square_multiplier_), + quant->sum_square_right_shift_); + if (isAddOverflow(sum_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum = sum_scaled + quant->out_zp_; + + if (sum > INT8_MAX) { + *inner_dst = INT8_MAX; + } else if (sum < INT8_MIN) { + *inner_dst = INT8_MIN; + } else { + *inner_dst = (int8_t)sum; + } + } + } + return NNACL_OK; +} + +int ReduceSumSquareInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp; + if (isMulOverflow(inner_src[i * inner_size] - quant->in_zp_, inner_src[i * inner_size] - quant->in_zp_)) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + tmp = (inner_src[i * inner_size] - quant->in_zp_) * (inner_src[i * inner_size] - quant->in_zp_); + if (isAddOverflow(sum, tmp)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + sum = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->sum_square_left_shift_), + quant->sum_square_multiplier_), + quant->sum_square_right_shift_); + if (isAddOverflow(sum, quant->in_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + *inner_dst = sum + quant->in_zp_; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reduce_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reduce_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..f8302ae87a4d5aa542ce6f40815ac1af5181016b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reduce_int8.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_REDUCE_INT8_H_ +#define NNACL_INT8_REDUCE_INT8_H_ + +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReduceMeanN(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanH(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNH(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanHW(int n, int plane, int count, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg, + int32_t bias); +int ReduceMeanHC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNHW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNHC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanHWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNHWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); + +int ReduceMeanInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceMeanLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceSumInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceSumLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceMaxInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceMaxLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceMinInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceMinLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceProdLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceProdInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceSumSquareLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceSumSquareInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +#ifdef __cplusplus +} +#endif +#endif // NNACL_INT8_REDUCE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/relux_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/relux_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..843283b8156ce9def59f23e66eba85c85dde0d9d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/relux_int8.c @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/int8/relux_int8.h" + +void ReluXInt8(const int8_t *src, int length, int8_t *dst, const ReluXQuantArg *arg) { + for (int i = 0; i < length; ++i) { + if (src[i] <= arg->input_arg.zp_) { + dst[i] = arg->output_arg.zp_; + continue; + } + const int32_t input_val = src[i] - arg->input_arg.zp_; + const int32_t scaled_input = SaturatingRoundingDoublingHighMul(input_val, arg->input_multiplier_); + const int32_t shifted_input = RoundingDivideByPOT(scaled_input * (1U << arg->left_shift_), -arg->right_shift_); + const int32_t output = shifted_input + arg->output_arg.zp_; + dst[i] = (int8_t)MSMIN(output, arg->quantized_output_max); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/relux_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/relux_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..0676cf1063c0db5718c659c650ed6403555fb49d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/relux_int8.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_RELU_INT8_H_ +#define NNACL_INT8_RELU_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct ReluXQuantArg { + QuantArg input_arg; + QuantArg output_arg; + int input_multiplier_; + int left_shift_; + int right_shift_; + int quantized_output_min; + int quantized_output_max; +} ReluXQuantArg; + +#ifdef __cplusplus +extern "C" { +#endif +void ReluXInt8(const int8_t *src, int length, int8_t *dst, const ReluXQuantArg *arg); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_RELU_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reshape_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reshape_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..4b3fc2006d9fa40c70558b75ccca38afff1f1220 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reshape_int8.c @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/reshape_int8.h" +#include "nnacl_c/reshape_parameter.h" +#include + +void Int8Reshape(const int8_t *input_ptr, int8_t *output_ptr, int64_t real_dst_count, ReshapeQuantArg para) { + if (para.in_args_.scale_ == para.out_args_.scale_ && para.in_args_.zp_ == para.out_args_.zp_) { + memcpy(output_ptr, input_ptr, real_dst_count); + } else { + const float output_inverse_scale = 1.f / para.out_args_.scale_; + float scale = para.in_args_.scale_ * output_inverse_scale; + float bias = -para.in_args_.zp_ * scale; + int32_t output_zp = para.out_args_.zp_; + for (int i = 0; i < real_dst_count; i++) { + int32_t output_tmp = round(input_ptr[i] * scale + bias) + output_zp; + if (output_tmp > para.output_activation_max_) { + output_ptr[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output_ptr[i] = para.output_activation_min_; + } else { + output_ptr[i] = (int8_t)output_tmp; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reshape_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reshape_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..46fb480c60c6704f8037a050c00ce38a5c2fc2d2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reshape_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_RESHAHPE_INT8_H_ +#define NNACL_INT8_RESHAHPE_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/reshape_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Int8Reshape(const int8_t *input_ptr, int8_t *output_ptr, int64_t real_dst_count, ReshapeQuantArg para); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_RESHAHPE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/resize_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/resize_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..4f7f07f8fcb9bd450ee62758d67b62333d348c39 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/resize_int8.c @@ -0,0 +1,233 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/int8/resize_int8.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/errorcode.h" + +int ResizeBilinearInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, int out_h, int out_w, + int channel, int index, int count, ResizeQuantArg quant_arg) { + if (out_w == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + int in_plane = in_h * in_w; + int out_plane = out_h * out_w; + for (int n = 0; n < batch; n++) { + const int8_t *in_b_ptr = input_ptr + n * in_plane * channel; + int8_t *out_b_ptr = output_ptr + n * out_plane * channel; + for (int t = 0; t < count; t++) { + int ori_out_h = (index + t) / out_w; + int ori_out_w = (index + t) % out_w; + int32_t x_lower_value = quant_arg.x_axis_lower_[ori_out_w]; + int32_t x_upper_value = quant_arg.x_axis_upper_[ori_out_w]; + int32_t y_lower_value = quant_arg.y_axis_lower_[ori_out_h]; + int32_t y_upper_value = quant_arg.y_axis_upper_[ori_out_h]; + int32_t weight_x = quant_arg.x_axis_index_[ori_out_w] - (1 << 10) * x_lower_value; + int32_t one_minus_weight_x = (1 << 10) - weight_x; + int32_t weight_y = quant_arg.y_axis_index_[ori_out_h] - (1 << 10) * y_lower_value; + int32_t one_minus_weight_y = (1 << 10) - weight_y; + int64_t left_bottom_coef = (int64_t)(one_minus_weight_x * one_minus_weight_y); + int64_t left_top_coef = (int64_t)(weight_y * one_minus_weight_x); + int64_t right_bottom_coef = (int64_t)(weight_x * one_minus_weight_y); + int64_t right_top_coef = (int64_t)(weight_x * weight_y); + int input_lb_index = (y_lower_value * in_w + x_lower_value) * channel; + int input_lt_index = (y_upper_value * in_w + x_lower_value) * channel; + int input_rb_index = (y_lower_value * in_w + x_upper_value) * channel; + int input_rt_index = (y_upper_value * in_w + x_upper_value) * channel; + int c = 0; + for (; c < channel; c++) { + int64_t out_left_bottom = left_bottom_coef * in_b_ptr[input_lb_index]; + int64_t out_left_top = left_top_coef * in_b_ptr[input_lt_index]; + int64_t out_right_bottom = right_bottom_coef * in_b_ptr[input_rb_index]; + int64_t out_right_top = right_top_coef * in_b_ptr[input_rt_index]; + int64_t out_value = out_left_bottom + out_left_top + out_right_bottom + out_right_top; + out_b_ptr[0] = (int8_t)((out_value + (1 << 19)) / (1 << 20)); + input_lb_index++; + input_lt_index++; + input_rb_index++; + input_rt_index++; + out_b_ptr++; + } + } + } + return NNACL_OK; +} + +int ResizeBilinearWithFloatScaleInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, + int out_h, int out_w, int channel, int index, int count, + ResizeFloatScaleQuantArg quant_arg) { + if (out_w == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + int in_plane = in_h * in_w; + int out_plane = out_h * out_w; + for (int n = 0; n < batch; n++) { + const int8_t *in_b_ptr = input_ptr + n * in_plane * channel; + int8_t *out_b_ptr = output_ptr + n * out_plane * channel; + for (int t = 0; t < count; t++) { + int ori_out_h = (index + t) / out_w; + int ori_out_w = (index + t) % out_w; + int32_t x_lower_value = quant_arg.x_axis_lower_[ori_out_w]; + int32_t x_upper_value = quant_arg.x_axis_upper_[ori_out_w]; + int32_t y_lower_value = quant_arg.y_axis_lower_[ori_out_h]; + int32_t y_upper_value = quant_arg.y_axis_upper_[ori_out_h]; + float weight_x = quant_arg.x_axis_index_[ori_out_w] - x_lower_value; + const float one_minus_weight_x = 1 - weight_x; + float weight_y = quant_arg.y_axis_index_[ori_out_h] - y_lower_value; + const float one_minus_weight_y = 1 - weight_y; + float left_bottom_coef = one_minus_weight_x * one_minus_weight_y; + float left_top_coef = weight_y * one_minus_weight_x; + float right_bottom_coef = weight_x * one_minus_weight_y; + float right_top_coef = weight_x * weight_y; + int input_lb_index = (y_lower_value * in_w + x_lower_value) * channel; + int input_lt_index = (y_upper_value * in_w + x_lower_value) * channel; + int input_rb_index = (y_lower_value * in_w + x_upper_value) * channel; + int input_rt_index = (y_upper_value * in_w + x_upper_value) * channel; + int c = 0; +#ifdef ENABLE_ARM + for (; c <= channel - 4; c += 4) { + float32x4_t in_lb; + in_lb[0] = (float)in_b_ptr[input_lb_index]; + in_lb[1] = (float)in_b_ptr[input_lb_index + 1]; + in_lb[2] = (float)in_b_ptr[input_lb_index + 2]; + in_lb[3] = (float)in_b_ptr[input_lb_index + 3]; + float32x4_t out_left_bottom = vmulq_n_f32(in_lb, left_bottom_coef); + float32x4_t in_lt; + in_lt[0] = (float)in_b_ptr[input_lt_index]; + in_lt[1] = (float)in_b_ptr[input_lt_index + 1]; + in_lt[2] = (float)in_b_ptr[input_lt_index + 2]; + in_lt[3] = (float)in_b_ptr[input_lt_index + 3]; + float32x4_t out_left_top = vmulq_n_f32(in_lt, left_top_coef); + float32x4_t in_rb; + in_rb[0] = (float)in_b_ptr[input_rb_index]; + in_rb[1] = (float)in_b_ptr[input_rb_index + 1]; + in_rb[2] = (float)in_b_ptr[input_rb_index + 2]; + in_rb[3] = (float)in_b_ptr[input_rb_index + 3]; + float32x4_t out_right_bottom = vmulq_n_f32(in_rb, right_bottom_coef); + float32x4_t in_rt; + in_rt[0] = (float)in_b_ptr[input_rt_index]; + in_rt[1] = (float)in_b_ptr[input_rt_index + 1]; + in_rt[2] = (float)in_b_ptr[input_rt_index + 2]; + in_rt[3] = (float)in_b_ptr[input_rt_index + 3]; + float32x4_t out_right_top = vmulq_n_f32(in_rt, right_top_coef); + float32x4_t out_value1 = vaddq_f32(out_left_bottom, out_left_top); + float32x4_t out_value2 = vaddq_f32(out_right_top, out_right_bottom); + float32x4_t out_value = vaddq_f32(out_value1, out_value2); + out_b_ptr[0] = (int8_t)(out_value[0]); + out_b_ptr[1] = (int8_t)(out_value[1]); + out_b_ptr[2] = (int8_t)(out_value[2]); + out_b_ptr[3] = (int8_t)(out_value[3]); + input_lb_index += 4; + input_lt_index += 4; + input_rb_index += 4; + input_rt_index += 4; + out_b_ptr += 4; + } +#endif + for (; c < channel; c++) { + float out_left_bottom = left_bottom_coef * in_b_ptr[input_lb_index]; + float out_left_top = left_top_coef * in_b_ptr[input_lt_index]; + float out_right_bottom = right_bottom_coef * in_b_ptr[input_rb_index]; + float out_right_top = right_top_coef * in_b_ptr[input_rt_index]; + float out_value = out_left_bottom + out_left_top + out_right_bottom + out_right_top; + out_b_ptr[0] = (int8_t)(out_value); + input_lb_index++; + input_lt_index++; + input_rb_index++; + input_rt_index++; + out_b_ptr++; + } + } + } + return NNACL_OK; +} + +int ResizeNearestNeighborInt8Simple(const int8_t *input_data, int8_t *output_data, const int32_t *input_shape, + const int32_t *output_shape, const bool align_corners, int tid, int thread_num) { + int batch, y, x, c; + c = output_shape[3]; + int in_h, in_w, new_height, new_width; + in_h = input_shape[1]; + in_w = input_shape[2]; + new_height = output_shape[1]; + new_width = output_shape[2]; + + for (batch = 0; batch < output_shape[0]; batch++) { + for (y = tid; y < output_shape[1]; y += thread_num) { + int input_y = 0; + ComputeNearestNeighborInt(y, in_h, new_height, align_corners, &input_y); + for (x = 0; x < output_shape[2]; x++) { + int input_x = 0; + ComputeNearestNeighborInt(x, in_w, new_width, align_corners, &input_x); + int in_offset = Offset(input_shape, batch, input_y, input_x, 0); + int out_offset = Offset(output_shape, batch, y, x, 0); + memcpy(output_data + out_offset, input_data + in_offset, c * sizeof(int8_t)); + } + } + } + + return NNACL_OK; +} + +void ComputeNearestNeighborInt(const int32_t pos, const int in_size, const int32_t new_size, const bool align_corners, + int32_t *nearest) { + if (new_size == 0) { + return; + } + *nearest = (in_size * pos) / new_size; + if (align_corners && new_size != 1) { + *nearest = ((in_size - 1) * pos + (new_size - 1) / 2) / (new_size - 1); + } + *nearest = *nearest < in_size ? *nearest : in_size - 1; +} + +int ResizeNearestNeighborInt8(const int8_t *input_data, int8_t *output_data, const int32_t *input_shape, + const int32_t *output_shape, const bool align_corners, const QuantMulArg *multiplier, + const QuantArg *quant_in, const QuantArg *quant_out, int tid, int thread_num) { + const int base_offset = 20; + int32_t batch, y, x, c; + int32_t in_h, in_w, new_height, new_width; + in_h = input_shape[1]; + in_w = input_shape[2]; + new_height = output_shape[1]; + new_width = output_shape[2]; + + for (batch = 0; batch < output_shape[0]; batch++) { + for (y = tid; y < output_shape[1]; y += thread_num) { + int input_y = 0; + ComputeNearestNeighborInt(y, in_h, new_height, align_corners, &input_y); + for (x = 0; x < output_shape[2]; x++) { + int input_x = 0; + ComputeNearestNeighborInt(x, in_w, new_width, align_corners, &input_x); + for (c = 0; c < output_shape[3]; c++) { + int in_offset = Offset(input_shape, batch, input_y, input_x, c); + int out_offset = Offset(output_shape, batch, y, x, c); + + int32_t out_value = MultiplyByQuantizedMultiplier( + input_data[in_offset] - quant_in->zp_, multiplier->multiplier_, + multiplier->left_shift_ + base_offset, multiplier->right_shift_ - base_offset) + + quant_out->zp_; + out_value = out_value > INT8_MAX ? INT8_MAX : out_value; + out_value = out_value < INT8_MIN ? INT8_MIN : out_value; + output_data[out_offset] = (int8_t)out_value; + } + } + } + } + + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/resize_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/resize_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..78e1972bf15c08109b57e0d66e9cb0613b9a6875 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/resize_int8.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_RESIZE_H_ +#define NNACL_INT8_RESIZE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/resize_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ResizeBilinearInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, int out_h, int out_w, + int channel, int index, int count, ResizeQuantArg quant_arg); + +int ResizeBilinearWithFloatScaleInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, + int out_h, int out_w, int channel, int index, int count, + ResizeFloatScaleQuantArg quant_arg); + +int ResizeNearestNeighborInt8Simple(const int8_t *input_data, int8_t *output_data, const int32_t *input_shape, + const int32_t *output_shape, const bool align_corners, int tid, int thread_num); + +int ResizeNearestNeighborInt8(const int8_t *input_data, int8_t *output_data, const int32_t *input_shape, + const int32_t *output_shape, const bool align_corners, const QuantMulArg *multiplier, + const QuantArg *quant_in, const QuantArg *quant_out, int tid, int thread_num); + +void ComputeNearestNeighborInt(const int32_t pos, const int in_size, const int32_t new_size, const bool align_corners, + int32_t *nearest); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_RESIZE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/scale_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/scale_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..23a508c52004ccef708bfeb7ef2d3aa6b5c740fd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/scale_int8.c @@ -0,0 +1,164 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/scale_int8.h" +#include "nnacl_c/int8/fixed_point.h" + +#ifdef ENABLE_NEON +int16x4_t ClacSumHalfWordMul2(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t output_multiplier_vec, const ScaleQuantParameter *scale_param) { + int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1); + int32x4_t raw_sum = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), + scale_param->scale_mul_arg_.right_shift_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(scale_param->output_zp_)); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_max_)); + return vqmovn_s32(raw_sum); +} + +int16x4_t ClacSumHalfWordMul3(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t scaled_input2, + const ScaleQuantParameter *scale_param) { + int32x4_t output_multiplier_vec = vdupq_n_s32(scale_param->scale_mul_arg_.multiplier_); + int32x4_t output_multiplier_vec2 = vdupq_n_s32(scale_param->offset_mul_arg_.multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << (size_t)(scale_param->scale_mul_arg_.left_shift_)); + int32x4_t left_shift_out_vec2 = vdupq_n_s32(1 << (size_t)(scale_param->offset_mul_arg_.left_shift_)); + int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1); + int32x4_t raw_sum = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), + scale_param->scale_mul_arg_.right_shift_); + int32x4_t raw_sum2 = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(scaled_input2, left_shift_out_vec2), output_multiplier_vec2), + scale_param->offset_mul_arg_.right_shift_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(scale_param->output_zp_)); + raw_sum = vaddq_s32(raw_sum, raw_sum2); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_max_)); + return vqmovn_s32(raw_sum); +} +#endif + +void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const ScaleQuantParameter *scale_param, + int real_dst_count) { + int index = 0; +#ifdef ENABLE_NEON + int32x4_t output_multiplier_vec = vdupq_n_s32(scale_param->scale_mul_arg_.multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << scale_param->scale_mul_arg_.left_shift_); + + for (; index <= real_dst_count - 8; index += 8) { + int8x8_t input_s8 = vld1_s8(in_data + index); + int16x8_t input_s16 = vmovl_s8(input_s8); + int16x8_t input0_val = vaddq_s16(input_s16, vdupq_n_s16(scale_param->input_zp_)); + + int8x8_t input1_s8 = vld1_s8(scale + index); + int16x8_t input1_s16 = vmovl_s8(input1_s8); + int16x8_t input1_val = vaddq_s16(input1_s16, vdupq_n_s16(scale_param->scale_zp_)); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + + int16x4_t sum_low = + ClacSumHalfWordMul2(input0_low, input1_low, left_shift_out_vec, output_multiplier_vec, scale_param); + int16x4_t sum_high = + ClacSumHalfWordMul2(input0_high, input1_high, left_shift_out_vec, output_multiplier_vec, scale_param); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(out_data, res_u8_n0); + out_data += 8; + } +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = scale_param->input_zp_ + in_data[index]; + const int32_t input1_val = scale_param->scale_zp_ + scale[index]; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << scale_param->scale_mul_arg_.left_shift_), + scale_param->scale_mul_arg_.multiplier_), + scale_param->scale_mul_arg_.right_shift_); + + mul_result += scale_param->output_zp_; + + if (mul_result > scale_param->output_activation_max_) { + out_data[index] = scale_param->output_activation_max_; + } else if (mul_result < scale_param->output_activation_min_) { + out_data[index] = scale_param->output_activation_min_; + } else { + out_data[index] = (int8_t)mul_result; + } + } + return; +} + +void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset, + const ScaleQuantParameter *scale_param, int real_dst_count) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= real_dst_count - 8; index += 8) { + int8x8_t input_s8 = vld1_s8(in_data + index); + int16x8_t input_s16 = vmovl_s8(input_s8); + int16x8_t input0_val = vaddq_s16(input_s16, vdupq_n_s16(scale_param->input_zp_)); + + int8x8_t input1_s8 = vld1_s8(scale + index); + int16x8_t input1_s16 = vmovl_s8(input1_s8); + int16x8_t input1_val = vaddq_s16(input1_s16, vdupq_n_s16(scale_param->scale_zp_)); + + int8x8_t input2_s8 = vld1_s8(offset + index); + int16x8_t input2_s16 = vmovl_s8(input2_s8); + int16x8_t input2_val = vaddq_s16(input2_s16, vdupq_n_s16(scale_param->offset_zp_)); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + int32x4_t input2_low = vmovl_s16(vget_low_s16(input2_val)); + int32x4_t input2_high = vmovl_s16(vget_high_s16(input2_val)); + + int16x4_t sum_low = ClacSumHalfWordMul3(input0_low, input1_low, input2_low, scale_param); + int16x4_t sum_high = ClacSumHalfWordMul3(input0_high, input1_high, input2_high, scale_param); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(out_data, res_u8_n0); + out_data += 8; + } +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = in_data[index] - scale_param->input_zp_; + const int32_t input1_val = scale[index] - scale_param->scale_zp_; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << scale_param->scale_mul_arg_.left_shift_), + scale_param->scale_mul_arg_.multiplier_), + scale_param->scale_mul_arg_.right_shift_); + int tmp_bias = offset[index] - scale_param->offset_zp_; + int bias = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_bias * (1 << (unsigned int)scale_param->offset_mul_arg_.left_shift_), + scale_param->offset_mul_arg_.multiplier_), + scale_param->offset_mul_arg_.right_shift_); + + mul_result += bias + scale_param->output_zp_; + + if (mul_result > scale_param->output_activation_max_) { + out_data[index] = scale_param->output_activation_max_; + } else if (mul_result < scale_param->output_activation_min_) { + out_data[index] = scale_param->output_activation_min_; + } else { + out_data[index] = (int8_t)mul_result; + } + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/scale_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/scale_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..02cace65300fd29523780915cd91f3dd958ffa07 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/scale_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SCALE_INT8_H_ +#define NNACL_SCALE_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/nnacl_common.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const ScaleQuantParameter *scale_param, + int real_dst_count); +void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset, + const ScaleQuantParameter *scale_param, int real_dst_count); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_SCALE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sigmoid_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sigmoid_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..6f8926021a11776a0b4094e275396e3209705b15 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sigmoid_int8.c @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/sigmoid_int8.h" + +int SigmoidInt8(const int8_t *src, int length, int8_t *dst, int8_t *table) { + for (int i = 0; i < length; i++) { + const int8_t input_value = src[i]; + uint8_t index = (uint8_t)input_value; + dst[i] = table[index]; + } + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sigmoid_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sigmoid_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..5fc8db7ba4b2d88024cd16072dc3298da16798c3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sigmoid_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_SIGMOID_INT8_H_ +#define NNACL_INT8_SIGMOID_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/int8/fixed_point.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SigmoidInt8(const int8_t *src, int length, int8_t *dst, int8_t *table); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SIGMOID_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/slice_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/slice_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..5ae5a4f5f16e29535005d8a30f8ba12c3917db68 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/slice_int8.c @@ -0,0 +1,97 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/slice_int8.h" +#include +#include +#include "nnacl_c/errorcode.h" + +int SliceInt8(const int8_t *input, int8_t *output, const SliceStruct *param, const SliceQuantArg *quant_arg, + int thread_id, int thread_num) { + double input_scale = quant_arg->in_args_.scale_; + int input_zp = quant_arg->in_args_.zp_; + double output_scale = quant_arg->out_args_.scale_; + int output_zp = quant_arg->out_args_.zp_; + const int base_offset = 20; + int act_min = quant_arg->output_activation_min_; + int act_max = quant_arg->output_activation_max_; + + size_t out_stride[8]; + out_stride[7] = 1; + for (int i = 6; i >= 0; --i) { + out_stride[i] = out_stride[i + 1] * param->size_[i + 1]; + } + + int count_per_thread = UP_DIV(param->size_[5], thread_num); + size_t thread_begin = thread_id * count_per_thread; + size_t thread_end = MSMIN(param->size_[5], thread_begin + count_per_thread); + int unit_size = param->size_[7] * sizeof(int8_t); + size_t in_stride[8]; + in_stride[7] = 1; + for (int i = 6; i >= 0; --i) { + in_stride[i] = param->shape_[i + 1] * in_stride[i + 1]; + } + int i, j, k, l, n, h, w, c; + + int equal_quant = 0; + if (fabs(input_scale - output_scale) <= FLT_EPSILON && input_zp == output_zp) { + equal_quant = 1; + } + + for (i = 0; i < param->size_[0]; ++i) { + size_t out_offset0 = i * out_stride[0]; + size_t in_offset0 = (i + param->begin_[0]) * in_stride[0] + param->begin_[7]; + for (j = 0; j < param->size_[1]; ++j) { + size_t out_offset1 = j * out_stride[1] + out_offset0; + size_t in_offset1 = (j + param->begin_[1]) * in_stride[1] + in_offset0; + for (k = 0; k < param->size_[2]; ++k) { + size_t out_offset2 = k * out_stride[2] + out_offset1; + size_t in_offset2 = (k + param->begin_[2]) * in_stride[2] + in_offset1; + for (l = 0; l < param->size_[3]; ++l) { + size_t out_offset3 = l * out_stride[3] + out_offset2; + size_t in_offset3 = (l + param->begin_[3]) * in_stride[3] + in_offset2; + for (n = 0; n < param->size_[4]; ++n) { + size_t out_offset4 = n * out_stride[4] + out_offset3; + size_t in_offset4 = (n + param->begin_[4]) * in_stride[4] + in_offset3; + for (h = thread_begin; h < thread_end; ++h) { + size_t out_offset5 = h * out_stride[5] + out_offset4; + size_t in_offset5 = (h + param->begin_[5]) * in_stride[5] + in_offset4; + for (w = 0; w < param->size_[6]; ++w) { + size_t out_offset = w * out_stride[6] + out_offset5; + size_t in_offset = (w + param->begin_[6]) * in_stride[6] + in_offset5; + if (equal_quant == 1) { + memcpy(output + out_offset, input + in_offset, unit_size); + } else { + for (c = 0; c < param->size_[7]; ++c) { + int32_t output_val = + MultiplyByQuantizedMultiplier(input[in_offset + c] - input_zp, quant_arg->multiplier_.multiplier_, + quant_arg->multiplier_.left_shift_ + base_offset, + quant_arg->multiplier_.right_shift_ - base_offset) + + output_zp; + output_val = MSMAX(INT8_MIN, MSMIN(output_val, INT8_MAX)); + output[c + out_offset] = (int8_t)MSMAX(act_min, MSMIN(output_val, act_max)); + } + } + } + } + } + } + } + } + } + + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/slice_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/slice_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..7ace3db0ac931cc70db8d915c3a2d369c39ae4f8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/slice_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_SLICE_INT8_H_ +#define NNACL_INT8_SLICE_INT8_H_ + +#include +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/kernel/slice.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SliceInt8(const int8_t *input, int8_t *output, const SliceStruct *param, const SliceQuantArg *quant_arg, + int thread_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SLICE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/softmax_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/softmax_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..7d32f06a385e109ccfe9b4347d9d4c713e6b0633 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/softmax_int8.c @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/softmax_int8.h" +#include "nnacl_c/errorcode.h" + +int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int32_t *exp_data, int32_t *sum_data, + const int32_t *input_shape, int n_dim, int32_t axis, const SoftmaxQuantArg *quant_param) { + int axis_shape_size = input_shape[axis]; + int inner_size = 1; + if (n_dim > DIMENSION_5D) { + return NNACL_ERR; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + + for (int o = 0; o < count; o++) { + int outter_offset = o * axis_shape_size * inner_size; + + for (int c = 0; c < inner_size; c++) { + int8_t max_row = quant_param->output_activation_min_; + for (int i = 0; i < axis_shape_size; ++i) { + int axis_offset = outter_offset + c + i * inner_size; + max_row = MSMAX(max_row, input_ptr[axis_offset]); + } + + int32_t exp_sum = 0; + for (int i = 0; i < axis_shape_size; ++i) { + int axis_offset = outter_offset + c + i * inner_size; + const int32_t input_val = input_ptr[axis_offset] - max_row; + const int32_t input_scaled = SaturatingRoundingDoublingHighMul( + input_val * (1 << (unsigned int)quant_param->shift_left_), quant_param->output_multiplier_); + int exp_val = exp_on_negative_values(input_scaled, 5); + exp_data[axis_offset] = exp_val; + exp_sum = exp_sum + Rescale(exp_val, 0, 12); + } + sum_data[c] = exp_sum; + } + for (int i = 0; i < axis_shape_size; ++i) { + int axis_offset = outter_offset + i * inner_size; + for (int c = 0; c < inner_size; ++c) { + int num_bits_over_unit; + int shifted_scale = ComputerReciprocal(sum_data[c], 12, &num_bits_over_unit); + int unsat_output = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_scale, exp_data[axis_offset + c]), num_bits_over_unit + 31 - 8); + + int raw_output = unsat_output + quant_param->output_activation_min_; + output_ptr[axis_offset + c] = + (int8_t)MSMAX(quant_param->output_activation_min_, MSMIN(raw_output, quant_param->output_activation_max_)); + } + } + } + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/softmax_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/softmax_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..6f538af2309e44bd9024620d5b73bd5897e41dcb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/softmax_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_SOFTMAX_INT8_H_ +#define NNACL_INT8_SOFTMAX_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int32_t *exp_data, int32_t *sum_data, + const int32_t *input_shape, int n_dim, int32_t axis, const SoftmaxQuantArg *quant_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SOFTMAX_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/space_to_batch_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/space_to_batch_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..f5bba63b0719458b65289f30c71f4a6cc8887c94 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/space_to_batch_int8.c @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/int8/space_to_batch_int8.h" +#include "nnacl_c/common_func.h" + +void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, const int32_t *block_sizes, const int32_t *in_shape, + const int32_t *out_shape) { + int out_dim0 = out_shape[0]; + int out_dim1 = out_shape[1]; + int out_dim2 = out_shape[2]; + int copy_num = out_shape[3]; + int block_w = block_sizes[1]; + int block_h = block_sizes[0]; + int in_strides[4] = {0}; + ComputeStrides(in_shape, in_strides, 4); + int out_strides[4] = {0}; + ComputeStrides(out_shape, out_strides, 4); + size_t copy_size = copy_num * sizeof(int8_t); + size_t out_offset = 0; + + NNACL_CHECK_ZERO_RETURN(in_shape[0]); + NNACL_CHECK_ZERO_RETURN(block_w); + for (int n = 0; n < out_dim0; ++n) { + int in_n = n % in_shape[0]; + int32_t stride_w = (n / in_shape[0]) % block_w; + int32_t stride_h = (n / in_shape[0]) / block_w; + size_t in_offset0 = in_n * in_strides[0]; + for (int h = 0; h < out_dim1; ++h) { + size_t in_offset1 = in_offset0 + (h * block_h + stride_h) * in_strides[1]; + for (int w = 0; w < out_dim2; ++w) { + size_t in_offset2 = in_offset1 + (w * block_w + stride_w) * in_strides[2]; + memcpy(output + out_offset, input + in_offset2, copy_size); + out_offset += copy_num; + } + } + } +} + +void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, SpaceToBatchParameter *param, int32_t zp) { + int block_shape_h = param->block_sizes_[0]; + int block_shape_w = param->m_ == 2 ? param->block_sizes_[1] : 1; + int in_b = param->input_shape_[0]; + int in_h = param->input_shape_[1]; + int in_w = param->input_shape_[2]; + int channel = param->input_shape_[3]; + int out_h = param->output_shape_[1]; + int out_w = param->output_shape_[2]; + int pad_t = param->paddings_[0]; + int pad_l = param->m_ == 2 ? param->paddings_[2] : 0; + + NNACL_CHECK_ZERO_RETURN(in_b); + NNACL_CHECK_ZERO_RETURN(block_shape_w); + for (int i = 0; i < param->output_shape_[0]; ++i) { + int in_batch = i % in_b; + int offset_w = (i / in_b) % block_shape_w; + int offset_h = (i / in_b) / block_shape_w; + int in_b_offset = in_batch * in_h * in_w * channel; + int out_b_offset = i * out_h * out_w * channel; + for (int j = 0; j < out_h; ++j) { + int out_h_offset = out_b_offset + j * out_w * channel; + for (int k = 0; k < out_w; ++k) { + int8_t *out_ptr = output + out_h_offset + k * channel; + int index_h = j * block_shape_h + offset_h; + int index_w = k * block_shape_w + offset_w; + if (index_h < pad_t || index_h >= (pad_t + in_h) || index_w < pad_l || index_w >= (pad_l + in_w)) { + memset(out_ptr, zp, channel * sizeof(int8_t)); + } else { + int in_plane_offset = in_b_offset + ((index_h - pad_t) * in_w + (index_w - pad_l)) * channel; + const int8_t *in_ptr = input + in_plane_offset; + memcpy(out_ptr, in_ptr, channel * sizeof(int8_t)); + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/space_to_batch_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/space_to_batch_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..8d60bc1c12e57d9f5240e4bbf905cea0e17f0567 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/space_to_batch_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_SPACE_TO_BATCH_INT8_H_ +#define NNACL_INT8_SPACE_TO_BATCH_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, const int32_t *block_sizes, const int32_t *in_shape, + const int32_t *out_shape); +void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, SpaceToBatchParameter *param, int32_t zp); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SPACE_TO_BATCH_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/split_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/split_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..5cc53dfadf033c52f48ed4de3dd977ec3e1cd742 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/split_int8.c @@ -0,0 +1,75 @@ +/** + * Copyright 2019-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/split_int8.h" +#include +#include +#include +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/errorcode.h" + +int Int8DoSplit(const int8_t *in_data, int8_t **out_data, const int32_t *input_shape, int offset, int num_unit, + const SplitParameter *param) { + if (in_data == NULL || out_data == NULL) { + return NNACL_ERR; + } + const int num_split = param->num_split_; + const int32_t *split_sizes = param->split_sizes_; + const int32_t *strides = param->strides_; + const int split_dim = param->split_dim_; + int in_stride = strides[split_dim]; + + int stride_per_split = in_stride * input_shape[split_dim]; + int split_which = offset % num_split; + int split_times = offset / num_split; + const int8_t *src = in_data + split_times * stride_per_split; + for (int i = 0; i < split_which; i++) { + src += split_sizes[i] * in_stride; + } + + const QuantArg in_quant_arg = param->quant_arg_.in_args_; + float in_scale = in_quant_arg.scale_; + int32_t in_zp = in_quant_arg.zp_; + const QuantArg *out_quant_arg = param->quant_arg_.out_args_; + + for (int i = offset; i < offset + num_unit; i++) { + split_which = i % num_split; + split_times = i / num_split; + int copy_size = split_sizes[split_which] * in_stride; + int8_t *dst = out_data[split_which] + split_times * copy_size; + float out_scale = out_quant_arg[split_which].scale_; + int32_t out_zp = out_quant_arg[split_which].zp_; + if (fabs(in_scale - out_scale) <= FLT_EPSILON && in_zp == out_zp) { + (void)memcpy(dst, src, copy_size * sizeof(int8_t)); + } else { + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + for (int j = 0; j < copy_size; j++) { + int32_t output_tmp = round(src[j] * scale + bias) + out_zp; + if (output_tmp > param->quant_arg_.output_activation_max_) { + dst[j] = param->quant_arg_.output_activation_max_; + } else if (output_tmp < param->quant_arg_.output_activation_min_) { + dst[j] = param->quant_arg_.output_activation_min_; + } else { + dst[j] = (int8_t)output_tmp; + } + } + } + src += copy_size; + } + + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/split_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/split_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..8db63febb5f2275bd81e88df5d6fa6f53a868d7f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/split_int8.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_SPLIT_INT8_H_ +#define NNACL_INT8_SPLIT_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int Int8DoSplit(const int8_t *in_data, int8_t **out_data, const int32_t *input_shape, int offset, int num_unit, + const SplitParameter *split_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SPLIT_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/squeeze_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/squeeze_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..19b4b914c3c6f5fe80305a4e26fab35a9f6f01e3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/squeeze_int8.c @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/squeeze_int8.h" + +void SqueezeInt8(const int8_t *input_ptr, int8_t *output_ptr, const SqueezeQuantArg *quant_Squeeze_parm, int num, + int task_id, int thread_count) { + float output_scale = quant_Squeeze_parm->out_quant_args_->scale_; + const float output_inverse_scale = 1.f / output_scale; + QuantArg *input_quant = quant_Squeeze_parm->in_quant_args_; + int output_zp = quant_Squeeze_parm->out_quant_args_->zp_; + + const int i = 0; + for (int j = task_id; j < num; j += thread_count) { + float scale = input_quant[i].scale_ * output_inverse_scale; + float bias = -input_quant[i].zp_ * scale; + int32_t output_tmp = round(input_ptr[j] * scale + bias) + output_zp; + if (output_tmp > 127) { + output_ptr[j] = 127; + } else if (output_tmp < -128) { + output_ptr[j] = -128; + } else { + output_ptr[j] = (int8_t)output_tmp; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/squeeze_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/squeeze_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..4129a292e744320f8cbe5cb03c960b4dae0b9982 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/squeeze_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_SQUEEZE_INT8_H_ +#define NNACL_INT8_SQUEEZE_INT8_H_ + +#include "nnacl_c/squeeze_parameter.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +void SqueezeInt8(const int8_t *input_ptr, int8_t *output_ptr, const SqueezeQuantArg *quant_Squeeze_parm, int num, + int task_id, int thread_count); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SQUEEZE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sub_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sub_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..f1548537903cf79c8f323954a86ca2054d429050 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sub_int8.c @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/sub_int8.h" +#ifdef ENABLE_NEON +#include +#include "nnacl_c/int8/common_func_int8.h" +#endif +#include "nnacl_c/int8/fixed_point.h" + +#ifdef ENABLE_NEON + +int16x4_t DoClacSumHalfWord(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t output_multiplier_vec, const SubQuantArg *para) { + int32x4_t raw_data = vsubq_s32(scaled_input0, scaled_input1); + + raw_data = RoundingDivideByPOTInt32x4(vqrdmulhq_s32(vmulq_s32(raw_data, left_shift_out_vec), output_multiplier_vec), + para->right_shift_out_); + raw_data = vaddq_s32(raw_data, vdupq_n_s32(para->out_args_.zp_)); + raw_data = vmaxq_s32(raw_data, vdupq_n_s32(para->output_activation_min_)); + raw_data = vminq_s32(raw_data, vdupq_n_s32(para->output_activation_max_)); + return vqmovn_s32(raw_data); +} + +void SubInt8NEON(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const SubQuantArg *para, int32_t *index) { + int32x4_t left_shift_result0_vec = vdupq_n_s32(para->left_shift_result0_); + int32x4_t left_shift_result1_vec = vdupq_n_s32(para->left_shift_result1_); + int32x4_t input0_multiplier_vec = vdupq_n_s32(para->input0_multiplier_); + int32x4_t input1_multiplier_vec = vdupq_n_s32(para->input1_multiplier_); + int32x4_t output_multiplier_vec = vdupq_n_s32(para->output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32((1 << (size_t)para->left_shift_out_)); + int32x4_t right_shift0_vec = vdupq_n_s32(-para->right_shift0_); + int32x4_t right_shift1_vec = vdupq_n_s32(-para->right_shift1_); + + for (; (*index) <= real_dst_count - 8; (*index) += 8) { + int16x8_t input0_val = LoadAndAddOffset(input0_data, *index, para->in0_args_.zp_); + int16x8_t input1_val = LoadAndAddOffset(input1_data, *index, para->in1_args_.zp_); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + + int32x4_t scaled_input0_low = + ClacScaledInput(input0_low, left_shift_result0_vec, input0_multiplier_vec, right_shift0_vec); + int32x4_t scaled_input0_high = + ClacScaledInput(input0_high, left_shift_result0_vec, input0_multiplier_vec, right_shift0_vec); + int32x4_t scaled_input1_low = + ClacScaledInput(input1_low, left_shift_result1_vec, input1_multiplier_vec, right_shift1_vec); + int32x4_t scaled_input1_high = + ClacScaledInput(input1_high, left_shift_result1_vec, input1_multiplier_vec, right_shift1_vec); + + int16x4_t sum_low = + DoClacSumHalfWord(scaled_input0_low, scaled_input1_low, left_shift_out_vec, output_multiplier_vec, para); + int16x4_t sum_high = + DoClacSumHalfWord(scaled_input0_high, scaled_input1_high, left_shift_out_vec, output_multiplier_vec, para); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(output_data + *index, res_u8_n0); + } +} +#endif + +int SubInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const SubQuantArg *para) { + int index = 0; +#ifdef ENABLE_NEON + SubInt8NEON(input0_data, input1_data, output_data, real_dst_count, para, &index); +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = para->in0_args_.zp_ + input0_data[index]; + const int32_t input1_val = para->in1_args_.zp_ + input1_data[index]; + const int32_t shifted_input0_val = input0_val * para->left_shift_result0_; + const int32_t shifted_input1_val = input1_val * para->left_shift_result1_; + const int32_t scaled_input0_val = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_input0_val, para->input0_multiplier_), para->right_shift0_); + const int32_t scaled_input1_val = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_input1_val, para->input1_multiplier_), para->right_shift1_); + + const int32_t raw_data = scaled_input0_val - scaled_input1_val; + const int32_t raw_output = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_data * (1 << (unsigned int)para->left_shift_out_), + para->output_multiplier_), + para->right_shift_out_) + + para->out_args_.zp_; + + output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_)); + } + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sub_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sub_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..e9107daa4c534810fcab69719e8f757289f9b210 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sub_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_SUB_INT8_H_ +#define NNACL_INT8_SUB_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SubInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const SubQuantArg *para); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SUB_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/tanh_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/tanh_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..351718477072a573caafa97eb249b8f78adf7955 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/tanh_int8.c @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/tanh_int8.h" +#ifdef ENABLE_NEON +#include +#endif + +void TanhInt8(const int8_t *input_ptr, int8_t *output_ptr, int size, const TanhQuantParameter *quant) { + for (int i = 0; i < size; ++i) { + float fp32_src = (input_ptr[i] - quant->in_zp_) * quant->in_scale_; + float fp32_dst = TanhOpt(fp32_src); + int32_t int32_dst = (int32_t)round(fp32_dst * 1.0 / quant->out_scale_ + quant->out_zp_); + output_ptr[i] = (int8_t)MSMAX(MSMIN(int32_dst, 127), -128); + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/tanh_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/tanh_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..332d71806a5d7c7ec5cc3c17e7cfe8b90fa80a4d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/tanh_int8.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_TANH_INT8_H_ +#define NNACL_INT8_TANH_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/fp32/activation_fp32.h" + +typedef struct TanhQuantParameter { + int32_t in_zp_; + int32_t out_zp_; + double in_scale_; + double out_scale_; +} TanhQuantParameter; + +#ifdef __cplusplus +extern "C" { +#endif + +void TanhInt8(const int8_t *input_ptr, int8_t *output_ptr, int size, const TanhQuantParameter *quant); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_TANH_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/topk_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/topk_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..643ff0c1864b63235b9870669cefad61eb62c6fb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/topk_int8.c @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/topk_int8.h" + +int DescendCmpInt8(const void *a, const void *b) { + return ((const TopkNodeInt8 *)b)->element - ((const TopkNodeInt8 *)a)->element; +} + +int AscendCmpInt8(const void *a, const void *b) { + return ((const TopkNodeInt8 *)a)->element - ((const TopkNodeInt8 *)b)->element; +} + +void TopkInt8(int8_t *input_data, int8_t *output_data, int32_t *output_index, TopkParameter *parameter) { + int dim_size = parameter->dim_size_; + int outer_loop_num = parameter->outer_loop_num_; + int inner_loop_num = parameter->inner_loop_num_; + int k = parameter->k_; + TopkNode *top_map = (TopkNode *)parameter->topk_node_list_; + + int8_t *cur_input_data = (int8_t *)input_data; + int8_t *cur_output_data = (int8_t *)output_data; + int32_t *cur_output_index = output_index; + for (int i = 0; i < outer_loop_num; i++) { + int in_offset = i * dim_size * inner_loop_num; + int out_offset = i * k * inner_loop_num; + for (int j = 0; j < inner_loop_num; j++) { + for (int m = 0; m < dim_size; m++) { + int offset = in_offset + m * inner_loop_num + j; + top_map[m].element = *(cur_input_data + offset); + top_map[m].index = m; + } + qsort(top_map, dim_size, sizeof(top_map[0]), DescendCmpInt8); + if (!parameter->sorted_) { + qsort(top_map, k, sizeof(top_map[0]), AscendCmpInt8); + } + for (int m = 0; m < k; m++) { + int offset = out_offset + m * inner_loop_num + j; + cur_output_data[offset] = top_map[m].element; + cur_output_index[offset] = top_map[m].index; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/topk_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/topk_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..9910f3559025066ce1062cfedcc1336c82ae4db1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/topk_int8.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_TOPK_INT8_H_ +#define NNACL_INT8_TOPK_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/topk_fp32.h" + +typedef struct TopkNodeInt8 { + int8_t element; + int32_t index; +} TopkNodeInt8; + +#ifdef __cplusplus +extern "C" { +#endif +void TopkInt8(int8_t *input_data, int8_t *output_data, int32_t *output_index, TopkParameter *parameter); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_TOPK_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/transpose_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/transpose_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..65c1a59032247885cd0bf37feb8b32b7fedc41b9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/transpose_int8.c @@ -0,0 +1,257 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/transpose_int8.h" +void TransposeDim2Int8(const int8_t *in_data, int8_t *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * output1; + int stride0_i = i * 1 * stride0; + for (int j = 0; j < output1; ++j) { + out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; + } + } + return; +} + +void TransposeDim3Int8(const int8_t *in_data, int8_t *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; + } + } + } +} + +void TransposeDim4Int8(const int8_t *in_data, int8_t *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = + in_data[stride0_i + stride1_j + stride2_k + m * stride3]; + } + } + } + } +} + +void TransposeDim5Int8(const int8_t *in_data, int8_t *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; + } + } + } + } + } +} + +void TransposeDim6Int8(const int8_t *in_data, int8_t *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int stride5 = strides[perm[5]]; + + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int out_stride4 = out_strides[4]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + const int output5 = output_shape[5]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + int out_stride4_n = n * out_stride4; + int stride4_n = n * stride4; + for (int p = 0; p < output5; ++p) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n + p] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_n + p * stride5]; + } + } + } + } + } + } +} + +int DoTransposeInt8(const int8_t *in_data, int8_t *out_data, const int32_t *output_shape, + const TransposeParameter *transpose_param) { + NNACL_CHECK_NULL_RETURN_ERR(in_data); + NNACL_CHECK_NULL_RETURN_ERR(out_data); + NNACL_CHECK_NULL_RETURN_ERR(output_shape); + NNACL_CHECK_NULL_RETURN_ERR(transpose_param); + + const int32_t *perm = transpose_param->perm_; + const int32_t *strides = transpose_param->strides_; + const int32_t *out_strides = transpose_param->out_strides_; + const int num_axes = transpose_param->num_axes_; + + // check if transpose is needed + bool needTranspose = false; + for (int i = 1; i < num_axes; i++) { + if (perm[i] - perm[i - 1] != 1) { + needTranspose = true; + break; + } + } + + if (!needTranspose) { + (void)memcpy(out_data, in_data, transpose_param->data_num_ * sizeof(int8_t)); + return NNACL_OK; + } + + switch (num_axes) { + case 2: + TransposeDim2Int8(in_data, out_data, strides, out_strides, perm, output_shape); + break; + case 3: + TransposeDim3Int8(in_data, out_data, strides, out_strides, perm, output_shape); + break; + case 4: + TransposeDim4Int8(in_data, out_data, strides, out_strides, perm, output_shape); + break; + case 5: + TransposeDim5Int8(in_data, out_data, strides, out_strides, perm, output_shape); + break; + case 6: + TransposeDim6Int8(in_data, out_data, strides, out_strides, perm, output_shape); + break; + default: + return NNACL_ERR; + } + + return NNACL_OK; +} + +void TransposeDimsInt8(const int8_t *in_data, int8_t *out_data, const int32_t *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num) { + NNACL_CHECK_NULL_RETURN_VOID(in_data); + NNACL_CHECK_NULL_RETURN_VOID(out_data); + NNACL_CHECK_NULL_RETURN_VOID(output_shape); + NNACL_CHECK_NULL_RETURN_VOID(transpose_param); + NNACL_CHECK_ZERO_RETURN(thread_num); + const int32_t *perm = transpose_param->perm_; + const int32_t *strides = transpose_param->strides_; + const int32_t *out_strides = transpose_param->out_strides_; + int num_axes = transpose_param->num_axes_; + size_t data_size = (size_t)((*out_strides) * output_shape[0]); + size_t offset_size = UP_DIV(data_size, thread_num); + size_t task_offset = offset_size * task_id; + size_t count = data_size - task_offset; + if (data_size < task_offset) { + return; + } + count = MSMIN(offset_size, count); + for (size_t idx = task_offset; idx < task_offset + count; ++idx) { + int pos = (int)idx; + int output_idx = 0; + int input_idx = 0; + for (int i = 0; i < num_axes; ++i) { + NNACL_CHECK_ZERO_RETURN(*(out_strides + i)); + int position = pos / *(out_strides + i); + int out_stride = i < num_axes - 1 ? out_strides[i] : 1; + output_idx += (position * out_stride); + input_idx += (position * strides[perm[i]]); + pos -= position * (*(out_strides + i)); + } + out_data[output_idx] = in_data[input_idx]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/transpose_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/transpose_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..50fbafee798be631e85fdda96f1acdd5b70f6c16 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/transpose_int8.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_TRANSPOSE_INT8_H_ +#define NNACL_INT8_TRANSPOSE_INT8_H_ + +#include +#include "nnacl_c/transpose_parameter.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DoTransposeInt8(const int8_t *in_data, int8_t *out_data, const int32_t *output_shape, + const TransposeParameter *transpose_param); +void TransposeDimsInt8(const int8_t *in_data, int8_t *out_data, const int32_t *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_TRANSPOSE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/unsqueeze_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/unsqueeze_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..44d10ed9e7e6214eae2f24f0b0d260169775dc73 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/unsqueeze_int8.c @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/unsqueeze_int8.h" +#include "nnacl_c/unsqueeze_parameter.h" +#include "nnacl_c/errorcode.h" + +int Int8Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, const UnSqueezeParameter *para_, size_t data_size, + int task_id) { + float output_scale = para_->quant_arg.out_quant_args_.scale_; + NNACL_CHECK_ZERO_RETURN_ERR(output_scale); + int8_t output_zp = para_->quant_arg.out_quant_args_.zp_; + float input_scale = para_->quant_arg.in_quant_args_.scale_; + int8_t input_zp = para_->quant_arg.in_quant_args_.zp_; + + for (int i = task_id; i < (int)data_size; i += para_->thread_count_) { + output_ptr[i] = output_zp + round(1 / output_scale * input_scale * (input_ptr[i] - input_zp)); + } + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/unsqueeze_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/unsqueeze_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..1649945d35429d4cae95e1e0d43765f74b92e8c2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/unsqueeze_int8.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_UNSQUEEZE_INT8_H_ +#define NNACL_INT8_UNSQUEEZE_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/unsqueeze_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int Int8Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, const UnSqueezeParameter *para_, size_t data_size, + int task_id); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_UNSQUEEZE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/DeconvMatMulAvx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/DeconvMatMulAvx.c new file mode 100644 index 0000000000000000000000000000000000000000..63b930c24080ac97bd9f8f6408c6a98e4112bded --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/DeconvMatMulAvx.c @@ -0,0 +1,188 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/op_base.h" + +void Deconv4X8AvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) { + __m256 res1 = _mm256_setzero_ps(); + __m256 res4 = _mm256_setzero_ps(); + __m256 res7 = _mm256_setzero_ps(); + __m256 res10 = _mm256_setzero_ps(); + + for (int d = 0; d < depth; ++d) { + __m256 w0 = _mm256_loadu_ps(weight); + __m256 tmp = _mm256_set1_ps(*src); + __m256 tmp1 = _mm256_set1_ps(*(src + C1NUM)); + weight += C8NUM; + __m256 tmp2 = _mm256_set1_ps(*(src + C2NUM)); + __m256 tmp3 = _mm256_set1_ps(*(src + C3NUM)); + res1 = _mm256_fmadd_ps(tmp, w0, res1); + res4 = _mm256_fmadd_ps(tmp1, w0, res4); + src += C4NUM; + res7 = _mm256_fmadd_ps(tmp2, w0, res7); + res10 = _mm256_fmadd_ps(tmp3, w0, res10); + } + // write + _mm256_storeu_ps(dst, res1); + _mm256_storeu_ps(dst + C8NUM, res4); + _mm256_storeu_ps(dst + C16NUM, res7); + _mm256_storeu_ps(dst + C24NUM, res10); +} + +void Deconv4X16AvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) { + __m256 res1 = _mm256_setzero_ps(); + __m256 res2 = _mm256_setzero_ps(); + __m256 res4 = _mm256_setzero_ps(); + __m256 res5 = _mm256_setzero_ps(); + __m256 res7 = _mm256_setzero_ps(); + __m256 res8 = _mm256_setzero_ps(); + __m256 res10 = _mm256_setzero_ps(); + __m256 res11 = _mm256_setzero_ps(); + + for (int d = 0; d < depth; ++d) { + __m256 w0 = _mm256_loadu_ps(weight); + __m256 w1 = _mm256_loadu_ps(weight + C8NUM); + weight += C16NUM; + __m256 tmp = _mm256_set1_ps(*src); + __m256 tmp1 = _mm256_set1_ps(*(src + C1NUM)); + __m256 tmp2 = _mm256_set1_ps(*(src + C2NUM)); + __m256 tmp3 = _mm256_set1_ps(*(src + C3NUM)); + res1 = _mm256_fmadd_ps(tmp, w0, res1); + res2 = _mm256_fmadd_ps(tmp, w1, res2); + src += C4NUM; + res4 = _mm256_fmadd_ps(tmp1, w0, res4); + res5 = _mm256_fmadd_ps(tmp1, w1, res5); + res7 = _mm256_fmadd_ps(tmp2, w0, res7); + res8 = _mm256_fmadd_ps(tmp2, w1, res8); + res10 = _mm256_fmadd_ps(tmp3, w0, res10); + res11 = _mm256_fmadd_ps(tmp3, w1, res11); + } + // write + _mm256_storeu_ps(dst, res1); + _mm256_storeu_ps(dst + C8NUM, res4); + _mm256_storeu_ps(dst + C16NUM, res7); + _mm256_storeu_ps(dst + C24NUM, res10); + + _mm256_storeu_ps(dst + stride, res2); + _mm256_storeu_ps(dst + stride + C8NUM, res5); + _mm256_storeu_ps(dst + stride + C16NUM, res8); + _mm256_storeu_ps(dst + stride + C24NUM, res11); +} + +void Deconv4X24AvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) { + __m256 res1 = _mm256_setzero_ps(); + __m256 res2 = _mm256_setzero_ps(); + __m256 res3 = _mm256_setzero_ps(); + __m256 res4 = _mm256_setzero_ps(); + __m256 res5 = _mm256_setzero_ps(); + __m256 res6 = _mm256_setzero_ps(); + __m256 res7 = _mm256_setzero_ps(); + __m256 res8 = _mm256_setzero_ps(); + __m256 res9 = _mm256_setzero_ps(); + __m256 res10 = _mm256_setzero_ps(); + __m256 res11 = _mm256_setzero_ps(); + __m256 res12 = _mm256_setzero_ps(); + + for (int d = 0; d < depth; ++d) { + __m256 w0 = _mm256_loadu_ps(weight); + __m256 w1 = _mm256_loadu_ps(weight + C8NUM); + __m256 w2 = _mm256_loadu_ps(weight + C16NUM); + __m256 tmp = _mm256_set1_ps(*src); + res1 = _mm256_fmadd_ps(tmp, w0, res1); + res2 = _mm256_fmadd_ps(tmp, w1, res2); + res3 = _mm256_fmadd_ps(tmp, w2, res3); + tmp = _mm256_set1_ps(*(src + C1NUM)); + res4 = _mm256_fmadd_ps(tmp, w0, res4); + res5 = _mm256_fmadd_ps(tmp, w1, res5); + res6 = _mm256_fmadd_ps(tmp, w2, res6); + tmp = _mm256_set1_ps(*(src + C2NUM)); + res7 = _mm256_fmadd_ps(tmp, w0, res7); + res8 = _mm256_fmadd_ps(tmp, w1, res8); + res9 = _mm256_fmadd_ps(tmp, w2, res9); + tmp = _mm256_set1_ps(*(src + C3NUM)); + res10 = _mm256_fmadd_ps(tmp, w0, res10); + res11 = _mm256_fmadd_ps(tmp, w1, res11); + res12 = _mm256_fmadd_ps(tmp, w2, res12); + weight += C24NUM; + src += C4NUM; + } + // write + _mm256_storeu_ps(dst, res1); + _mm256_storeu_ps(dst + C8NUM, res4); + _mm256_storeu_ps(dst + C16NUM, res7); + _mm256_storeu_ps(dst + C24NUM, res10); + + _mm256_storeu_ps(dst + stride, res2); + _mm256_storeu_ps(dst + stride + C8NUM, res5); + _mm256_storeu_ps(dst + stride + C16NUM, res8); + _mm256_storeu_ps(dst + stride + C24NUM, res11); + + _mm256_storeu_ps(dst + C2NUM * stride, res3); + _mm256_storeu_ps(dst + C2NUM * stride + C8NUM, res6); + _mm256_storeu_ps(dst + C2NUM * stride + C16NUM, res9); + _mm256_storeu_ps(dst + C2NUM * stride + C24NUM, res12); +} + +void DeconvMatmulAvx(const float *a, const float *b, float *c, int depth, int row, int col, const int plane) { + NNACL_CHECK_ZERO_RETURN(plane); + int col_num = 0; + int col_block = UP_DIV(col / plane, C8NUM); + DeconvAvxKernel kernel[3] = {Deconv4X8AvxKernel, Deconv4X16AvxKernel, Deconv4X24AvxKernel}; + for (int col_tmp = 0; col_tmp < col_block; col_tmp += col_num) { + col_num = MSMIN(C3NUM, col_block - col_tmp); + for (int p = 0; p < plane; ++p) { + for (int r = 0; r < row; r += C4NUM) { + kernel[col_num - 1](a + r * depth, b + (col_tmp * plane + p * col_num) * C8NUM * depth, + c + (col_tmp * plane + p) * C8NUM * row + r * C8NUM, col_num, C4NUM, depth, + row * C8NUM * plane); + } + } + } +} + +#ifdef ENABLE_DEBUG +void DeconvColXRowAvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, + int stride) { + __m256 res[C12NUM]; + __m256 w[C3NUM]; + for (int i = 0; i < C12NUM; ++i) { + res[i] = _mm256_setzero_ps(); + } + for (int d = 0; d < depth; ++d) { + for (int c = 0; c < col; ++c) { + w[c] = _mm256_loadu_ps(weight + c * C8NUM); + } + weight += col * C8NUM; + for (int r = 0; r < row; ++r) { // C4NUm + __m256 tmp = _mm256_set1_ps(*src); + for (int c = 0; c < col; ++c) { // 3 * C8NUM + res[r * col + c] = _mm256_fmadd_ps(tmp, w[c], res[r * col + c]); + } + src += 1; + } + } + // write + for (int i = 0; i < col; ++i) { + for (int j = 0; j < row; ++j) { + _mm256_storeu_ps(dst + j * C8NUM, res[j * col + i]); + } + dst += stride; + } +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/PostFuncBiasReluC8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/PostFuncBiasReluC8.c new file mode 100644 index 0000000000000000000000000000000000000000..91d7d154a750a358916d6c9e9a1f66ed366cf9fd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/PostFuncBiasReluC8.c @@ -0,0 +1,352 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/intrinsics/avx/common_utils.h" + +void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t stride, size_t relu_type) { + stride /= sizeof(float); + int loop_c8 = 0; + size_t src_stride = plane_size * C8NUM; + for (; loop_c8 <= (int)(oc8div)-C32NUM; loop_c8 += C32NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + __m256 bias3 = _mm256_setzero_ps(); + __m256 bias4 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias3 = _mm256_loadu_ps(bias + C16NUM); + bias4 = _mm256_loadu_ps(bias + C24NUM); + bias += C32NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src9 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src10 = _mm256_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m256 src13 = _mm256_loadu_ps(src + src_stride * C3NUM); + __m256 src14 = _mm256_loadu_ps(src + src_stride * C3NUM + C8NUM); + + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src9 = _mm256_add_ps(src9, bias3); + src10 = _mm256_add_ps(src10, bias3); + src13 = _mm256_add_ps(src13, bias4); + src14 = _mm256_add_ps(src14, bias4); + + ActBlock8Avx(&src1, &src2, &src5, &src6, &src9, &src10, &src13, &src14, relu_type); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src5); + _mm256_storeu_ps(dst_c8 + C16NUM, src9); + _mm256_storeu_ps(dst_c8 + C24NUM, src13); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src2); + _mm256_storeu_ps(dst_c8 + C8NUM, src6); + _mm256_storeu_ps(dst_c8 + C16NUM, src10); + _mm256_storeu_ps(dst_c8 + C24NUM, src14); + dst_c8 += stride; + + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + __m256 src11 = _mm256_loadu_ps(src + src_stride * C2NUM + C16NUM); + __m256 src12 = _mm256_loadu_ps(src + src_stride * C2NUM + C24NUM); + __m256 src15 = _mm256_loadu_ps(src + src_stride * C3NUM + C16NUM); + __m256 src16 = _mm256_loadu_ps(src + src_stride * C3NUM + C24NUM); + src += C32NUM; + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + src11 = _mm256_add_ps(src11, bias3); + src12 = _mm256_add_ps(src12, bias3); + src15 = _mm256_add_ps(src15, bias4); + src16 = _mm256_add_ps(src16, bias4); + + ActBlock8Avx(&src3, &src4, &src7, &src8, &src11, &src12, &src15, &src16, relu_type); + + _mm256_storeu_ps(dst_c8, src3); + _mm256_storeu_ps(dst_c8 + C8NUM, src7); + _mm256_storeu_ps(dst_c8 + C16NUM, src11); + _mm256_storeu_ps(dst_c8 + C24NUM, src15); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src4); + _mm256_storeu_ps(dst_c8 + C8NUM, src8); + _mm256_storeu_ps(dst_c8 + C16NUM, src12); + _mm256_storeu_ps(dst_c8 + C24NUM, src16); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + __m256 src3 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src4 = _mm256_loadu_ps(src + src_stride * C3NUM); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + src3 = _mm256_add_ps(src3, bias3); + src4 = _mm256_add_ps(src4, bias4); + + ActBlock4Avx(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src2); + _mm256_storeu_ps(dst_c8 + C16NUM, src3); + _mm256_storeu_ps(dst_c8 + C24NUM, src4); + dst_c8 += stride; + src += C8NUM; + } + src += C3NUM * src_stride; + } + for (; loop_c8 <= (int)(oc8div)-C24NUM; loop_c8 += C24NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + __m256 bias3 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias3 = _mm256_loadu_ps(bias + C16NUM); + bias += C24NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + __m256 src9 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src10 = _mm256_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m256 src11 = _mm256_loadu_ps(src + src_stride * C2NUM + C16NUM); + __m256 src12 = _mm256_loadu_ps(src + src_stride * C2NUM + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + src9 = _mm256_add_ps(src9, bias3); + src10 = _mm256_add_ps(src10, bias3); + src11 = _mm256_add_ps(src11, bias3); + src12 = _mm256_add_ps(src12, bias3); + + ActBlock12Avx(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, &src9, &src10, &src11, &src12, + relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src5); + _mm256_storeu_ps(dst_c8 + C16NUM, src9); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src2); + _mm256_storeu_ps(dst_c8 + C8NUM, src6); + _mm256_storeu_ps(dst_c8 + C16NUM, src10); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src3); + _mm256_storeu_ps(dst_c8 + C8NUM, src7); + _mm256_storeu_ps(dst_c8 + C16NUM, src11); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src4); + _mm256_storeu_ps(dst_c8 + C8NUM, src8); + _mm256_storeu_ps(dst_c8 + C16NUM, src12); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + __m256 src3 = _mm256_loadu_ps(src + src_stride * C2NUM); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + src3 = _mm256_add_ps(src3, bias3); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + ActBlock1Avx(&src2, relu_type == 1, relu_type == C3NUM); + ActBlock1Avx(&src3, relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src2); + _mm256_storeu_ps(dst_c8 + C16NUM, src3); + dst_c8 += stride; + src += C8NUM; + } + src += C2NUM * src_stride; + } + for (; loop_c8 <= (int)(oc8div)-C16NUM; loop_c8 += C16NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias += C16NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + + ActBlock8Avx(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src5); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src2); + _mm256_storeu_ps(dst_c8 + C8NUM, src6); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src3); + _mm256_storeu_ps(dst_c8 + C8NUM, src7); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src4); + _mm256_storeu_ps(dst_c8 + C8NUM, src8); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + + ActBlock2Avx(&src1, &src2, relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src2); + dst_c8 += stride; + src += C8NUM; + } + src += src_stride; + } + for (; loop_c8 < (int)(oc8div); loop_c8 += C8NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias += C8NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + + ActBlock4Avx(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src2); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src3); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src4); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + src1 = _mm256_add_ps(src1, bias1); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + dst_c8 += stride; + src += C8NUM; + } + } + if (oc8mod == 0) { + return; + } + __m256 bias1 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias += C8NUM; + } + float *dst_c1 = dst + oc8div; + for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1, src += C8NUM, dst_c1 += stride) { + __m256 src1 = _mm256_loadu_ps(src); + src1 = _mm256_add_ps(src1, bias1); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + __m128 src_high = _mm256_extractf128_ps(src1, 1); + + switch (oc8mod) { + case 1: + dst_c1[0] = _mm256_cvtss_f32(src1); + break; + case C2NUM: + _mm_storel_pi((__m64 *)(dst_c1), _mm256_castps256_ps128(src1)); + break; + case C3NUM: + _mm_storel_pi((__m64 *)(dst_c1), _mm256_castps256_ps128(src1)); + dst_c1[C2NUM] = MS_F32X8_GETI(src1, C2NUM); + break; + case C4NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + break; + case C5NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_store_ss(dst_c1 + C4NUM, src_high); + break; + case C6NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_storel_pi((__m64 *)(dst_c1 + C4NUM), src_high); + break; + case C7NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_storel_pi((__m64 *)(dst_c1 + C4NUM), src_high); + dst_c1[C6NUM] = MS_F32X4_GETI(src_high, C2NUM); + break; + default: + break; + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/TiledC8MatMulFp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/TiledC8MatMulFp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7eb313acd76b9540bab72a38eb23542d85091ec0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/TiledC8MatMulFp32.c @@ -0,0 +1,274 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#ifdef _MSC_VER +#include +#else +#include +#endif +#include "nnacl_c/fp32/common_func_fp32.h" + +void TiledC8MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic8, size_t oc8) { + const float *src_tmp = src; + for (int i = 0; i < oc8; ++i) { + src = src_tmp; +#ifndef ENABLE_DEBUG + asm volatile( + "vxorps %%xmm0, %%xmm0, %%xmm0\n" + "vmovaps %%ymm0, %%ymm1\n" + "vmovaps %%ymm0, %%ymm2\n" + "vmovaps %%ymm0, %%ymm3\n" + "vmovaps %%ymm0, %%ymm4\n" + "vmovaps %%ymm0, %%ymm5\n" + "vmovaps %%ymm0, %%ymm6\n" + "vmovaps %%ymm0, %%ymm7\n" + : /* no input */ + : /* no input */ + : "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7"); +#else + register __m256 dst1 asm("ymm0") = _mm256_setzero_ps(); + register __m256 dst2 asm("ymm1") = _mm256_setzero_ps(); + register __m256 dst3 asm("ymm2") = _mm256_setzero_ps(); + register __m256 dst4 asm("ymm3") = _mm256_setzero_ps(); + register __m256 dst5 asm("ymm4") = _mm256_setzero_ps(); + register __m256 dst6 asm("ymm5") = _mm256_setzero_ps(); + register __m256 dst7 asm("ymm6") = _mm256_setzero_ps(); + register __m256 dst8 asm("ymm7") = _mm256_setzero_ps(); +#endif + for (size_t ic8_tmp = 0; ic8_tmp < ic8; ++ic8_tmp) { +#ifndef ENABLE_DEBUG + asm volatile( + // 1 + "vmovups (%1), %%ymm8\n" + + "vbroadcastss (%0), %%ymm9\n" + "vbroadcastss 32(%0), %%ymm10\n" + "vbroadcastss 64(%0), %%ymm11\n" + "vbroadcastss 96(%0), %%ymm12\n" + "vbroadcastss 128(%0), %%ymm13\n" + "vbroadcastss 160(%0), %%ymm14\n" + + "vfmadd231ps %%ymm9, %%ymm8, %%ymm0\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm1\n" + "vfmadd231ps %%ymm11, %%ymm8, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm8, %%ymm3\n" + "vfmadd231ps %%ymm13, %%ymm8, %%ymm4\n" + "vfmadd231ps %%ymm14, %%ymm8, %%ymm5\n" + + "vbroadcastss 192(%0), %%ymm9\n" + "vbroadcastss 224(%0), %%ymm10\n" + "vfmadd231ps %%ymm9, %%ymm8, %%ymm6\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm7\n" + + // 2 + "vmovups 32(%1), %%ymm15\n" + + "vbroadcastss 4(%0), %%ymm11\n" + "vbroadcastss 36(%0), %%ymm12\n" + "vbroadcastss 68(%0), %%ymm13\n" + "vbroadcastss 100(%0), %%ymm14\n" + "vbroadcastss 132(%0), %%ymm9\n" + "vbroadcastss 164(%0), %%ymm10\n" + + "vfmadd231ps %%ymm11, %%ymm15, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm1\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm14, %%ymm15, %%ymm3\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm4\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm5\n" + + "vbroadcastss 196(%0), %%ymm11\n" + "vbroadcastss 228(%0), %%ymm12\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n" + + // 3 + "vmovups 64(%1), %%ymm8\n" + + "vbroadcastss 8(%0), %%ymm13\n" + "vbroadcastss 40(%0), %%ymm14\n" + "vbroadcastss 72(%0), %%ymm9\n" + "vbroadcastss 104(%0), %%ymm10\n" + "vbroadcastss 136(%0), %%ymm11\n" + "vbroadcastss 168(%0), %%ymm12\n" + + "vfmadd231ps %%ymm13, %%ymm8, %%ymm0\n" + "vfmadd231ps %%ymm14, %%ymm8, %%ymm1\n" + "vfmadd231ps %%ymm9, %%ymm8, %%ymm2\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm3\n" + "vfmadd231ps %%ymm11, %%ymm8, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm8, %%ymm5\n" + + "vbroadcastss 200(%0), %%ymm13\n" + "vbroadcastss 232(%0), %%ymm14\n" + "vfmadd231ps %%ymm13, %%ymm8, %%ymm6\n" + "vfmadd231ps %%ymm14, %%ymm8, %%ymm7\n" + + // 4 + "vmovups 96(%1), %%ymm15\n" + + "vbroadcastss 12(%0), %%ymm9\n" + "vbroadcastss 44(%0), %%ymm10\n" + "vbroadcastss 76(%0), %%ymm11\n" + "vbroadcastss 108(%0), %%ymm12\n" + "vbroadcastss 140(%0), %%ymm13\n" + "vbroadcastss 172(%0), %%ymm14\n" + + "vfmadd231ps %%ymm9, %%ymm15, %%ymm0\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm1\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm3\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm4\n" + "vfmadd231ps %%ymm14, %%ymm15, %%ymm5\n" + + "vbroadcastss 204(%0), %%ymm9\n" + "vbroadcastss 236(%0), %%ymm10\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm7\n" + + // 5 + "vmovups 128(%1), %%ymm8\n" + + "vbroadcastss 16(%0), %%ymm11\n" + "vbroadcastss 48(%0), %%ymm12\n" + "vbroadcastss 80(%0), %%ymm13\n" + "vbroadcastss 112(%0), %%ymm14\n" + "vbroadcastss 144(%0), %%ymm9\n" + "vbroadcastss 176(%0), %%ymm10\n" + + "vfmadd231ps %%ymm11, %%ymm8, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm8, %%ymm1\n" + "vfmadd231ps %%ymm13, %%ymm8, %%ymm2\n" + "vfmadd231ps %%ymm14, %%ymm8, %%ymm3\n" + "vfmadd231ps %%ymm9, %%ymm8, %%ymm4\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm5\n" + + "vbroadcastss 208(%0), %%ymm11\n" + "vbroadcastss 240(%0), %%ymm12\n" + "vfmadd231ps %%ymm11, %%ymm8, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm8, %%ymm7\n" + + // 6 + "vmovups 160(%1), %%ymm15\n" + + "vbroadcastss 20(%0), %%ymm13\n" + "vbroadcastss 52(%0), %%ymm14\n" + "vbroadcastss 84(%0), %%ymm9\n" + "vbroadcastss 116(%0), %%ymm10\n" + "vbroadcastss 148(%0), %%ymm11\n" + "vbroadcastss 180(%0), %%ymm12\n" + + "vfmadd231ps %%ymm13, %%ymm15, %%ymm0\n" + "vfmadd231ps %%ymm14, %%ymm15, %%ymm1\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm3\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + + "vbroadcastss 212(%0), %%ymm13\n" + "vbroadcastss 244(%0), %%ymm14\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm14, %%ymm15, %%ymm7\n" + + // 7 + "vmovups 192(%1), %%ymm8\n" + + "vbroadcastss 24(%0), %%ymm9\n" + "vbroadcastss 56(%0), %%ymm10\n" + "vbroadcastss 88(%0), %%ymm11\n" + "vbroadcastss 120(%0), %%ymm12\n" + "vbroadcastss 152(%0), %%ymm13\n" + "vbroadcastss 184(%0), %%ymm14\n" + + "vfmadd231ps %%ymm9, %%ymm8, %%ymm0\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm1\n" + "vfmadd231ps %%ymm11, %%ymm8, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm8, %%ymm3\n" + "vfmadd231ps %%ymm13, %%ymm8, %%ymm4\n" + "vfmadd231ps %%ymm14, %%ymm8, %%ymm5\n" + + "vbroadcastss 216(%0), %%ymm9\n" + "vbroadcastss 248(%0), %%ymm10\n" + "vfmadd231ps %%ymm9, %%ymm8, %%ymm6\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm7\n" + + // 8 + "vmovups 224(%1), %%ymm15\n" + + "vbroadcastss 28(%0), %%ymm11\n" + "vbroadcastss 60(%0), %%ymm12\n" + "vbroadcastss 92(%0), %%ymm13\n" + "vbroadcastss 124(%0), %%ymm14\n" + "vbroadcastss 156(%0), %%ymm9\n" + "vbroadcastss 188(%0), %%ymm10\n" + + "vfmadd231ps %%ymm11, %%ymm15, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm1\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm14, %%ymm15, %%ymm3\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm4\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm5\n" + + "vbroadcastss 220(%0), %%ymm11\n" + "vbroadcastss 252(%0), %%ymm12\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n" + : + : "r"(src), "r"(weight) + : "memory"); +#else + for (int j = 0; j < C8NUM; ++j) { + __m256 weight_data = _mm256_loadu_ps(weight + j * C8NUM); + dst1 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j)), dst1); + dst2 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C8NUM)), dst2); + dst3 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C16NUM)), dst3); + dst4 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C24NUM)), dst4); + dst5 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C32NUM)), dst5); + dst6 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C40NUM)), dst6); + dst7 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C48NUM)), dst7); + dst8 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C56NUM)), dst8); + } +#endif + src += C64NUM; + weight += C64NUM; + } +#ifndef ENABLE_DEBUG + asm volatile( + "vmovups %%ymm0, (%[dst])\n\t" + "vmovups %%ymm1, 32(%[dst])\n\t" + "vmovups %%ymm2, 64(%[dst])\n\t" + "vmovups %%ymm3, 96(%[dst])\n\t" + "vmovups %%ymm4, 128(%[dst])\n\t" + "vmovups %%ymm5, 160(%[dst])\n\t" + "vmovups %%ymm6, 192(%[dst])\n\t" + "vmovups %%ymm7, 224(%[dst])\n\t" + : + : [dst] "r"(dst) + : "memory", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7"); +#else + _mm256_storeu_ps(dst, dst1); + _mm256_storeu_ps(dst + C8NUM, dst2); + _mm256_storeu_ps(dst + C16NUM, dst3); + _mm256_storeu_ps(dst + C24NUM, dst4); + _mm256_storeu_ps(dst + C32NUM, dst5); + _mm256_storeu_ps(dst + C40NUM, dst6); + _mm256_storeu_ps(dst + C48NUM, dst7); + _mm256_storeu_ps(dst + C56NUM, dst8); +#endif + dst += cal_num; + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradPostFuncBiasReluC8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradPostFuncBiasReluC8.c new file mode 100644 index 0000000000000000000000000000000000000000..c67513fc5c3e463090308aadeb2200003648d847 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradPostFuncBiasReluC8.c @@ -0,0 +1,357 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/intrinsics/avx/common_utils.h" + +void WinogradPostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t plane_stride, size_t relu_type) { + size_t stride = oc8div + oc8mod; + plane_stride /= sizeof(float); + int loop_c8 = 0; + size_t src_stride = plane_size * C8NUM + plane_stride; + for (; loop_c8 <= (int)(oc8div)-C32NUM; loop_c8 += C32NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + __m256 bias3 = _mm256_setzero_ps(); + __m256 bias4 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias3 = _mm256_loadu_ps(bias + C16NUM); + bias4 = _mm256_loadu_ps(bias + C24NUM); + bias += C32NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src9 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src10 = _mm256_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m256 src13 = _mm256_loadu_ps(src + src_stride * C3NUM); + __m256 src14 = _mm256_loadu_ps(src + src_stride * C3NUM + C8NUM); + + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src9 = _mm256_add_ps(src9, bias3); + src10 = _mm256_add_ps(src10, bias3); + src13 = _mm256_add_ps(src13, bias4); + src14 = _mm256_add_ps(src14, bias4); + + ActBlock8Avx(&src1, &src2, &src5, &src6, &src9, &src10, &src13, &src14, relu_type); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src5); + _mm256_stream_ps(dst_c8 + C16NUM, src9); + _mm256_stream_ps(dst_c8 + C24NUM, src13); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src2); + _mm256_stream_ps(dst_c8 + C8NUM, src6); + _mm256_stream_ps(dst_c8 + C16NUM, src10); + _mm256_stream_ps(dst_c8 + C24NUM, src14); + dst_c8 += stride; + + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + __m256 src11 = _mm256_loadu_ps(src + src_stride * C2NUM + C16NUM); + __m256 src12 = _mm256_loadu_ps(src + src_stride * C2NUM + C24NUM); + __m256 src15 = _mm256_loadu_ps(src + src_stride * C3NUM + C16NUM); + __m256 src16 = _mm256_loadu_ps(src + src_stride * C3NUM + C24NUM); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + src11 = _mm256_add_ps(src11, bias3); + src12 = _mm256_add_ps(src12, bias3); + src15 = _mm256_add_ps(src15, bias4); + src16 = _mm256_add_ps(src16, bias4); + + ActBlock8Avx(&src3, &src4, &src7, &src8, &src11, &src12, &src15, &src16, relu_type); + + _mm256_stream_ps(dst_c8, src3); + _mm256_stream_ps(dst_c8 + C8NUM, src7); + _mm256_stream_ps(dst_c8 + C16NUM, src11); + _mm256_stream_ps(dst_c8 + C24NUM, src15); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src4); + _mm256_stream_ps(dst_c8 + C8NUM, src8); + _mm256_stream_ps(dst_c8 + C16NUM, src12); + _mm256_stream_ps(dst_c8 + C24NUM, src16); + dst_c8 += stride; + src += C32NUM; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + __m256 src3 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src4 = _mm256_loadu_ps(src + src_stride * C3NUM); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + src3 = _mm256_add_ps(src3, bias3); + src4 = _mm256_add_ps(src4, bias4); + + ActBlock4Avx(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src2); + _mm256_stream_ps(dst_c8 + C16NUM, src3); + _mm256_stream_ps(dst_c8 + C24NUM, src4); + dst_c8 += stride; + src += C8NUM; + } + src += plane_stride; + src += C3NUM * src_stride; + } + for (; loop_c8 <= (int)(oc8div)-C24NUM; loop_c8 += C24NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + __m256 bias3 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias3 = _mm256_loadu_ps(bias + C16NUM); + bias += C24NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + __m256 src9 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src10 = _mm256_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m256 src11 = _mm256_loadu_ps(src + src_stride * C2NUM + C16NUM); + __m256 src12 = _mm256_loadu_ps(src + src_stride * C2NUM + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + src9 = _mm256_add_ps(src9, bias3); + src10 = _mm256_add_ps(src10, bias3); + src11 = _mm256_add_ps(src11, bias3); + src12 = _mm256_add_ps(src12, bias3); + + ActBlock12Avx(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, &src9, &src10, &src11, &src12, + relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src5); + _mm256_stream_ps(dst_c8 + C16NUM, src9); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src2); + _mm256_stream_ps(dst_c8 + C8NUM, src6); + _mm256_stream_ps(dst_c8 + C16NUM, src10); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src3); + _mm256_stream_ps(dst_c8 + C8NUM, src7); + _mm256_stream_ps(dst_c8 + C16NUM, src11); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src4); + _mm256_stream_ps(dst_c8 + C8NUM, src8); + _mm256_stream_ps(dst_c8 + C16NUM, src12); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + __m256 src3 = _mm256_loadu_ps(src + src_stride * C2NUM); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + src3 = _mm256_add_ps(src3, bias3); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + ActBlock1Avx(&src2, relu_type == 1, relu_type == C3NUM); + ActBlock1Avx(&src3, relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src2); + _mm256_stream_ps(dst_c8 + C16NUM, src3); + dst_c8 += stride; + src += C8NUM; + } + src += plane_stride; + src += C2NUM * src_stride; + } + for (; loop_c8 <= (int)(oc8div)-C16NUM; loop_c8 += C16NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias += C16NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + + ActBlock8Avx(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src5); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src2); + _mm256_stream_ps(dst_c8 + C8NUM, src6); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src3); + _mm256_stream_ps(dst_c8 + C8NUM, src7); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src4); + _mm256_stream_ps(dst_c8 + C8NUM, src8); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + + ActBlock2Avx(&src1, &src2, relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src2); + dst_c8 += stride; + src += C8NUM; + } + src += plane_stride; + src += src_stride; + } + for (; loop_c8 < (int)(oc8div); loop_c8 += C8NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias += C8NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + + ActBlock4Avx(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src2); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src3); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src4); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + src1 = _mm256_add_ps(src1, bias1); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + dst_c8 += stride; + src += C8NUM; + } + src += plane_stride; + } + if (oc8mod == 0) { + return; + } + __m256 bias1 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias += C8NUM; + } + float *dst_c1 = dst + oc8div; + for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1, src += C8NUM, dst_c1 += stride) { + __m256 src1 = _mm256_loadu_ps(src); + src1 = _mm256_add_ps(src1, bias1); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + __m128 src_high = _mm256_extractf128_ps(src1, 1); + + switch (oc8mod) { + case 1: + dst_c1[0] = _mm256_cvtss_f32(src1); + break; + case C2NUM: + _mm_storel_pi((__m64 *)(dst_c1), _mm256_castps256_ps128(src1)); + break; + case C3NUM: + _mm_storel_pi((__m64 *)(dst_c1), _mm256_castps256_ps128(src1)); + dst_c1[C2NUM] = MS_F32X8_GETI(src1, C2NUM); + break; + case C4NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + break; + case C5NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_store_ss(dst_c1 + C4NUM, src_high); + break; + case C6NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_storel_pi((__m64 *)(dst_c1 + C4NUM), src_high); + break; + case C7NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_storel_pi((__m64 *)(dst_c1 + C4NUM), src_high); + dst_c1[C6NUM] = MS_F32X4_GETI(src_high, C2NUM); + break; + default: + break; + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradTransAvx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradTransAvx.c new file mode 100644 index 0000000000000000000000000000000000000000..634d19f46aca80c92ce71804ed9d080b903467a8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradTransAvx.c @@ -0,0 +1,355 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" + +void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + size_t len_c8 = length * C8NUM; + size_t S_step = length * w * C8NUM; + for (int h1 = 0; h1 < h; ++h1) { + const float *SW = S; + memset(M, 0, len_c8 * w * sizeof(float)); + for (int w_tmp = w; w_tmp > 0; --w_tmp) { + const float *SK = SW; + const float *BK = B; + int k_tmp = k; + for (; k_tmp >= C8NUM; k_tmp -= C8NUM) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + __m256 k4 = _mm256_set1_ps(*(BK + C3NUM * h)); + __m256 k5 = _mm256_set1_ps(*(BK + C4NUM * h)); + __m256 k6 = _mm256_set1_ps(*(BK + C5NUM * h)); + __m256 k7 = _mm256_set1_ps(*(BK + C6NUM * h)); + __m256 k8 = _mm256_set1_ps(*(BK + C7NUM * h)); + BK += C8NUM * h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += C8NUM, SK += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(SK + S_step); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(SK + C2NUM * S_step); + M1 = _mm256_fmadd_ps(s3, k3, M1); + __m256 s4 = _mm256_loadu_ps(SK + C3NUM * S_step); + M2 = _mm256_fmadd_ps(s4, k4, M2); + __m256 s5 = _mm256_loadu_ps(SK + C4NUM * S_step); + M1 = _mm256_fmadd_ps(s5, k5, M1); + __m256 s6 = _mm256_loadu_ps(SK + C5NUM * S_step); + M2 = _mm256_fmadd_ps(s6, k6, M2); + __m256 s7 = _mm256_loadu_ps(SK + C6NUM * S_step); + M1 = _mm256_fmadd_ps(s7, k7, M1); + __m256 s8 = _mm256_loadu_ps(SK + C7NUM * S_step); + M2 = _mm256_fmadd_ps(s8, k8, M2); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + M -= len_c8; + SK += C8NUM * S_step - len_c8; + } + for (; k_tmp >= C4NUM; k_tmp -= C4NUM) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + __m256 k4 = _mm256_set1_ps(*(BK + C3NUM * h)); + BK += C4NUM * h; + int len_tmp = length; + for (; len_tmp >= C2NUM; len_tmp -= C2NUM, SK += C16NUM, M += C16NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 M3 = _mm256_loadu_ps(M + C8NUM); + __m256 M4 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + __m256 s11 = _mm256_loadu_ps(SK + C8NUM); + __m256 s2 = _mm256_loadu_ps(SK + S_step); + __m256 s22 = _mm256_loadu_ps(SK + S_step + C8NUM); + M1 = _mm256_fmadd_ps(s1, k1, M1); + M2 = _mm256_fmadd_ps(s2, k2, M2); + M3 = _mm256_fmadd_ps(s11, k1, M3); + M4 = _mm256_fmadd_ps(s22, k2, M4); + __m256 s3 = _mm256_loadu_ps(SK + C2NUM * S_step); + __m256 s33 = _mm256_loadu_ps(SK + C2NUM * S_step + C8NUM); + __m256 s4 = _mm256_loadu_ps(SK + C3NUM * S_step); + __m256 s44 = _mm256_loadu_ps(SK + C3NUM * S_step + C8NUM); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M2 = _mm256_fmadd_ps(s4, k4, M2); + M3 = _mm256_fmadd_ps(s33, k3, M3); + M4 = _mm256_fmadd_ps(s44, k4, M4); + M1 = _mm256_add_ps(M1, M2); + M4 = _mm256_add_ps(M3, M4); + _mm256_storeu_ps(M, M1); + _mm256_storeu_ps(M + C8NUM, M4); + } + for (; len_tmp > 0; len_tmp--, SK += C8NUM, M += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(SK + S_step); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(SK + C2NUM * S_step); + M1 = _mm256_fmadd_ps(s3, k3, M1); + __m256 s4 = _mm256_loadu_ps(SK + C3NUM * S_step); + M2 = _mm256_fmadd_ps(s4, k4, M2); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + M -= len_c8; + SK += C4NUM * S_step - len_c8; + } + for (; k_tmp >= C3NUM; k_tmp -= C3NUM) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + BK += C3NUM * h; + int len_tmp = length; + for (; len_tmp >= C3NUM; len_tmp -= C3NUM, SK += C24NUM, M += C24NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 M3 = _mm256_loadu_ps(M + C8NUM); + __m256 M4 = _mm256_set1_ps(0.0f); + __m256 M5 = _mm256_loadu_ps(M + C16NUM); + __m256 M6 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + __m256 s2 = _mm256_loadu_ps(SK + S_step); + __m256 s11 = _mm256_loadu_ps(SK + C8NUM); + __m256 s22 = _mm256_loadu_ps(SK + S_step + C8NUM); + __m256 s111 = _mm256_loadu_ps(SK + C16NUM); + __m256 s222 = _mm256_loadu_ps(SK + S_step + C16NUM); + M1 = _mm256_fmadd_ps(s1, k1, M1); + M2 = _mm256_fmadd_ps(s2, k2, M2); + M3 = _mm256_fmadd_ps(s11, k1, M3); + M4 = _mm256_fmadd_ps(s22, k2, M4); + M5 = _mm256_fmadd_ps(s111, k1, M5); + M6 = _mm256_fmadd_ps(s222, k2, M6); + __m256 s3 = _mm256_loadu_ps(SK + C2NUM * S_step); + __m256 s33 = _mm256_loadu_ps(SK + C2NUM * S_step + C8NUM); + __m256 s333 = _mm256_loadu_ps(SK + C2NUM * S_step + C16NUM); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M3 = _mm256_fmadd_ps(s33, k3, M3); + M5 = _mm256_fmadd_ps(s333, k3, M5); + M1 = _mm256_add_ps(M1, M2); + M3 = _mm256_add_ps(M3, M4); + M5 = _mm256_add_ps(M5, M6); + _mm256_storeu_ps(M, M1); + _mm256_storeu_ps(M + C8NUM, M3); + _mm256_storeu_ps(M + C16NUM, M5); + } + for (; len_tmp > 0; len_tmp--, SK += C8NUM, M += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(SK + S_step); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(SK + C2NUM * S_step); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + M -= len_c8; + SK += C3NUM * S_step - len_c8; + } + for (; k_tmp > 0; k_tmp -= 1) { + __m256 k1 = _mm256_set1_ps(*BK); + BK += h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += C8NUM, M += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 s0 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s0, k1, M1); + _mm256_storeu_ps(M, M1); + } + M -= len_c8; + SK += S_step - len_c8; + } + SW += len_c8; + M += len_c8; + } + B += 1; + } +} + +void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + size_t len_c8 = length * C8NUM, k_step = len_c8 * k; + for (int h1 = 0; h1 < h; ++h1, S += k_step) { + const float *BW = B; + memset(M, 0, len_c8 * w * sizeof(float)); + for (int ww = 0; ww < w; ++ww, BW += 1, M += len_c8) { + const float *SK = S, *BK = BW; + int k_tmp = k; + for (; k_tmp >= C8NUM; k_tmp -= C8NUM, M -= len_c8) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + __m256 k4 = _mm256_set1_ps(*(BK + C3NUM * h)); + __m256 k5 = _mm256_set1_ps(*(BK + C4NUM * h)); + __m256 k6 = _mm256_set1_ps(*(BK + C5NUM * h)); + __m256 k7 = _mm256_set1_ps(*(BK + C6NUM * h)); + __m256 k8 = _mm256_set1_ps(*(BK + C7NUM * h)); + BK += C8NUM * h; + const float *S2 = SK + len_c8, *S3 = S2 + len_c8; + const float *S4 = S3 + len_c8, *S5 = S4 + len_c8; + const float *S6 = S5 + len_c8, *S7 = S6 + len_c8, *S8 = S7 + len_c8; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += C8NUM, SK += C8NUM, S2 += C8NUM, S3 += C8NUM, + S4 += C8NUM, S5 += C8NUM, S6 += C8NUM, S7 += C8NUM, S8 += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(S2); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(S3); + M1 = _mm256_fmadd_ps(s3, k3, M1); + __m256 s4 = _mm256_loadu_ps(S4); + M2 = _mm256_fmadd_ps(s4, k4, M2); + __m256 s5 = _mm256_loadu_ps(S5); + M1 = _mm256_fmadd_ps(s5, k5, M1); + __m256 s6 = _mm256_loadu_ps(S6); + M2 = _mm256_fmadd_ps(s6, k6, M2); + __m256 s7 = _mm256_loadu_ps(S7); + M1 = _mm256_fmadd_ps(s7, k7, M1); + __m256 s8 = _mm256_loadu_ps(S8); + M2 = _mm256_fmadd_ps(s8, k8, M2); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + SK = S7; + } + for (; k_tmp >= C4NUM; k_tmp -= C4NUM, M -= len_c8) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + __m256 k4 = _mm256_set1_ps(*(BK + C3NUM * h)); + BK += C4NUM * h; + const float *S2 = SK + len_c8; + const float *S3 = S2 + len_c8; + const float *S4 = S3 + len_c8; + int len_tmp = length; + for (; len_tmp >= C2NUM; + len_tmp -= C2NUM, M += C16NUM, SK += C16NUM, S2 += C16NUM, S3 += C16NUM, S4 += C16NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 M3 = _mm256_loadu_ps(M + C8NUM); + __m256 M4 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + __m256 s2 = _mm256_loadu_ps(S2); + __m256 s11 = _mm256_loadu_ps(SK + C8NUM); + __m256 s22 = _mm256_loadu_ps(S2 + C8NUM); + M1 = _mm256_fmadd_ps(s1, k1, M1); + M2 = _mm256_fmadd_ps(s2, k2, M2); + M3 = _mm256_fmadd_ps(s11, k1, M3); + M4 = _mm256_fmadd_ps(s22, k2, M4); + + __m256 s3 = _mm256_loadu_ps(S3); + __m256 s4 = _mm256_loadu_ps(S4); + __m256 s33 = _mm256_loadu_ps(S3 + C8NUM); + __m256 s44 = _mm256_loadu_ps(S4 + C8NUM); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M2 = _mm256_fmadd_ps(s4, k4, M2); + M3 = _mm256_fmadd_ps(s33, k3, M3); + M4 = _mm256_fmadd_ps(s44, k4, M4); + + M1 = _mm256_add_ps(M1, M2); + M3 = _mm256_add_ps(M3, M4); + _mm256_storeu_ps(M, M1); + _mm256_storeu_ps(M + C8NUM, M3); + } + for (; len_tmp > 0; len_tmp--, M += C8NUM, SK += C8NUM, S2 += C8NUM, S3 += C8NUM, S4 += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(S2); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(S3); + M1 = _mm256_fmadd_ps(s3, k3, M1); + __m256 s4 = _mm256_loadu_ps(S4); + M2 = _mm256_fmadd_ps(s4, k4, M2); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + SK = S4; + } + for (; k_tmp >= C3NUM; k_tmp -= C3NUM, M -= len_c8) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + BK += C3NUM * h; + const float *S2 = SK + len_c8; + const float *S3 = S2 + len_c8; + int len_tmp = length; + for (; len_tmp >= C3NUM; len_tmp -= C3NUM, M += C24NUM, SK += C24NUM, S2 += C24NUM, S3 += C24NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 M3 = _mm256_loadu_ps(M + C8NUM); + __m256 M4 = _mm256_set1_ps(0.0f); + __m256 M5 = _mm256_loadu_ps(M + C16NUM); + __m256 M6 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + __m256 s2 = _mm256_loadu_ps(S2); + __m256 s11 = _mm256_loadu_ps(SK + C8NUM); + __m256 s22 = _mm256_loadu_ps(S2 + C8NUM); + __m256 s111 = _mm256_loadu_ps(SK + C16NUM); + __m256 s222 = _mm256_loadu_ps(S2 + C16NUM); + M1 = _mm256_fmadd_ps(s1, k1, M1); + M2 = _mm256_fmadd_ps(s2, k2, M2); + M3 = _mm256_fmadd_ps(s11, k1, M3); + M4 = _mm256_fmadd_ps(s22, k2, M4); + M5 = _mm256_fmadd_ps(s111, k1, M5); + M6 = _mm256_fmadd_ps(s222, k2, M6); + __m256 s3 = _mm256_loadu_ps(S3); + __m256 s33 = _mm256_loadu_ps(S3 + C8NUM); + __m256 s333 = _mm256_loadu_ps(S3 + C16NUM); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M3 = _mm256_fmadd_ps(s33, k3, M3); + M5 = _mm256_fmadd_ps(s333, k3, M5); + M1 = _mm256_add_ps(M1, M2); + M3 = _mm256_add_ps(M3, M4); + M5 = _mm256_add_ps(M6, M5); + _mm256_storeu_ps(M, M1); + _mm256_storeu_ps(M + C8NUM, M3); + _mm256_storeu_ps(M + C16NUM, M5); + } + for (; len_tmp > 0; len_tmp--, M += C8NUM, SK += C8NUM, S2 += C8NUM, S3 += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(S2); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(S3); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + SK = S3; + } + for (; k_tmp >= 1; k_tmp -= 1, M -= len_c8) { + __m256 k1 = _mm256_set1_ps(*BK); + BK += h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += C8NUM, SK += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + _mm256_storeu_ps(M, M1); + } + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/common_utils.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/common_utils.c new file mode 100644 index 0000000000000000000000000000000000000000..8335748ae8d31dce56c6414216b3d123f5cbcfa8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/common_utils.c @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/intrinsics/avx/common_utils.h" +#include + +__m128i _mm_adds_epi32(__m128i a, __m128i b) { + __m128i int_min = _mm_set1_epi32(0x80000000); + __m128i int_max = _mm_set1_epi32(0x7FFFFFFF); + + const __m128i res = _mm_add_epi32(a, b); + const __m128i sign_and = _mm_and_si128(a, b); + const __m128i sign_or = _mm_or_si128(a, b); + + const __m128i min_sat_mask = _mm_andnot_si128(res, sign_and); + const __m128i max_sat_mask = _mm_andnot_si128(sign_or, res); + const __m128 res_temp = + _mm_blendv_ps(_mm_castsi128_ps(res), _mm_castsi128_ps(int_min), _mm_castsi128_ps(min_sat_mask)); + return _mm_castps_si128(_mm_blendv_ps(res_temp, _mm_castsi128_ps(int_max), _mm_castsi128_ps(max_sat_mask))); +} + +__m128i _mm_rshr_epi32(__m128i a, int shift) { + const __m128i vmask = _mm_cmpgt_epi32(_mm_setzero_si128(), a); + const __m128i vabs_a = _mm_sub_epi32(_mm_xor_si128(a, vmask), vmask); + const __m128i tmp_res = _mm_srli_epi32(vabs_a, shift); + return _mm_xor_si128(tmp_res, vmask); +} + +__m128i _mm_qrdmulh_epi32(__m128i a, __m128i b) { + const __m128i tmp_a_lo = _mm_unpacklo_epi32(a, _mm_setzero_si128()); + const __m128i tmp_a_hi = _mm_unpackhi_epi32(a, _mm_setzero_si128()); + const __m256i tmp_a_256 = _mm256_set_m128i(tmp_a_hi, tmp_a_lo); + const __m128i tmp_b_lo = _mm_unpacklo_epi32(b, _mm_setzero_si128()); + const __m128i tmp_b_hi = _mm_unpackhi_epi32(b, _mm_setzero_si128()); + const __m256i tmp_b_256 = _mm256_set_m128i(tmp_b_hi, tmp_b_lo); + __m256i tmp_out = _mm256_mul_epi32(tmp_a_256, tmp_b_256); + tmp_out = _mm256_add_epi64(tmp_out, _mm256_set1_epi64x(1ll << 30)); + const __m256i vmask = _mm256_cmpgt_epi64(_mm256_setzero_si256(), tmp_out); + const __m256i vabs_tmp_out = _mm256_sub_epi64(_mm256_xor_si256(tmp_out, vmask), vmask); + tmp_out = _mm256_srli_epi64(vabs_tmp_out, 31); + const __m256i vtmp_out = _mm256_sub_epi64(_mm256_xor_si256(tmp_out, vmask), vmask); + const int32_t max_32bit = (1ll << 31) - 1; + const int32_t min_32bit = -(1ll << 31); + int64_t *tmp_out_ptr = (int64_t *)(&vtmp_out); + int32_t r1 = tmp_out_ptr[0] > max_32bit ? max_32bit : tmp_out_ptr[0]; + r1 = r1 < min_32bit ? min_32bit : r1; + int32_t r2 = tmp_out_ptr[1] > max_32bit ? max_32bit : tmp_out_ptr[1]; + r2 = r2 < min_32bit ? min_32bit : r2; + int32_t r3 = tmp_out_ptr[2] > max_32bit ? max_32bit : tmp_out_ptr[2]; + r3 = r3 < min_32bit ? min_32bit : r3; + int32_t r4 = tmp_out_ptr[3] > max_32bit ? max_32bit : tmp_out_ptr[3]; + r4 = r4 < min_32bit ? min_32bit : r4; + return _mm_set_epi32(r4, r3, r2, r1); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/common_utils.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/common_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..0b80a83a1440fb06a6cd1ea9ab9a97a669a138d6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/common_utils.h @@ -0,0 +1,157 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_X86_64_AVX_COMMON_UTILS_H_ +#define MINDSPORE_NNACL_X86_64_AVX_COMMON_UTILS_H_ + +#ifdef _MSC_VER +#include +#else +#include +#endif + +#ifdef __cplusplus +extern "C" { +#endif +#ifdef __GNUC__ +#if __GNUC__ < 8 +#define _mm256_set_m128i(xmm1, xmm2) \ + _mm256_permute2f128_si256(_mm256_castsi128_si256(xmm1), _mm256_castsi128_si256(xmm2), 2) +#define _mm256_set_m128f(xmm1, xmm2) \ + _mm256_permute2f128_ps(_mm256_castps128_ps256(xmm1), _mm256_castps128_ps256(xmm2), 2) +#endif +#endif + +#define AVX_ACT_RELU 1 +#define AVX_ACT_RELU6 3 + +// Signed saturating Add +__m128i _mm_adds_epi32(__m128i a, __m128i b); + +// Signed rounding shift right +__m128i _mm_rshr_epi32(__m128i a, int shift); + +// Signed saturating Rounding Doubling Multiply return High half +__m128i _mm_qrdmulh_epi32(__m128i a, __m128i b); + +static inline void ActBlock1Avx(__m256 *v1, size_t relu, size_t relu6) { + __m256 zero_ma = _mm256_setzero_ps(); + __m256 relu6_ma = _mm256_set1_ps(6.0f); + if (relu || relu6) { + *v1 = _mm256_max_ps(zero_ma, *v1); + } + if (relu6) { + *v1 = _mm256_min_ps(relu6_ma, *v1); + } +} + +static inline void ActBlock2Avx(__m256 *v1, __m256 *v2, size_t relu, size_t relu6) { + __m256 zero_ma = _mm256_setzero_ps(); + __m256 relu6_ma = _mm256_set1_ps(6.0f); + if (relu || relu6) { + *v1 = _mm256_max_ps(zero_ma, *v1); + *v2 = _mm256_max_ps(zero_ma, *v2); + } + if (relu6) { + *v1 = _mm256_min_ps(relu6_ma, *v1); + *v2 = _mm256_min_ps(relu6_ma, *v2); + } +} + +static inline void ActBlock4Avx(__m256 *v1, __m256 *v2, __m256 *v3, __m256 *v4, size_t relu, size_t relu6) { + __m256 zero_ma = _mm256_setzero_ps(); + __m256 relu6_ma = _mm256_set1_ps(6.0f); + if (relu || relu6) { + *v1 = _mm256_max_ps(zero_ma, *v1); + *v2 = _mm256_max_ps(zero_ma, *v2); + *v3 = _mm256_max_ps(zero_ma, *v3); + *v4 = _mm256_max_ps(zero_ma, *v4); + } + if (relu6) { + *v1 = _mm256_min_ps(relu6_ma, *v1); + *v2 = _mm256_min_ps(relu6_ma, *v2); + *v3 = _mm256_min_ps(relu6_ma, *v3); + *v4 = _mm256_min_ps(relu6_ma, *v4); + } +} + +static inline void ActBlock8Avx(__m256 *v1, __m256 *v2, __m256 *v3, __m256 *v4, __m256 *v5, __m256 *v6, __m256 *v7, + __m256 *v8, size_t relu_type) { + __m256 relu6 = _mm256_set1_ps(6.0); + __m256 zero = _mm256_setzero_ps(); + switch (relu_type) { + case AVX_ACT_RELU6: + *v1 = _mm256_min_ps(*v1, relu6); + *v2 = _mm256_min_ps(*v2, relu6); + *v3 = _mm256_min_ps(*v3, relu6); + *v4 = _mm256_min_ps(*v4, relu6); + *v5 = _mm256_min_ps(*v5, relu6); + *v6 = _mm256_min_ps(*v6, relu6); + *v7 = _mm256_min_ps(*v7, relu6); + *v8 = _mm256_min_ps(*v8, relu6); + case AVX_ACT_RELU: + *v1 = _mm256_max_ps(*v1, zero); + *v2 = _mm256_max_ps(*v2, zero); + *v3 = _mm256_max_ps(*v3, zero); + *v4 = _mm256_max_ps(*v4, zero); + *v5 = _mm256_max_ps(*v5, zero); + *v6 = _mm256_max_ps(*v6, zero); + *v7 = _mm256_max_ps(*v7, zero); + *v8 = _mm256_max_ps(*v8, zero); + break; + default: + break; + } +} + +static inline void ActBlock12Avx(__m256 *v1, __m256 *v2, __m256 *v3, __m256 *v4, __m256 *v5, __m256 *v6, __m256 *v7, + __m256 *v8, __m256 *v9, __m256 *v10, __m256 *v11, __m256 *v12, size_t relu, + size_t relu6) { + if (relu || relu6) { + __m256 zero_ma = _mm256_setzero_ps(); + *v1 = _mm256_max_ps(zero_ma, *v1); + *v2 = _mm256_max_ps(zero_ma, *v2); + *v3 = _mm256_max_ps(zero_ma, *v3); + *v4 = _mm256_max_ps(zero_ma, *v4); + *v5 = _mm256_max_ps(zero_ma, *v5); + *v6 = _mm256_max_ps(zero_ma, *v6); + *v7 = _mm256_max_ps(zero_ma, *v7); + *v8 = _mm256_max_ps(zero_ma, *v8); + *v9 = _mm256_max_ps(zero_ma, *v9); + *v10 = _mm256_max_ps(zero_ma, *v10); + *v11 = _mm256_max_ps(zero_ma, *v11); + *v12 = _mm256_max_ps(zero_ma, *v12); + } + if (relu6) { + __m256 relu6_ma = _mm256_set1_ps(6.0f); + *v1 = _mm256_min_ps(relu6_ma, *v1); + *v2 = _mm256_min_ps(relu6_ma, *v2); + *v3 = _mm256_min_ps(relu6_ma, *v3); + *v4 = _mm256_min_ps(relu6_ma, *v4); + *v5 = _mm256_min_ps(relu6_ma, *v5); + *v6 = _mm256_min_ps(relu6_ma, *v6); + *v7 = _mm256_min_ps(relu6_ma, *v7); + *v8 = _mm256_min_ps(relu6_ma, *v8); + *v9 = _mm256_min_ps(relu6_ma, *v9); + *v10 = _mm256_min_ps(relu6_ma, *v10); + *v11 = _mm256_min_ps(relu6_ma, *v11); + *v12 = _mm256_min_ps(relu6_ma, *v12); + } +} + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_X86_64_AVX_COMMON_UTILS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_avx512_instructions.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_avx512_instructions.h new file mode 100644 index 0000000000000000000000000000000000000000..5918725b66bc561cf11a3bbc29347c4bf0ee678e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_avx512_instructions.h @@ -0,0 +1,446 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#define NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include +#include + +#ifdef _MSC_VER +#include +#define MS_F32X16_GETI(src, i) src.m512_f32[i] +#define MS512_F32_GETI(src, i) src.m512_f32[i] +#else +#include +#define MS_F32X16_GETI(src, i) src[i] +#define MS512_F32_GETI(src, i) src[i] +#endif + +#pragma GCC push_options +#pragma GCC target("avx512f") + +#define PI 3.1415926f +#define LN2 0.693147f + +#define MS_FLOAT32X16 __m512 +#define MS_FLOAT512_F32 __m512 +#define MS_INT32X16 __m512i +#define MS_INT512_EPI32 __m512i +#define MS_MASK512_TYPE __mmask16 +#define MS_LD512_F32 _mm512_loadu_ps +#define MS_LD512_EPI32(src) _mm512_loadu_si512((__m512i const *)(src)) +#define MS_LD512_HALF_EPI32(src) _mm256_loadu_si256((__m256i const *)(src)) +#define MS_ADD512_F32 _mm512_add_ps +#define MS_ADD512_EPI32 _mm512_add_epi32 +#define MS_MOV512_F32 _mm512_set1_ps +#define MS_MOV512_EPI32 _mm512_set1_epi32 +#define MS_MOV512_VAL0_F32 _mm512_setzero_ps() +#define MS_MLA512_F32(src1, src2, src3) _mm512_fmadd_ps(src2, src3, src1) +#define MS_ST512_F32 _mm512_storeu_ps +#define MS_ST512_EPI32(src1, src2) _mm512_storeu_si512((__m512i *)(src1), src2) +#define MS_ST512_HALF_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2) +#define MS_SUB512_F32 _mm512_sub_ps +#define MS_SUB512_EPI32 _mm512_sub_epi32 +#define MS_MAX512_F32 _mm512_max_ps +#define MS_MAX512_EPI32 _mm512_max_epi32 +#define MS_MIN512_F32 _mm512_min_ps +#define MS_MIN512_EPI32 _mm512_min_epi32 +#define MS_SQRT512_F32 _mm512_sqrt_ps +#define MS_RSQRT512_F32 _mm512_rsqrt14_ps +#define MS_SIN512_F32 _mm512_sin_ps +#define MS_ERF512_F32 _mm512_erf_ps +#define MS_ABS512_F32 _mm512_abs_ps +#define MS_ABS512_EPI32 _mm512_abs_epi32 + +#define MS_ROUND512_F32(src) \ + _mm512_add_round_ps(src, _mm512_set1_ps(0.0f), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) +#define MS_FLOOR512_F32 _mm512_floor_ps +#define MS_CEIL512_F32 _mm512_ceil_ps +#define MS_MUL512_F32(src1, src2) _mm512_mul_ps(src1, src2) +#define MS_MUL512_EPI32(src1, src2) _mm512_mullo_epi32(src1, src2) +#define MS_FMADD512_F32(src1, src2, src3) _mm512_fmadd_ps(src1, src2, src3) +#define MS_FMSUB512_F32(src1, src2, src3) _mm512_fmsub_ps(src1, src2, src3) +#define MS_FSMUL512_F32(src1, src2, src3) _mm512_fnmadd_ps(src3, src2, src1) // src1 - src2 * src3 +#define MS_DIV512_F32(src1, src2) _mm512_div_ps(src1, src2) +#define MS_MUL512_N_F32(src1, src2) _mm512_mul_ps(src1, _mm512_set1_ps(src2)) +#define MS_MUL512_N_EPI32(src1, src2) _mm512_mullo_epi32(src1, _mm512_set1_epi32(src2)) +#define MS_DIV512_N_F32(src1, src2) _mm512_div_ps(src1, _mm512_set1_ps(src2)) +#define MS_SLLI512_EPI32(src1, src2) _mm512_slli_epi32(src1, src2) +#define MS_CVT512PS_EPI32(src) _mm512_cvttps_epi32(src) +#define MS_CVT512EPI32_PS(src) _mm512_cvtepi32_ps(src) // truncate float to int +#define MS_CMP512_F32(src1, src2, src3) _mm512_cmp_ps_mask(src1, src2, src3) +#define MS_CMPGT512_F32(src1, src2) _mm512_cmp_ps_mask(src1, src2, 30) +#define MS_CMPLE512_F32(src1, src2) _mm512_cmp_ps_mask(src1, src2, 18) +#define MS_CMPLT512_F32(src1, src2) _mm512_cmp_ps_mask(src1, src2, 17) +#define MS_CMPGT512_EPI32(src1, src2) _mm512_cmpgt_epi32(src1, src2) +#define MS_BLEND512_F32(src1, src2, mask) _mm512_mask_blend_ps(mask, src1, src2) +#define MS_BLEND512_EPI32(src1, src2, mask) _mm512_mask_blend_epi32(mask, src1, src2) +#define MS_CAST512_F32_S32(src) _mm512_castsi512_ps(src) +#define MS_REDUCE_ADD512_F32(src) _mm512_reduce_add_ps(src) +#define MS_GET_MAX512_F32(src) _mm512_reduce_max_ps(src) +#define MS_GET_MIN512_F32(src) _mm512_reduce_min_ps(src) +#define MS_GET_SUM512_F32(src) _mm512_reduce_add_ps(src) +#define MS_AND512_MASK(src1, src2) _mm512_kand(src1, src2) + +#define MS512_SRLI_EPI32(src1, src2) _mm512_srli_epi32(src1, src2) +#define MS512_AND_EPI32(src1, src2) _mm512_and_si512(src1, src2) +#define MS512_CASTPS_EPI32(src) _mm512_castps_si512(src) +#define MS_OR512_EPI32(src1, src2) _mm512_or_epi32(src1, src2) +#define MS_AND512_EPI32(src1, src2) _mm512_and_epi32(src1, src2) +#define MS_AND512_F32(src1, src2) \ + _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(src1), _mm512_castps_si512(src2))) + +static inline MS_FLOAT512_F32 SIMD_SIGN512_F32(MS_FLOAT512_F32 src) { + MS_FLOAT512_F32 abs_src = MS_ABS512_F32(src); + MS_FLOAT512_F32 sign = MS_DIV512_F32(abs_src, src); + return sign; +} + +#define SIMD_SIGNABS512_F32(src, abs_src) MS_DIV512_F32(abs_src, src) + +static inline MS_FLOAT512_F32 MS_OR512_F32(MS_FLOAT512_F32 src1, MS_FLOAT512_F32 src2) { + MS_FLOAT512_F32 result = MS_CAST512_F32_S32(MS_OR512_EPI32(MS512_CASTPS_EPI32(src1), MS512_CASTPS_EPI32(src2))); + return result; +} + +static inline MS_FLOAT512_F32 MS512_ANDNOT_F32(MS_FLOAT512_F32 src1, MS_FLOAT512_F32 src2) { + MS_FLOAT512_F32 result = MS_CAST512_F32_S32(MS_AND512_EPI32(~MS512_CASTPS_EPI32(src1), MS512_CASTPS_EPI32(src2))); + return result; +} + +static inline MS_FLOAT512_F32 MS_AND512_MASK_F32(MS_MASK512_TYPE mask, MS_FLOAT512_F32 value) { + /* mask = T ? value ; 0 */ + MS_FLOAT512_F32 zeros = _mm512_set1_ps(0.0f); + return _mm512_mask_blend_ps(mask, zeros, value); +} + +static inline MS_FLOAT32X16 MS_POW512_F32(MS_FLOAT32X16 src1, MS_FLOAT32X16 src2) { + MS_FLOAT32X16 dst; + MS512_F32_GETI(dst, 0) = powf(MS512_F32_GETI(src1, 0), MS512_F32_GETI(src2, 0)); + MS512_F32_GETI(dst, 1) = powf(MS512_F32_GETI(src1, 1), MS512_F32_GETI(src2, 1)); + MS512_F32_GETI(dst, 2) = powf(MS512_F32_GETI(src1, 2), MS512_F32_GETI(src2, 2)); + MS512_F32_GETI(dst, 3) = powf(MS512_F32_GETI(src1, 3), MS512_F32_GETI(src2, 3)); + MS512_F32_GETI(dst, 4) = powf(MS512_F32_GETI(src1, 4), MS512_F32_GETI(src2, 4)); + MS512_F32_GETI(dst, 5) = powf(MS512_F32_GETI(src1, 5), MS512_F32_GETI(src2, 5)); + MS512_F32_GETI(dst, 6) = powf(MS512_F32_GETI(src1, 6), MS512_F32_GETI(src2, 6)); + MS512_F32_GETI(dst, 7) = powf(MS512_F32_GETI(src1, 7), MS512_F32_GETI(src2, 7)); + MS512_F32_GETI(dst, 8) = powf(MS512_F32_GETI(src1, 8), MS512_F32_GETI(src2, 8)); + MS512_F32_GETI(dst, 9) = powf(MS512_F32_GETI(src1, 9), MS512_F32_GETI(src2, 9)); + MS512_F32_GETI(dst, 10) = powf(MS512_F32_GETI(src1, 10), MS512_F32_GETI(src2, 10)); + MS512_F32_GETI(dst, 11) = powf(MS512_F32_GETI(src1, 11), MS512_F32_GETI(src2, 11)); + MS512_F32_GETI(dst, 12) = powf(MS512_F32_GETI(src1, 12), MS512_F32_GETI(src2, 12)); + MS512_F32_GETI(dst, 13) = powf(MS512_F32_GETI(src1, 13), MS512_F32_GETI(src2, 13)); + MS512_F32_GETI(dst, 14) = powf(MS512_F32_GETI(src1, 14), MS512_F32_GETI(src2, 14)); + MS512_F32_GETI(dst, 15) = powf(MS512_F32_GETI(src1, 15), MS512_F32_GETI(src2, 15)); + return dst; +} + +static inline MS_FLOAT32X16 MS_COS512_F32(MS_FLOAT32X16 src) { + static const MS_FLOAT32X16 pi = {PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI}; + static const MS_FLOAT32X16 pi2_neg = {-2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, + -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI}; + static const MS_FLOAT32X16 div_pi2 = {1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), + 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), + 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), + 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI)}; + MS_FLOAT512_F32 src_abs = MS_ABS512_F32(src); + MS_FLOAT512_F32 src_cycle = + MS_ADD512_F32(MS_MUL512_F32(MS_FLOOR512_F32(MS_MUL512_F32(MS_ADD512_F32(src_abs, pi), div_pi2)), pi2_neg), src_abs); + static const MS_FLOAT512_F32 data0 = {1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, + 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, + 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90}; + static const MS_FLOAT512_F32 data1 = {1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, + 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, + 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56}; + static const MS_FLOAT512_F32 data2 = {1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, + 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, + 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30}; + static const MS_FLOAT512_F32 data3 = {1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, + 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, + 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12}; + static const MS_FLOAT512_F32 data4 = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; + static const MS_FLOAT32X16 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X16 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + + MS_FLOAT32X16 square = MS_MUL512_F32(src_cycle, src_cycle); + + MS_FLOAT32X16 tmp = + MS_MUL512_F32(MS_MUL512_F32(MS_ADD512_F32(MS_MUL512_F32(MS_MUL512_F32(neg, square), data0), pos), square), data1); + MS_FLOAT32X16 tmp1 = MS_MUL512_F32(MS_MUL512_F32(MS_ADD512_F32(tmp, neg), square), data2); + MS_FLOAT512_F32 res = MS_ADD512_F32( + MS_MUL512_F32( + MS_MUL512_F32(MS_ADD512_F32(MS_MUL512_F32(MS_MUL512_F32(MS_ADD512_F32(tmp1, pos), square), data3), neg), square), + data4), + pos); + return res; +} + +static inline MS_FLOAT32X16 MS512_LOG_F32(MS_FLOAT32X16 src) { + const MS_INT512_EPI32 gFloatExpMask = MS_MOV512_EPI32(0xffULL << 23); + const MS_INT512_EPI32 gFloatExp0 = MS_MOV512_EPI32(127ULL << 23); + const MS_INT512_EPI32 gExpNormalizer = MS_MOV512_EPI32(127); + static const MS_FLOAT512_F32 data0 = {1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, + 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, + 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11}; + static const MS_FLOAT512_F32 data1 = {1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, + 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9}; + static const MS_FLOAT512_F32 data2 = {1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, + 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7}; + static const MS_FLOAT512_F32 data3 = {0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, + 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f}; + static const MS_FLOAT512_F32 data4 = {1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, + 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3}; + static const MS_FLOAT512_F32 data5 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT512_F32 data6 = {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, + 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}; + static const MS_FLOAT32X16 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X16 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT32X16 ln2 = {LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2}; + + const MS_INT512_EPI32 exps32 = MS512_SRLI_EPI32(MS512_AND_EPI32(gFloatExpMask, MS512_CASTPS_EPI32(src)), 23); + const MS_INT512_EPI32 normExps = MS_SUB512_EPI32(exps32, gExpNormalizer); + const MS_FLOAT32X16 expsPD = MS_CVT512EPI32_PS(normExps); + const MS_FLOAT32X16 y = + MS_OR512_F32(MS_CAST512_F32_S32(gFloatExp0), MS512_ANDNOT_F32(MS_CAST512_F32_S32(gFloatExpMask), src)); + MS_FLOAT32X16 div = MS_DIV512_F32(MS_ADD512_F32(y, neg), MS_ADD512_F32(y, pos)); + MS_FLOAT32X16 square = MS_MUL512_F32(div, div); + + MS_FLOAT32X16 tmp = MS_ADD512_F32( + MS_MUL512_F32(MS_ADD512_F32(MS_MUL512_F32(square, MS_ADD512_F32(MS_MUL512_F32(square, data0), data1)), data2), + square), + data3); + MS_FLOAT32X16 tmp1 = MS_MUL512_F32(square, MS_ADD512_F32(MS_MUL512_F32(square, tmp), data4)); + MS_FLOAT32X16 res = + MS_ADD512_F32(MS_MUL512_F32(ln2, expsPD), MS_MUL512_F32(MS_MUL512_F32(div, MS_ADD512_F32(tmp1, data5)), data6)); + MS_MASK512_TYPE mask = MS_CMP512_F32(src, MS_MOV512_F32(0.0f), _CMP_EQ_OQ); + res = MS_BLEND512_F32(res, MS_MOV512_F32(-INFINITY), mask); + mask = MS_CMP512_F32(src, MS_MOV512_F32(INFINITY), _CMP_EQ_OQ); + res = MS_BLEND512_F32(res, MS_MOV512_F32(INFINITY), mask); + mask = MS_CMPLT512_F32(src, MS_MOV512_F32(0.0f)); + res = MS_BLEND512_F32(res, MS_MOV512_F32(NAN), mask); + mask = MS_CMP512_F32(src, MS_MOV512_F32(0.0f), _CMP_UNORD_Q); + res = MS_BLEND512_F32(res, MS_MOV512_F32(NAN), mask); + return res; +} + +#define MS_DIV512_EPI32(src1, src2) \ + _mm512_cvttps_epi32(MS_DIV512_F32(_mm512_cvtepi32_ps(src1), _mm512_cvtepi32_ps(src2))) + +#define MS512_INT16_TO_FLOAT16(src) _mm512_cvtepi16_ph(src) +#define MS512_FLOAT16_TO_INT16(src) _mm512_cvttph_epi16(src) + +#define MS512_INT32_TO_FLOAT16(src) _mm512_cvtepi32_ph(src) +#define MS512_FLOAT16_TO_INT32(src) _mm512_cvttph_epi32(src) + +#define MS512_INT32_TO_FLOAT32(src) _mm512_cvtepi32_ps(src) +#define MS512_FLOAT32_TO_INT32(src) _mm512_cvttps_epi32(src) +#define MS512_FLOAT16_TO_FLOAT32(src) _mm512_cvtph_ps(src) +#define MS512_FLOAT32_TO_FLOAT16(src1, src2) _mm512_cvtps_ph(src1, src2) + +#define MS512_INT64_TO_FLOAT32(src) _mm512_cvtepi64_ps(src) +#define MS512_FLOAT32_TO_INT64(src) _mm512_cvttps_epi64(src) + +#define MS512_INT64_TO_FLOAT16(src) _mm512_cvtepi64_ph(src) +#define MS512_FLOAT16_TO_INT64(src) _mm512_cvttph_epi64(src) + +#define MS512_INT32_TO_FLOAT64(src) _mm512_cvtepi32_pd(src) +#define MS512_FLOAT64_TO_INT32(src) _mm512_cvttpd_epi32(src) + +#define MS512_INT64_TO_FLOAT64(src) _mm512_cvtepi64_pd(src) +#define MS512_FLOAT64_TO_INT64(src) _mm512_cvttpd_epi64(src) + +#define MS512_INT16_TO_INT32(src) _mm512_cvtepi16_epi32(src) +#define MS512_INT16_TO_INT64(src) _mm512_cvtepi16_epi64(src) +#define MS512_INT32_TO_INT16(src) _mm512_cvtepi32_epi16(src) +#define MS512_INT32_TO_INT64(src) _mm512_cvtepi32_epi64(src) +#define MS512_INT64_TO_INT16(src) _mm512_cvtepi64_epi16(src) +#define MS512_INT64_TO_INT32(src) _mm512_cvtepi64_epi32(src) + +static inline MS_FLOAT32X16 simd_exp512_f32(MS_FLOAT32X16 input) { + static MS_FLOAT32X16 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, + 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, + 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, + 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X16 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, + -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, + -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, + -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + static MS_FLOAT32X16 param[] = { + {0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, + 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f}, + {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, + 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, + {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, + 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, + {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, + 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, + {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, + 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, + 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, + 1.44269504088896341f}, + {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}}; + + input = MS_MAX512_F32(minv, MS_MIN512_F32(input, maxv)); + MS_INT32X16 integer = MS_CVT512PS_EPI32(MS_FLOOR512_F32(MS_FMADD512_F32(input, param[6], param[4]))); + MS_FLOAT32X16 decimal = MS_SUB512_F32(input, MS_MUL512_F32(MS_CVT512EPI32_PS(integer), param[0])); + MS_INT32X16 int_exp = MS_SLLI512_EPI32(MS_ADD512_EPI32(integer, MS_MOV512_EPI32(126)), 23); + MS_FLOAT32X16 tmp = MS_FMADD512_F32(decimal, MS_FMADD512_F32(decimal, param[1], param[2]), param[3]); + tmp = MS_FMADD512_F32(decimal, MS_FMADD512_F32(decimal, tmp, param[4]), param[5]); + MS_FLOAT32X16 decimal_exp = MS_FMADD512_F32(decimal, tmp, param[5]); + return MS_MUL512_F32(param[7], MS_MUL512_F32(decimal_exp, MS_CAST512_F32_S32(int_exp))); +} + +static inline MS_FLOAT32X16 simd_hexp512_f32(MS_FLOAT32X16 src) { + MS_FLOAT32X16 dst; + MS512_F32_GETI(dst, 0) = exp(MS512_F32_GETI(src, 0)); + MS512_F32_GETI(dst, 1) = exp(MS512_F32_GETI(src, 1)); + MS512_F32_GETI(dst, 2) = exp(MS512_F32_GETI(src, 2)); + MS512_F32_GETI(dst, 3) = exp(MS512_F32_GETI(src, 3)); + MS512_F32_GETI(dst, 4) = exp(MS512_F32_GETI(src, 4)); + MS512_F32_GETI(dst, 5) = exp(MS512_F32_GETI(src, 5)); + MS512_F32_GETI(dst, 6) = exp(MS512_F32_GETI(src, 6)); + MS512_F32_GETI(dst, 7) = exp(MS512_F32_GETI(src, 7)); + MS512_F32_GETI(dst, 8) = exp(MS512_F32_GETI(src, 8)); + MS512_F32_GETI(dst, 9) = exp(MS512_F32_GETI(src, 9)); + MS512_F32_GETI(dst, 10) = exp(MS512_F32_GETI(src, 10)); + MS512_F32_GETI(dst, 11) = exp(MS512_F32_GETI(src, 11)); + MS512_F32_GETI(dst, 12) = exp(MS512_F32_GETI(src, 12)); + MS512_F32_GETI(dst, 13) = exp(MS512_F32_GETI(src, 13)); + MS512_F32_GETI(dst, 14) = exp(MS512_F32_GETI(src, 14)); + MS512_F32_GETI(dst, 15) = exp(MS512_F32_GETI(src, 15)); + return dst; +} + +static inline void simd_exp512(MS_FLOAT32X16 input, float *dst) { + MS_FLOAT32X16 res = simd_exp512_f32(input); + MS_ST512_F32(dst, res); +} + +static inline MS_FLOAT32X16 MS_TANHX16_F32(MS_FLOAT32X16 src) { + static const MS_FLOAT32X16 data0 = {378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, + 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f}; + static const MS_FLOAT32X16 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, + 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f}; + static const MS_FLOAT32X16 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, + 135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, + 135135.0f, 135135.0f, 135135.0f, 135135.0f}; + static const MS_FLOAT32X16 data3 = {28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, + 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f}; + static const MS_FLOAT32X16 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, + 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f}; + static const MS_FLOAT32X16 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, + 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f}; + static const MS_FLOAT32X16 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X16 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + MS_FLOAT32X16 square = MS_MUL512_F32(src, src); + MS_FLOAT32X16 a = + MS_MUL512_F32(MS_FMADD512_F32(MS_FMADD512_F32(MS_ADD512_F32(square, data0), square, data1), square, data2), src); + MS_FLOAT32X16 b = + MS_FMADD512_F32(MS_FMADD512_F32(MS_FMADD512_F32(data3, square, data4), square, data5), square, data2); + MS_FLOAT32X16 res = MS_DIV512_F32(a, b); + MS_FLOAT32X16 up_limit = MS_MOV512_F32(5.0f); + MS_FLOAT32X16 down_limit = MS_MOV512_F32(-5.0f); + MS_MASK512_TYPE up_mask = MS_CMPGT512_F32(src, up_limit); + MS_MASK512_TYPE down_mask = MS_CMPLT512_F32(src, down_limit); + res = MS_BLEND512_F32(res, pos, up_mask); + res = MS_BLEND512_F32(res, neg, down_mask); + return res; +} + +#define MS_TANH512_F32 MS_TANHX16_F32 + +static inline MS_FLOAT32X16 MS512_ERF_F32(MS_FLOAT32X16 src) { + MS_FLOAT32X16 dst; + MS_F32X16_GETI(dst, 0) = erff(MS_F32X16_GETI(src, 0)); + MS_F32X16_GETI(dst, 1) = erff(MS_F32X16_GETI(src, 1)); + MS_F32X16_GETI(dst, 2) = erff(MS_F32X16_GETI(src, 2)); + MS_F32X16_GETI(dst, 3) = erff(MS_F32X16_GETI(src, 3)); + MS_F32X16_GETI(dst, 4) = erff(MS_F32X16_GETI(src, 4)); + MS_F32X16_GETI(dst, 5) = erff(MS_F32X16_GETI(src, 5)); + MS_F32X16_GETI(dst, 6) = erff(MS_F32X16_GETI(src, 6)); + MS_F32X16_GETI(dst, 7) = erff(MS_F32X16_GETI(src, 7)); + MS_F32X16_GETI(dst, 8) = erff(MS_F32X16_GETI(src, 8)); + MS_F32X16_GETI(dst, 9) = erff(MS_F32X16_GETI(src, 9)); + MS_F32X16_GETI(dst, 10) = erff(MS_F32X16_GETI(src, 10)); + MS_F32X16_GETI(dst, 11) = erff(MS_F32X16_GETI(src, 11)); + MS_F32X16_GETI(dst, 12) = erff(MS_F32X16_GETI(src, 12)); + MS_F32X16_GETI(dst, 13) = erff(MS_F32X16_GETI(src, 13)); + MS_F32X16_GETI(dst, 14) = erff(MS_F32X16_GETI(src, 14)); + MS_F32X16_GETI(dst, 15) = erff(MS_F32X16_GETI(src, 15)); + return dst; +} + +#define MS_LOAD512X8_F32(src, input_ptr, num) \ + MS_FLOAT32X16 src##1 = MS_LD512_F32(input_ptr); \ + MS_FLOAT32X16 src##2 = MS_LD512_F32(input_ptr + 1 * num); \ + MS_FLOAT32X16 src##3 = MS_LD512_F32(input_ptr + 2 * num); \ + MS_FLOAT32X16 src##4 = MS_LD512_F32(input_ptr + 3 * num); \ + MS_FLOAT32X16 src##5 = MS_LD512_F32(input_ptr + 4 * num); \ + MS_FLOAT32X16 src##6 = MS_LD512_F32(input_ptr + 5 * num); \ + MS_FLOAT32X16 src##7 = MS_LD512_F32(input_ptr + 6 * num); \ + MS_FLOAT32X16 src##8 = MS_LD512_F32(input_ptr + 7 * num); + +#define MS_LOAD512X4_F32(src, input_ptr, num) \ + MS_FLOAT32X16 src##1 = MS_LD512_F32(input_ptr); \ + MS_FLOAT32X16 src##2 = MS_LD512_F32(input_ptr + 1 * num); \ + MS_FLOAT32X16 src##3 = MS_LD512_F32(input_ptr + 2 * num); \ + MS_FLOAT32X16 src##4 = MS_LD512_F32(input_ptr + 3 * num); + +#define MS_FMADD512X8_F32(src, weight, dst) \ + dst##1 = MS_MLA512_F32(dst##1, src##1, weight); \ + dst##2 = MS_MLA512_F32(dst##2, src##2, weight); \ + dst##3 = MS_MLA512_F32(dst##3, src##3, weight); \ + dst##4 = MS_MLA512_F32(dst##4, src##4, weight); \ + dst##5 = MS_MLA512_F32(dst##5, src##5, weight); \ + dst##6 = MS_MLA512_F32(dst##6, src##6, weight); \ + dst##7 = MS_MLA512_F32(dst##7, src##7, weight); \ + dst##8 = MS_MLA512_F32(dst##8, src##8, weight); + +#define MS_FMADD512X4_F32(src, weight, dst) \ + dst##1 = MS_MLA512_F32(src##1, weight, dst##1); \ + dst##2 = MS_MLA512_F32(src##2, weight, dst##2); \ + dst##3 = MS_MLA512_F32(src##3, weight, dst##3); \ + dst##4 = MS_MLA512_F32(src##4, weight, dst##4); + +#define MS_SET_ZERO512X8_F32(dst) \ + MS_FLOAT32X16 dst##1 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##2 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##3 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##4 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##5 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##6 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##7 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##8 = _mm512_setzero_ps(); + +#define MS_SET_ZERO512X4_F32(dst) \ + MS_FLOAT32X16 dst##1 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##2 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##3 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##4 = _mm512_setzero_ps(); + +#pragma GCC pop_options + +#endif // NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_avx_instructions.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_avx_instructions.h new file mode 100644 index 0000000000000000000000000000000000000000..2b3647d84423417b94a61351471a8f50639567a0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_avx_instructions.h @@ -0,0 +1,440 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#define NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include + +#ifdef _MSC_VER +#include +#define MS_F32X8_GETI(src, i) src.m256_f32[i] +#define MS256_F32_GETI(src, i) src.m256_f32[i] +#else +#include +#define MS_F32X8_GETI(src, i) src[i] +#define MS256_F32_GETI(src, i) src[i] +#endif + +#define PI 3.1415926f +#define LN2 0.693147f + +#define MS_FLOAT32X8 __m256 +#define MS_FLOAT256_F32 __m256 +#define MS_INT32X8 __m256i +#define MS_INT256_EPI32 __m256i +#define MS_MASK256_TYPE MS_FLOAT32X8 +#define MS_LD256_F32 _mm256_loadu_ps +#define MS_LD256_EPI32(src) _mm256_loadu_si256((__m256i const *)(src)) +#define MS_ADD256_F32 _mm256_add_ps +#define MS_ADD256_EPI32 _mm256_add_epi32 +#define MS_MOV256_F32 _mm256_set1_ps +#define MS_MOV256_EPI32 _mm256_set1_epi32 +#define MS_MOV256_VAL0_F32 _mm256_setzero_ps() +#define MS_MLA256_F32(src1, src2, src3) _mm256_fmadd_ps(src2, src3, src1) +#define MS_ST256_F32 _mm256_storeu_ps +#define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2) +#define MS_SUB256_F32 _mm256_sub_ps +#define MS_SUB256_EPI32 _mm256_sub_epi32 +#define MS_MAX256_F32 _mm256_max_ps +#define MS_MAX256_EPI32 _mm256_max_epi32 +#define MS_MIN256_F32 _mm256_min_ps +#define MS_MIN256_EPI32 _mm256_min_epi32 +#define MS_SQRT256_F32 _mm256_sqrt_ps +#define MS_RSQRT256_F32 _mm256_rsqrt_ps +#define MS_SIN256_F32 _mm256_sin_ps +#define MS_ERF256_F32 _mm256_erf_ps +#define MS_ROUND256_F32(src) _mm256_round_ps(src, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) +#define MS_FLOOR256_F32 _mm256_floor_ps +#define MS_CEIL256_F32 _mm256_ceil_ps +#define MS_MUL256_F32(src1, src2) _mm256_mul_ps(src1, src2) +#define MS_MUL256_EPI32(src1, src2) _mm256_mullo_epi32(src1, src2) +#define MS_FMADD256_F32(src1, src2, src3) _mm256_fmadd_ps(src1, src2, src3) +#define MS_FMSUB256_F32(src1, src2, src3) _mm256_fmsub_ps(src1, src2, src3) +#define MS_FSMUL256_F32(src1, src2, src3) _mm256_fnmadd_ps(src3, src2, src1) // src1 - src2 * src3 +#define MS_DIV256_F32(src1, src2) _mm256_div_ps(src1, src2) +#define MS_MUL256_N_F32(src1, src2) _mm256_mul_ps(src1, _mm256_set1_ps(src2)) +#define MS_MUL256_N_EPI32(src1, src2) _mm256_mullo_epi32(src1, _mm256_set1_epi32(src2)) +#define MS_DIV256_N_F32(src1, src2) _mm256_div_ps(src1, _mm256_set1_ps(src2)) +#define MS_SLLI256_EPI32(src1, src2) _mm256_slli_epi32(src1, src2) +#define MS_CVT256PS_EPI32(src) _mm256_cvttps_epi32(src) +#define MS_CVT256EPI32_PS(src) _mm256_cvtepi32_ps(src) // truncate float to int +#define MS_CMP256_F32(src1, src2, src3) _mm256_cmp_ps(src1, src2, src3) +#define MS_CMPGT256_F32(src1, src2) _mm256_cmp_ps(src1, src2, 30) +#define MS_CMPLE256_F32(src1, src2) _mm256_cmp_ps(src1, src2, 18) +#define MS_CMPLT256_F32(src1, src2) _mm256_cmp_ps(src1, src2, 17) +#define MS_CMPGT256_EPI32(src1, src2) _mm256_cmpgt_epi32(src1, src2) +#define MS_BLEND256_F32(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3) +#define MS_BLEND256_EPI32(src1, src2, src3) _mm256_blendv_epi8(src1, src2, src3) +#define MS_CAST256_F32_S32(src) _mm256_castsi256_ps(src) +#define MS_AND256_MASK(src1, src2) _mm256_and_ps(src1, src2) +#define MS_OR256_F32(src1, src2) _mm256_or_ps(src1, src2) +#define MS_AND256_MASK_F32(src1, src2) _mm256_and_ps(src1, src2) +#define MS_AND256_F32(src1, src2) _mm256_and_ps(src1, src2) + +#define MS256_ANDNOT_F32(src1, src2) _mm256_andnot_ps(src1, src2) +#define MS256_SRLI_EPI32(src1, src2) _mm256_srli_epi32(src1, src2) +#define MS256_AND_EPI32(src1, src2) _mm256_and_si256(src1, src2) +#define MS256_CASTPS_EPI32(src) _mm256_castps_si256(src) + +static inline MS_FLOAT32X8 MS_POW256_F32(MS_FLOAT32X8 src1, MS_FLOAT32X8 src2) { + MS_FLOAT32X8 dst; + MS_F32X8_GETI(dst, 0) = powf(MS_F32X8_GETI(src1, 0), MS_F32X8_GETI(src2, 0)); + MS_F32X8_GETI(dst, 1) = powf(MS_F32X8_GETI(src1, 1), MS_F32X8_GETI(src2, 1)); + MS_F32X8_GETI(dst, 2) = powf(MS_F32X8_GETI(src1, 2), MS_F32X8_GETI(src2, 2)); + MS_F32X8_GETI(dst, 3) = powf(MS_F32X8_GETI(src1, 3), MS_F32X8_GETI(src2, 3)); + MS_F32X8_GETI(dst, 4) = powf(MS_F32X8_GETI(src1, 4), MS_F32X8_GETI(src2, 4)); + MS_F32X8_GETI(dst, 5) = powf(MS_F32X8_GETI(src1, 5), MS_F32X8_GETI(src2, 5)); + MS_F32X8_GETI(dst, 6) = powf(MS_F32X8_GETI(src1, 6), MS_F32X8_GETI(src2, 6)); + MS_F32X8_GETI(dst, 7) = powf(MS_F32X8_GETI(src1, 7), MS_F32X8_GETI(src2, 7)); + return dst; +} + +static inline MS_FLOAT32X8 MS_ABS256_F32(MS_FLOAT32X8 src) { + MS_FLOAT32X8 dst; + MS_F32X8_GETI(dst, 0) = fabsf(MS_F32X8_GETI(src, 0)); + MS_F32X8_GETI(dst, 1) = fabsf(MS_F32X8_GETI(src, 1)); + MS_F32X8_GETI(dst, 2) = fabsf(MS_F32X8_GETI(src, 2)); + MS_F32X8_GETI(dst, 3) = fabsf(MS_F32X8_GETI(src, 3)); + MS_F32X8_GETI(dst, 4) = fabsf(MS_F32X8_GETI(src, 4)); + MS_F32X8_GETI(dst, 5) = fabsf(MS_F32X8_GETI(src, 5)); + MS_F32X8_GETI(dst, 6) = fabsf(MS_F32X8_GETI(src, 6)); + MS_F32X8_GETI(dst, 7) = fabsf(MS_F32X8_GETI(src, 7)); + return dst; +} + +static inline MS_FLOAT256_F32 SIMD_SIGN256_F32(MS_FLOAT256_F32 src) { + MS_FLOAT256_F32 abs_src = MS_ABS256_F32(src); + MS_FLOAT256_F32 sign = MS_DIV256_F32(abs_src, src); + return sign; +} + +#define SIMD_SIGNABS256_F32(src, abs_src) MS_DIV256_F32(abs_src, src) + +static inline MS_FLOAT32X8 MS_COS256_F32(MS_FLOAT32X8 src) { + static const MS_FLOAT32X8 pi = {PI, PI, PI, PI, PI, PI, PI, PI}; + static const MS_FLOAT32X8 pi2_neg = { + -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, + }; + static const MS_FLOAT32X8 div_pi2 = {1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), + 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI)}; + MS_FLOAT256_F32 src_abs = MS_ABS256_F32(src); + MS_FLOAT256_F32 src_cycle = + MS_ADD256_F32(MS_MUL256_F32(MS_FLOOR256_F32(MS_MUL256_F32(MS_ADD256_F32(src_abs, pi), div_pi2)), pi2_neg), src_abs); + + static const MS_FLOAT256_F32 data0 = {1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, + 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90}; + static const MS_FLOAT256_F32 data1 = {1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, + 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56}; + static const MS_FLOAT256_F32 data2 = {1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, + 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30}; + static const MS_FLOAT256_F32 data3 = {1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, + 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12}; + static const MS_FLOAT256_F32 data4 = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; + static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + + MS_FLOAT32X8 square = MS_MUL256_F32(src_cycle, src_cycle); + + MS_FLOAT32X8 tmp = + MS_MUL256_F32(MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_MUL256_F32(neg, square), data0), pos), square), data1); + MS_FLOAT32X8 tmp1 = MS_MUL256_F32(MS_MUL256_F32(MS_ADD256_F32(tmp, neg), square), data2); + MS_FLOAT256_F32 res = MS_ADD256_F32( + MS_MUL256_F32( + MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_MUL256_F32(MS_ADD256_F32(tmp1, pos), square), data3), neg), square), + data4), + pos); + return res; +} + +static inline MS_FLOAT32X8 MS256_LOG_F32(MS_FLOAT32X8 src) { + const MS_INT256_EPI32 gFloatExpMask = MS_MOV256_EPI32(0xffULL << 23); + const MS_INT256_EPI32 gFloatExp0 = MS_MOV256_EPI32(127ULL << 23); + const MS_INT256_EPI32 gExpNormalizer = MS_MOV256_EPI32(127); + static const MS_FLOAT256_F32 data0 = {1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, + 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11}; + static const MS_FLOAT256_F32 data1 = {1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9}; + static const MS_FLOAT256_F32 data2 = {1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7}; + static const MS_FLOAT256_F32 data3 = {0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f}; + static const MS_FLOAT256_F32 data4 = {1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3}; + static const MS_FLOAT256_F32 data5 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT256_F32 data6 = {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}; + static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT32X8 ln2 = {LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2}; + + const MS_INT256_EPI32 exps32 = MS256_SRLI_EPI32(MS256_AND_EPI32(gFloatExpMask, MS256_CASTPS_EPI32(src)), 23); + const MS_INT256_EPI32 normExps = MS_SUB256_EPI32(exps32, gExpNormalizer); + const MS_FLOAT32X8 expsPD = MS_CVT256EPI32_PS(normExps); + const MS_FLOAT32X8 y = + MS_OR256_F32(MS_CAST256_F32_S32(gFloatExp0), MS256_ANDNOT_F32(MS_CAST256_F32_S32(gFloatExpMask), src)); + MS_FLOAT32X8 div = MS_DIV256_F32(MS_ADD256_F32(y, neg), MS_ADD256_F32(y, pos)); + MS_FLOAT32X8 square = MS_MUL256_F32(div, div); + + MS_FLOAT32X8 tmp = MS_ADD256_F32( + MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(square, MS_ADD256_F32(MS_MUL256_F32(square, data0), data1)), data2), + square), + data3); + MS_FLOAT32X8 tmp1 = MS_MUL256_F32(square, MS_ADD256_F32(MS_MUL256_F32(square, tmp), data4)); + MS_FLOAT32X8 res = + MS_ADD256_F32(MS_MUL256_F32(ln2, expsPD), MS_MUL256_F32(MS_MUL256_F32(div, MS_ADD256_F32(tmp1, data5)), data6)); + MS_FLOAT32X8 mask = MS_CMP256_F32(src, MS_MOV256_F32(0.0f), _CMP_EQ_OQ); + res = MS_BLEND256_F32(res, MS_MOV256_F32(-INFINITY), mask); + mask = MS_CMP256_F32(src, MS_MOV256_F32(INFINITY), _CMP_EQ_OQ); + res = MS_BLEND256_F32(res, MS_MOV256_F32(INFINITY), mask); + mask = MS_OR256_F32(MS_CMPLT256_F32(src, MS_MOV256_F32(0.0f)), MS_CMP256_F32(src, MS_MOV256_F32(0.0f), _CMP_UNORD_Q)); + res = MS_BLEND256_F32(res, MS_MOV256_F32(NAN), mask); + return res; +} + +static inline float MS_GET_MAX256_F32(__m256 src) { + float result = MS_F32X8_GETI(src, 0); + for (int i = 1; i < 8; i++) { // avx block num : 8 + result = fmaxf(result, MS_F32X8_GETI(src, i)); + } + return result; +} + +static inline float MS_GET_SUM256_F32(__m256 src) { + float result = MS_F32X8_GETI(src, 0); + for (int i = 1; i < 8; i++) { // avx block num : 8 + result = result + MS_F32X8_GETI(src, i); + } + return result; +} + +#define MS_DIV256_EPI32(src1, src2) \ + _mm256_cvttps_epi32(MS_DIV256_F32(_mm256_cvtepi32_ps(src1), _mm256_cvtepi32_ps(src2))) + +#define MS256_INT16_TO_FLOAT16(src) _mm256_cvtepi16_ph(src) +#define MS256_FLOAT16_TO_INT16(src) _mm256_cvttph_epi16(src) + +#define MS256_INT32_TO_FLOAT16(src) _mm256_cvtepi32_ph(src) +#define MS256_FLOAT16_TO_INT32(src) _mm256_cvttph_epi32(src) + +#define MS256_INT32_TO_FLOAT32(src) _mm256_cvtepi32_ps(src) +#define MS256_FLOAT32_TO_INT32(src) _mm256_cvttps_epi32(src) + +#define MS256_INT64_TO_FLOAT32(src) _mm256_cvtepi64_ps(src) +#define MS256_FLOAT32_TO_INT64(src) _mm256_cvttps_epi64(src) + +#define MS256_INT64_TO_FLOAT16(src) _mm256_cvtepi64_ph(src) +#define MS256_FLOAT16_TO_INT64(src) _mm256_cvttph_epi64(src) + +#define MS256_INT32_TO_FLOAT64(src) _mm256_cvtepi32_pd(src) +#define MS256_FLOAT64_TO_INT32(src) _mm256_cvttpd_epi32(src) + +#define MS256_INT64_TO_FLOAT64(src) _mm256_cvtepi64_pd(src) +#define MS256_FLOAT64_TO_INT64(src) _mm256_cvttpd_epi64(src) + +#define MS256_INT16_TO_INT32(src) _mm256_cvtepi16_epi32(src) +#define MS256_INT16_TO_INT64(src) _mm256_cvtepi16_epi64(src) +#define MS256_INT32_TO_INT16(src) _mm256_cvtepi32_epi16(src) +#define MS256_INT32_TO_INT64(src) _mm256_cvtepi32_epi64(src) +#define MS256_INT64_TO_INT16(src) _mm256_cvtepi64_epi16(src) +#define MS256_INT64_TO_INT32(src) _mm256_cvtepi64_epi32(src) + +static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) { + MS_FLOAT32X8 dst; + MS_F32X8_GETI(dst, 0) = sqrtf(MS_F32X8_GETI(src, 0)); + MS_F32X8_GETI(dst, 1) = sqrtf(MS_F32X8_GETI(src, 1)); + MS_F32X8_GETI(dst, 2) = sqrtf(MS_F32X8_GETI(src, 2)); + MS_F32X8_GETI(dst, 3) = sqrtf(MS_F32X8_GETI(src, 3)); + MS_F32X8_GETI(dst, 4) = sqrtf(MS_F32X8_GETI(src, 4)); + MS_F32X8_GETI(dst, 5) = sqrtf(MS_F32X8_GETI(src, 5)); + MS_F32X8_GETI(dst, 6) = sqrtf(MS_F32X8_GETI(src, 6)); + MS_F32X8_GETI(dst, 7) = sqrtf(MS_F32X8_GETI(src, 7)); + return dst; +} + +#define MS_LOAD256X4_F32(src, input_ptr, num) \ + MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \ + MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \ + MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \ + MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); + +#define MS_LOAD256X8_F32(src, input_ptr, num) \ + MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \ + MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \ + MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \ + MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \ + MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \ + MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \ + MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \ + MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num); + +#define MS_LOAD256X16_F32(src, input_ptr, num) \ + MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \ + MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \ + MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \ + MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \ + MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \ + MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \ + MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \ + MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num); \ + MS_FLOAT32X8 src##9 = MS_LD256_F32(input_ptr + 8 * num); \ + MS_FLOAT32X8 src##10 = MS_LD256_F32(input_ptr + 9 * num); \ + MS_FLOAT32X8 src##11 = MS_LD256_F32(input_ptr + 10 * num); \ + MS_FLOAT32X8 src##12 = MS_LD256_F32(input_ptr + 11 * num); \ + MS_FLOAT32X8 src##13 = MS_LD256_F32(input_ptr + 12 * num); \ + MS_FLOAT32X8 src##14 = MS_LD256_F32(input_ptr + 13 * num); \ + MS_FLOAT32X8 src##15 = MS_LD256_F32(input_ptr + 14 * num); \ + MS_FLOAT32X8 src##16 = MS_LD256_F32(input_ptr + 15 * num); + +#define STORE256X8_F32(output_ptr, num, dst) \ + MS_ST256_F32(output_ptr + 0 * num, dst##1); \ + MS_ST256_F32(output_ptr + 1 * num, dst##2); \ + MS_ST256_F32(output_ptr + 2 * num, dst##3); \ + MS_ST256_F32(output_ptr + 3 * num, dst##4); \ + MS_ST256_F32(output_ptr + 4 * num, dst##5); \ + MS_ST256_F32(output_ptr + 5 * num, dst##6); \ + MS_ST256_F32(output_ptr + 6 * num, dst##7); \ + MS_ST256_F32(output_ptr + 7 * num, dst##8); + +#define STORE256X16_F32(output_ptr, num, dst) \ + MS_ST256_F32(output_ptr + 0 * num, dst##1); \ + MS_ST256_F32(output_ptr + 1 * num, dst##2); \ + MS_ST256_F32(output_ptr + 2 * num, dst##3); \ + MS_ST256_F32(output_ptr + 3 * num, dst##4); \ + MS_ST256_F32(output_ptr + 4 * num, dst##5); \ + MS_ST256_F32(output_ptr + 5 * num, dst##6); \ + MS_ST256_F32(output_ptr + 6 * num, dst##7); \ + MS_ST256_F32(output_ptr + 7 * num, dst##8); \ + MS_ST256_F32(output_ptr + 8 * num, dst##9); \ + MS_ST256_F32(output_ptr + 9 * num, dst##10); \ + MS_ST256_F32(output_ptr + 10 * num, dst##11); \ + MS_ST256_F32(output_ptr + 11 * num, dst##12); \ + MS_ST256_F32(output_ptr + 12 * num, dst##13); \ + MS_ST256_F32(output_ptr + 13 * num, dst##14); \ + MS_ST256_F32(output_ptr + 14 * num, dst##15); \ + MS_ST256_F32(output_ptr + 15 * num, dst##16); + +static inline MS_FLOAT32X8 simd_exp256_f32(MS_FLOAT32X8 input) { + static MS_FLOAT32X8 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, + 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X8 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, + -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + static MS_FLOAT32X8 param[] = { + {0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f}, + {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, + {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, + {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, + {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, + 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f}, + {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}}; + input = MS_MAX256_F32(minv, MS_MIN256_F32(input, maxv)); + MS_INT32X8 integer = MS_CVT256PS_EPI32(MS_FLOOR256_F32(MS_FMADD256_F32(input, param[6], param[4]))); + MS_FLOAT32X8 decimal = MS_SUB256_F32(input, MS_MUL256_F32(MS_CVT256EPI32_PS(integer), param[0])); + MS_INT32X8 int_exp = MS_SLLI256_EPI32(MS_ADD256_EPI32(integer, MS_MOV256_EPI32(126)), 23); + MS_FLOAT32X8 tmp = MS_FMADD256_F32(decimal, MS_FMADD256_F32(decimal, param[1], param[2]), param[3]); + tmp = MS_FMADD256_F32(decimal, MS_FMADD256_F32(decimal, tmp, param[4]), param[5]); + MS_FLOAT32X8 decimal_exp = MS_FMADD256_F32(decimal, tmp, param[5]); + return MS_MUL256_F32(param[7], MS_MUL256_F32(decimal_exp, MS_CAST256_F32_S32(int_exp))); +} + +static inline MS_FLOAT32X8 simd_hexp256_f32(MS_FLOAT32X8 src) { + MS_FLOAT32X8 dst; + MS_F32X8_GETI(dst, 0) = exp(MS_F32X8_GETI(src, 0)); + MS_F32X8_GETI(dst, 1) = exp(MS_F32X8_GETI(src, 1)); + MS_F32X8_GETI(dst, 2) = exp(MS_F32X8_GETI(src, 2)); + MS_F32X8_GETI(dst, 3) = exp(MS_F32X8_GETI(src, 3)); + MS_F32X8_GETI(dst, 4) = exp(MS_F32X8_GETI(src, 4)); + MS_F32X8_GETI(dst, 5) = exp(MS_F32X8_GETI(src, 5)); + MS_F32X8_GETI(dst, 6) = exp(MS_F32X8_GETI(src, 6)); + MS_F32X8_GETI(dst, 7) = exp(MS_F32X8_GETI(src, 7)); + return dst; +} + +static inline void simd_exp256(MS_FLOAT32X8 input, float *dst) { + MS_FLOAT32X8 res = simd_exp256_f32(input); + MS_ST256_F32(dst, res); +} + +static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) { + static const MS_FLOAT32X8 data0 = {378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f}; + static const MS_FLOAT32X8 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f}; + static const MS_FLOAT32X8 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f, + 135135.0f, 135135.0f, 135135.0f, 135135.0f}; + static const MS_FLOAT32X8 data3 = {28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f}; + static const MS_FLOAT32X8 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f}; + static const MS_FLOAT32X8 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f}; + static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + MS_FLOAT32X8 square = MS_MUL256_F32(src, src); + MS_FLOAT32X8 a = + MS_MUL256_F32(MS_FMADD256_F32(MS_FMADD256_F32(MS_ADD256_F32(square, data0), square, data1), square, data2), src); + MS_FLOAT32X8 b = + MS_FMADD256_F32(MS_FMADD256_F32(MS_FMADD256_F32(data3, square, data4), square, data5), square, data2); + MS_FLOAT32X8 res = MS_DIV256_F32(a, b); + MS_FLOAT32X8 up_limit = MS_MOV256_F32(5.0f); + MS_FLOAT32X8 down_limit = MS_MOV256_F32(-5.0f); + MS_FLOAT32X8 up_mask = MS_CMPGT256_F32(src, up_limit); + MS_FLOAT32X8 down_mask = MS_CMPLT256_F32(src, down_limit); + res = MS_BLEND256_F32(res, pos, up_mask); + res = MS_BLEND256_F32(res, neg, down_mask); + return res; +} + +#define MS_TANH256_F32 MS_TANHX8_F32 + +static inline MS_FLOAT32X8 MS256_ERF_F32(MS_FLOAT32X8 src) { + MS_FLOAT32X8 dst; + MS_F32X8_GETI(dst, 0) = erff(MS_F32X8_GETI(src, 0)); + MS_F32X8_GETI(dst, 1) = erff(MS_F32X8_GETI(src, 1)); + MS_F32X8_GETI(dst, 2) = erff(MS_F32X8_GETI(src, 2)); + MS_F32X8_GETI(dst, 3) = erff(MS_F32X8_GETI(src, 3)); + MS_F32X8_GETI(dst, 4) = erff(MS_F32X8_GETI(src, 4)); + MS_F32X8_GETI(dst, 5) = erff(MS_F32X8_GETI(src, 5)); + MS_F32X8_GETI(dst, 6) = erff(MS_F32X8_GETI(src, 6)); + MS_F32X8_GETI(dst, 7) = erff(MS_F32X8_GETI(src, 7)); + return dst; +} + +#define MS_FMADD256X8_F32(src, weight, dst) \ + dst##1 = MS_MLA256_F32(dst##1, src##1, weight); \ + dst##2 = MS_MLA256_F32(dst##2, src##2, weight); \ + dst##3 = MS_MLA256_F32(dst##3, src##3, weight); \ + dst##4 = MS_MLA256_F32(dst##4, src##4, weight); \ + dst##5 = MS_MLA256_F32(dst##5, src##5, weight); \ + dst##6 = MS_MLA256_F32(dst##6, src##6, weight); \ + dst##7 = MS_MLA256_F32(dst##7, src##7, weight); \ + dst##8 = MS_MLA256_F32(dst##8, src##8, weight); + +#define MS_SET_ZERO256X8_F32(dst) \ + MS_FLOAT32X8 dst##1 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##2 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##3 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##4 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##5 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##6 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##7 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##8 = _mm256_setzero_ps(); + +#define MS_FMADD256X4_F32(src, weight, dst) \ + dst##1 = MS_MLA256_F32(dst##1, src##1, weight); \ + dst##2 = MS_MLA256_F32(dst##2, src##2, weight); \ + dst##3 = MS_MLA256_F32(dst##3, src##3, weight); \ + dst##4 = MS_MLA256_F32(dst##4, src##4, weight); + +#define MS_SET_ZERO256X4_F32(dst) \ + MS_FLOAT32X8 dst##1 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##2 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##3 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##4 = _mm256_setzero_ps(); + +#define MS_REDUCE_ADD256_F32(src) (src = _mm256_hadd_ps(src, src), src = _mm256_hadd_ps(src, src), src[0] + src[4]); +#endif // NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_cpu_info.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_cpu_info.c new file mode 100644 index 0000000000000000000000000000000000000000..8348f795aeec3d3b290c06ff2b38b3eca31d8405 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_cpu_info.c @@ -0,0 +1,141 @@ + + +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" +#include +#include +#include +#include "nnacl_c/errorcode.h" + +typedef unsigned int DWORD; +struct X86CpuInfoContext { + bool fma_flag_; + bool sse4_1_flag_; + bool avx2_flag_; + bool avx512_flag_; +}; + +static struct X86CpuInfoContext g_x86_cpu_info_context_; + +inline const bool X86_Fma_Support(void) { return g_x86_cpu_info_context_.fma_flag_; } + +inline const bool X86_Sse_Support(void) { +#ifdef ENABLE_SSE + return g_x86_cpu_info_context_.sse4_1_flag_; +#else + return false; +#endif +} + +inline const bool X86_Avx_Support(void) { +#ifdef ENABLE_AVX + return g_x86_cpu_info_context_.avx2_flag_; +#else + return false; +#endif +} + +inline const bool X86_Avx512_Support(void) { +#ifdef ENABLE_AVX512 + return g_x86_cpu_info_context_.avx512_flag_; +#else + return false; +#endif +} + +void ExecuteCpuIdCmd(DWORD cmd_code, DWORD *eax_data, DWORD *ebx_data, DWORD *ecx_data, DWORD *edx_data) { + DWORD deax, debx, decx, dedx; + asm volatile( + "movl %4, %%eax;\n" + "movl $0, %%ecx;\n" + "cpuid;\n" + "movl %%eax, %0;\n" + "movl %%ebx, %1;\n" + "movl %%ecx, %2;\n" + "movl %%edx, %3;\n" + : "=r"(deax), "=r"(debx), "=r"(decx), "=r"(dedx) + : "r"(cmd_code) + : "%eax", "%ebx", "%ecx", "%edx"); + + *eax_data = deax; + *ebx_data = debx; + *ecx_data = decx; + *edx_data = dedx; +} + +bool IsIntelX86Platform(void) { + DWORD eax_data, ebx_data, ecx_data, edx_data; + + const int vid_info_size = 13; + char *vid_info = malloc(sizeof(char) * vid_info_size); + if (vid_info == NULL) { + return false; + } + memset(vid_info, 0, vid_info_size); + + ExecuteCpuIdCmd(0, &eax_data, &ebx_data, &ecx_data, &edx_data); // eax = 0, execute cpuid to get vid info + + memcpy(vid_info, &ebx_data, 4); // Copy the first 4 characters to the array[0:3] + memcpy(vid_info + 4, &edx_data, 4); // Copy the middle 4 characters to the array[4:8] + memcpy(vid_info + 8, &ecx_data, 4); // Copy the last 4 characters to the array[8:12] + + int x86_intel_flag = (strcmp(vid_info, "GenuineIntel") == 0 || strcmp(vid_info, "AuthenticAMD") == 0) ? 1 : 0; + + free(vid_info); + return x86_intel_flag; +} + +int IntelX86CpuInfoInit(void) { + if (!IsIntelX86Platform()) { + return NNACL_ERR; + } + DWORD eax_data, ebx_data, ecx_data, edx_data; + ExecuteCpuIdCmd(1, &eax_data, &ebx_data, &ecx_data, &edx_data); // eax = 1, execute cpuid to get sse/fma flag + g_x86_cpu_info_context_.sse4_1_flag_ = (ecx_data & (1 << 19)) == 0 ? false : true; // sse flag is ecx 19 bit + g_x86_cpu_info_context_.fma_flag_ = (ecx_data & (1 << 12)) == 0 ? false : true; // fma flag is ecx 12 bit + + ExecuteCpuIdCmd(7, &eax_data, &ebx_data, &ecx_data, &edx_data); // eax = 7, execute cpuid to get avx2/avx512 flag + g_x86_cpu_info_context_.avx2_flag_ = (ebx_data & (1 << 5)) == 0 ? false : true; // avx2 flag is ecx 5 bit + g_x86_cpu_info_context_.avx512_flag_ = (ebx_data & (1 << 16)) == 0 ? false : true; // avx512 flag is ecx 16 bit + + return NNACL_OK; +} + +X86CpuInfoErrorCodeEnum IntelX86InstructionSetSupportCheck(void) { + if (IntelX86CpuInfoInit() != NNACL_OK) { + return X86CPUINFO_PLATFORM_ERR; + } +#if defined(ENABLE_AVX512) && !defined(AVX512_HARDWARE_SELF_AWARENESS) + if (!X86_Avx512_Support()) { + return X86CPUINFO_AVX512_ERR; + } +#endif + +#ifdef ENABLE_AVX + if (!X86_Avx_Support()) { + return X86CPUINFO_AVX_ERR; + } +#endif + +#ifdef ENABLE_SSE + if (!X86_Sse_Support()) { + return X86CPUINFO_SSE_ERR; + } +#endif + return X86CPUINFO_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_cpu_info.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_cpu_info.h new file mode 100644 index 0000000000000000000000000000000000000000..cec5ef130e3fd4391343bd8049493ae4df4400e7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_cpu_info.h @@ -0,0 +1,61 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_MS_SIMD_CPU_INFO_H_ +#define NNACL_MS_SIMD_CPU_INFO_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef ENABLE_AVX512 +#define AVX512_HARDWARE_SELF_AWARENESS +#endif + +#if defined(AVX512_HARDWARE_SELF_AWARENESS) +#define AVX512_HARDWARE_SELF_AWARENESS_BEGIN if (X86_Avx512_Support()) { +#define AVX512_HARDWARE_SELF_AWARENESS_END } +#else +#define AVX512_HARDWARE_SELF_AWARENESS_BEGIN +#define AVX512_HARDWARE_SELF_AWARENESS_END +#endif + +typedef enum X86CpuInfoErrorCodeEnum { + X86CPUINFO_OK = 0, + X86CPUINFO_PLATFORM_ERR = 1, + X86CPUINFO_AVX512_ERR, + X86CPUINFO_AVX_ERR, + X86CPUINFO_SSE_ERR, + X86CPUINFO_END = 9999 +} X86CpuInfoErrorCodeEnum; + +const bool X86_Fma_Support(void); +const bool X86_Sse_Support(void); +const bool X86_Avx_Support(void); +const bool X86_Avx512_Support(void); + +bool IsIntelX86Platform(void); +X86CpuInfoErrorCodeEnum IntelX86InstructionSetSupportCheck(void); + +int IntelX86CpuInfoInit(void); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions.h new file mode 100644 index 0000000000000000000000000000000000000000..030c610a5cb16c020f12826ed2f7ca114f4e2f53 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions.h @@ -0,0 +1,563 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#define NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" + +#ifdef ENABLE_AVX512 +#include "nnacl_c/intrinsics/ms_simd_avx512_instructions.h" +#endif + +#ifdef ENABLE_AVX +#include "nnacl_c/intrinsics/ms_simd_avx_instructions.h" +#endif + +#ifdef ENABLE_SSE +#include "nnacl_c/intrinsics/ms_simd_sse_instructions.h" +#endif + +#ifdef ENABLE_ARM +#include "nnacl_c/intrinsics/ms_simd_neon_instructions.h" +#endif + +#define MS_SIMD_AVX512_INSTRUCTION(instruction, suffix) instruction##512##suffix +#define MS_SIMD_AVX_INSTRUCTION(instruction, suffix) instruction##256##suffix +#define MS_SIMD_SSE_INSTRUCTION(instruction, suffix) instruction##128##suffix +#define MS_SIMD_NEON_INSTRUCTION(instruction, suffix) instruction##128##suffix + +#define MS_SIMD_INSTRUCTION_F32(instruction) MS_SIMD_INSTRUCTION(instruction, _F32) +#define MS_SIMD_INSTRUCTION_EPI32(instruction) MS_SIMD_INSTRUCTION(instruction, _EPI32) +#define MS_SIMD_INSTRUCTION_MASK(instruction) MS_SIMD_INSTRUCTION(instruction, _MASK) + +// define (float/int) data +#define SIMD_F32 MS_SIMD_INSTRUCTION_F32(MS_FLOAT) +#define SIMD_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_INT) +#define SIMD_MASK MS_SIMD_INSTRUCTION(MS_MASK, _TYPE) + +// read scaler data +#define SIMD_F32_GETI MS_SIMD_INSTRUCTION(MS, _F32_GETI) + +// move (float/int) data +#define SIMD_MOV_F32 MS_SIMD_INSTRUCTION_F32(MS_MOV) +#define SIMD_MOV_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MOV) +#define SIMD_SET0_F32 MS_SIMD_INSTRUCTION(MS_MOV, _VAL0_F32) + +// load (float/int) data +#define SIMD_LD_F32 MS_SIMD_INSTRUCTION_F32(MS_LD) +#define SIMD_LD_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_LD) +#define SIMD_LD_HALF_EPI32 MS_SIMD_INSTRUCTION(MS_LD, _HALF_EPI32) + +// load 4 (float/int) data +#define SIMD_LDX4_F32 MS_SIMD_INSTRUCTION(MS_LOAD, X4_F32) +#define SIMD_LDX4_EPI32 MS_SIMD_INSTRUCTION(MS_LOAD, X4_EPI32) + +// stored (float/int) data +#define SIMD_ST_F32 MS_SIMD_INSTRUCTION_F32(MS_ST) +#define SIMD_ST_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_ST) +#define SIMD_ST_HALF_EPI32 MS_SIMD_INSTRUCTION(MS_ST, _HALF_EPI32) + +// sign +#define SIMD_SIGN_F32 MS_SIMD_INSTRUCTION_F32(SIMD_SIGN) +#define SIMD_SIGNABS_F32 MS_SIMD_INSTRUCTION_F32(SIMD_SIGNABS) + +// add (float/int) op +#define SIMD_ADD_F32 MS_SIMD_INSTRUCTION_F32(MS_ADD) +#define SIMD_ADD_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_ADD) +#define SIMD_ADD_N_F32(val1, val2) MS_EXPAND(SIMD_ADD_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_ADD_N_EPI32(val1, val2) MS_EXPAND(SIMD_ADD_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// sub (float/int) op +#define SIMD_SUB_F32 MS_SIMD_INSTRUCTION_F32(MS_SUB) +#define SIMD_SUB_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_SUB) +#define SIMD_SUB_N_F32(val1, val2) MS_EXPAND(SIMD_SUB_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_SUB_N_EPI32(val1, val2) MS_EXPAND(SIMD_SUB_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// div (float/int) op +#define SIMD_DIV_F32 MS_SIMD_INSTRUCTION_F32(MS_DIV) +#define SIMD_DIV_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_DIV) +#define SIMD_DIV_N_F32(val1, val2) MS_EXPAND(SIMD_DIV_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_DIV_N_EPI32(val1, val2) MS_EXPAND(SIMD_DIV_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// sqrt (float) op +#define SIMD_SQRT_F32 MS_SIMD_INSTRUCTION_F32(MS_SQRT) + +// rsqrt (float) op +#define SIMD_RSQRT_F32 MS_SIMD_INSTRUCTION_F32(MS_RSQRT) + +// log (float) op +#define SIMD_LOG_F32 MS_SIMD_INSTRUCTION(MS, _LOG_F32) + +// cos (float) op +#define SIMD_COS_F32 MS_SIMD_INSTRUCTION_F32(MS_COS) + +// sin (float) op +#define SIMD_SIN_F32 MS_SIMD_INSTRUCTION_F32(MS_SIN) + +// erf (float) op +#define SIMD_ERF_F32 MS_SIMD_INSTRUCTION(MS, _ERF_F32) + +// abs (float) op +#define SIMD_ABS_F32 MS_SIMD_INSTRUCTION_F32(MS_ABS) +#define SIMD_ABS_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_ABS) + +// round (float) op +#define SIMD_ROUND_F32 MS_SIMD_INSTRUCTION_F32(MS_ROUND) + +// ceil (float) op +#define SIMD_CEIL_F32 MS_SIMD_INSTRUCTION_F32(MS_CEIL) + +// floor (float) op +#define SIMD_FLOOR_F32 MS_SIMD_INSTRUCTION_F32(MS_FLOOR) + +// tanh (float) op +#define SIMD_TANH_F32 MS_SIMD_INSTRUCTION_F32(MS_TANH) + +// min (float/int) op +#define SIMD_MIN_F32 MS_SIMD_INSTRUCTION_F32(MS_MIN) +#define SIMD_MIN_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MIN) +#define SIMD_MIN_N_F32(val1, val2) MS_EXPAND(SIMD_MIN_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_MIN_N_EPI32(val1, val2) MS_EXPAND(SIMD_MIN_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// max (float/int) op +#define SIMD_MAX_F32 MS_SIMD_INSTRUCTION_F32(MS_MAX) +#define SIMD_MAX_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MAX) +#define SIMD_MAX_N_F32(val1, val2) MS_EXPAND(SIMD_MAX_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_MAX_N_EPI32(val1, val2) MS_EXPAND(SIMD_MAX_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// get max (float/int) op +#define SIMD_GET_MAX_F32 MS_SIMD_INSTRUCTION_F32(MS_GET_MAX) +#define SIMD_GET_MAX_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_GET_MAX) + +// get max (float/int) op +#define SIMD_GET_SUM_F32 MS_SIMD_INSTRUCTION_F32(MS_GET_SUM) +#define SIMD_REDUCE_ADD_F32 MS_SIMD_INSTRUCTION(MS_REDUCE_ADD, _F32) + +// clamp (float/int) op +#define SIMD_CLAMP_F32(val, min_val, max_val) SIMD_MIN_F32(SIMD_MAX_F32(val, min_val), max_val) +#define SIMD_CLAMP_EPI32(val, min_val, max_val) SIMD_MIN_EPI32(SIMD_MAX_EPI32(val, min_val), max_val) +#define SIMD_CLAMP_N_F32(val, min_val, max_val) \ + SIMD_MIN_F32(SIMD_MAX_F32(val, SIMD_MOV_F32(min_val)), SIMD_MOV_F32(max_val)) +#define SIMD_CLAMP_N_EPI32(val, min_val, max_val) \ + SIMD_MIN_EPI32(SIMD_MAX_EPI32(val, SIMD_MOV_EPI32(min_val)), SIMD_MOV_EPI32(max_val)) + +// mul (float/int) op +#define SIMD_MUL_F32 MS_SIMD_INSTRUCTION_F32(MS_MUL) +#define SIMD_MUL_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MUL) +#define SIMD_MUL_N_F32(val1, val2) MS_EXPAND(SIMD_MUL_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_MUL_N_EPI32(val1, val2) MS_EXPAND(SIMD_MUL_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// pow (float) op +#define SIMD_POW_F32 MS_SIMD_INSTRUCTION_F32(MS_POW) + +// fma (float/int) op +#define SIMD_FMADD_F32 MS_SIMD_INSTRUCTION_F32(MS_FMADD) + +// fms (float/int) op +#define SIMD_FMSUB_F32 MS_SIMD_INSTRUCTION_F32(MS_FMSUB) + +// fsm (float) op +#define MS_FSMUL_F32 MS_SIMD_INSTRUCTION_F32(MS_FSMUL) + +// square (float/int) op +#define SIMD_MUL_SQUARE_F32(val1) SIMD_MUL_F32(val1, val1) +#define SIMD_MUL_SQUARE_EPI32(val1) SIMD_MUL_EPI32(val1, val1) + +// exp (float) op +#define SIMD_EXP_ST_F32 MS_SIMD_INSTRUCTION(simd_exp, ) +#define SIMD_EXP_F32 MS_SIMD_INSTRUCTION(simd_exp, _f32) +// exp (float) high precision but a little slow op. +#define SIMD_HEXP_F32 MS_SIMD_INSTRUCTION(simd_hexp, _f32) + +// cmp (float/int) op +#define SIMD_CMPLT_F32 MS_SIMD_INSTRUCTION_F32(MS_CMPLT) +#define SIMD_CMPLE_F32 MS_SIMD_INSTRUCTION_F32(MS_CMPLE) +#define SIMD_CMPGT_F32 MS_SIMD_INSTRUCTION_F32(MS_CMPGT) +#define SIMD_BLEND_F32 MS_SIMD_INSTRUCTION_F32(MS_BLEND) + +// cast data +#define MS_CAST_F32_S32 MS_SIMD_INSTRUCTION(MS_CAST, _F32_S32) + +// logical op +#define SIMD_AND_MASK MS_SIMD_INSTRUCTION_MASK(MS_AND) +#define SIMD_OR_F32 MS_SIMD_INSTRUCTION_F32(MS_OR) +#define SIMD_AND_MASK_F32 MS_SIMD_INSTRUCTION(MS_AND, _MASK_F32) +#define SIMD_AND_F32 MS_SIMD_INSTRUCTION_F32(MS_AND) + +#define SIMD_GETSIGN_F32(src) \ + SIMD_OR_F32(SIMD_AND_F32(src, MS_CAST_F32_S32(SIMD_MOV_EPI32(0x80000000))), \ + MS_CAST_F32_S32(SIMD_MOV_EPI32(0x3F800000))) + +// int32/float mutual conversion +#define SIMD_EPI32_TO_F32 MS_SIMD_INSTRUCTION(MS, _INT32_TO_FLOAT32) +#define SIMD_F32_TO_EPI32 MS_SIMD_INSTRUCTION(MS, _FLOAT32_TO_INT32) +#define SIMD_F16_TO_F32 MS_SIMD_INSTRUCTION(MS, _FLOAT16_TO_FLOAT32) +#define SIMD_F32_TO_F16 MS_SIMD_INSTRUCTION(MS, _FLOAT32_TO_FLOAT16) + +// enable avx512 +#if defined(ENABLE_AVX512) +#define SIMD_RUN_AVX512(function, index, ...) \ + do { \ + AVX512_HARDWARE_SELF_AWARENESS_BEGIN \ + index = function##AVX512(index, __VA_ARGS__); \ + AVX512_HARDWARE_SELF_AWARENESS_END \ + } while (0) +#else +#define SIMD_RUN_AVX512(function, index, ...) +#endif + +// enable avx256 +#if defined(ENABLE_AVX) +#define SIMD_RUN_AVX(function, index, ...) index = function##AVX(index, __VA_ARGS__) +#else +#define SIMD_RUN_AVX(function, index, ...) +#endif + +// enable sse +#if defined(ENABLE_SSE) +#define SIMD_RUN_SSE(function, index, ...) index = function##SSE(index, __VA_ARGS__) +#else +#define SIMD_RUN_SSE(function, index, ...) +#endif + +// enable neon +#if defined(ENABLE_NEON) +#define SIMD_RUN_NEON(function, index, ...) index = function##NEON(index, __VA_ARGS__) +#else +#define SIMD_RUN_NEON(function, index, ...) +#endif + +#define SIMD_RUN_NO_SCALAR(function, index, ...) \ + do { \ + SIMD_RUN_AVX512(function, index, __VA_ARGS__); \ + SIMD_RUN_AVX(function, index, __VA_ARGS__); \ + SIMD_RUN_SSE(function, index, __VA_ARGS__); \ + SIMD_RUN_NEON(function, index, __VA_ARGS__); \ + } while (0) + +#define SIMD_RUN_X86_NO_SCALAR(function, index, ...) \ + do { \ + SIMD_RUN_AVX512(function, index, __VA_ARGS__); \ + SIMD_RUN_AVX(function, index, __VA_ARGS__); \ + SIMD_RUN_SSE(function, index, __VA_ARGS__); \ + } while (0) + +#define SIMD512_BLOCK16 32 // SIMD : 512 = 16 x 32 +#define SIMD256_BLOCK16 16 // SIMD : 256 = 16 x 16 +#define SIMD128_BLOCK16 8 // SIMD : 128 = 16 x 8 + +#define SIMD512_BLOCK32 16 // SIMD : 512 = 32 x 16 +#define SIMD256_BLOCK32 8 // SIMD : 256 = 32 x 8 +#define SIMD128_BLOCK32 4 // SIMD : 128 = 32 x 4 + +#define SIMD512_BLOCK64 8 // SIMD : 512 = 64 x 8 +#define SIMD256_BLOCK64 4 // SIMD : 256 = 64 x 4 +#define SIMD128_BLOCK64 2 // SIMD : 128 = 64 x 2 + +#define MS_EXPAND(...) __VA_ARGS__ + +// Scaler +#define MS_FLOAT32X1 float +#define MS_INT32X1 int +#define MS_MOV32_F32(value) (value) +#define MS_MOV32_EPI32(value) (value) +#define MS_LD32_F32(address) (*(address)) +#define MS_LD32_EPI32(address) (*(address)) +#define MS_ST32_F32(address, value) (*(address) = (value)) +#define MS_ST32_EPI32(address, value) (*(address) = (value)) +#define MS_ADD32_F32(value1, value2) ((value1) + (value2)) +#define MS_ADD32_EPI32(value1, value2) ((value1) + (value2)) +#define MS_SUB32_F32(value1, value2) ((value1) - (value2)) +#define MS_SUB32_EPI32(value1, value2) ((value1) - (value2)) +#define MS_MUL32_F32(value1, value2) ((value1) * (value2)) +#define MS_MUL32_EPI32(value1, value2) ((value1) * (value2)) +#define MS_DIV32_F32(value1, value2) ((value1) / (value2)) +#define MS_DIV32_EPI32(value1, value2) ((value1) / (value2)) +#define MS_MIN32_F32(value1, value2) (fmin((value1), (value2))) +#define MS_MIN32_EPI32(value1, value2) ((value1) < (value2) ? (value1) : (value2)) +#define MS_MAX32_F32(value1, value2) (fmax((value1), (value2))) +#define MS_MAX32_EPI32(value1, value2) ((value1) > (value2) ? (value1) : (value2)) +#define MS_SQRT32_F32(value) (sqrt(value)) + +static inline float simd_exp32_f32(float data) { + typedef union { + float f; + int i; + } fi; + static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; // Approximate calculation param +#ifdef _WIN32 + if (data < -88.0f) { + return 0.0f; + } else if (data > 88.0f) { + return 1.6516363e+38; // e^88 = 1.6516363e+38 + } +#else + data = + MS_MAX32_F32(-87.3365478515625f, MS_MIN32_F32(88.72283935546875f, data)); // clamp(logf(FLT_MIN), logf(FLT_MAX)) +#endif + int integer = floor(data * 1.44269504088896341f + 0.5f); + float decimal = data - integer * param[0]; + fi int_exp; + const int shift = 23; + const int bias = 126; + const float factor = 2; + // 2^n * exp(r) should be counted 2 * 2^(n - 1) * exp(r), + // because n may be 128, and it is not representable by fp32. + int_exp.i = (integer + bias) << shift; // integer num 2^(n - 1) approximate calculation : ((x - 1) + 127) << 23 + // Approximate calculation + const float decimal_exp = + 1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); + return factor * int_exp.f * decimal_exp; +} + +// exp(x) = exp(n * ln(2) + r) = 2^n * exp(r) = 2 * 2^(n - 1) * exp(r) +static inline void simd_exp32(float src, float *dst) { + typedef union { + float f; + int i; + } fi; + static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; // log(2.0f) + src = MS_MAX32_F32(-87.3365478515625f, MS_MIN32_F32(88.72283935546875f, src)); // clamp(logf(FLT_MIN), logf(FLT_MAX)) + int integer = floor(src * 1.44269504088896341f + 0.5f); + float decimal = src - integer * param[0]; + fi int_exp; + const int shift = 23; + const int bias = 126; + const float factor = 2; + // 2^n * exp(r) should be counted 2 * 2^(n - 1) * exp(r), + // because n may be 128, and it is not representable by fp32. + int_exp.i = (integer + bias) << shift; // integer num 2^(n - 1) approximate calculation : ((x - 1) + 127) << 23 + const float decimal_exp = + 1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); + *dst = factor * int_exp.f * decimal_exp; +} + +// define (float/int) data +#define MS_FLOAT_32xN(byte_num) MS_FLOAT32##X##byte_num +#define MS_INT_32xN(byte_num) MS_INT32##X##byte_num + +// move (float/int) data +#define MS_MOVN_F32(byte_num, ...) MS_EXPAND(MS_MOV##byte_num##_F32(__VA_ARGS__)) +#define MS_MOVN_EPI32(byte_num, ...) MS_EXPAND(MS_MOV##byte_num##_EPI32(__VA_ARGS__)) + +// load (float/int) data +#define MS_LD_F32(bit_num, ...) MS_EXPAND(MS_LD##bit_num##_F32(__VA_ARGS__)) +#define MS_LD_EPI32(bit_num, ...) MS_EXPAND(MS_LD##bit_num##_EPI32(__VA_ARGS__)) + +// load 4 (float/int) data +#define MS_LDX4_F32(bit_num, ...) MS_EXPAND(MS_LOAD##bit_num##X4_F32(__VA_ARGS__)) +#define MS_LDX4_EPI32(bit_num, ...) MS_EXPAND(MS_LOAD##bit_num##X4_EPI32(__VA_ARGS__)) + +// stored (float/int) data +#define MS_ST_F32(bit_num, ...) MS_EXPAND(MS_ST##bit_num##_F32(__VA_ARGS__)) +#define MS_ST_EPI32(bit_num, ...) MS_EXPAND(MS_ST##bit_num##_EPI32(__VA_ARGS__)) + +// add (float/int) op +#define MS_ADD_F32(bit_num, ...) MS_EXPAND(MS_ADD##bit_num##_F32(__VA_ARGS__)) +#define MS_ADD_EPI32(bit_num, ...) MS_EXPAND(MS_ADD##bit_num##_EPI32(__VA_ARGS__)) +#define MS_ADD_N_F32(bit_num, val1, val2) MS_EXPAND(MS_ADD##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) +#define MS_ADD_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_ADD##bit_num##_EPI32(val1, MS_MOV##bit_num##_F32(val2))) + +// sub (float/int) op +#define MS_SUB_F32(bit_num, ...) MS_EXPAND(MS_SUB##bit_num##_F32(__VA_ARGS__)) +#define MS_SUB_EPI32(bit_num, ...) MS_EXPAND(MS_SUB##bit_num##_EPI32(__VA_ARGS__)) +#define MS_SUB_N_F32(bit_num, val1, val2) MS_EXPAND(MS_SUB##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) +#define MS_SUB_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_SUB##bit_num##_EPI32(val1, MS_MOV##bit_num##_F32(val2))) + +// div (float/int) op +#define MS_DIV_F32(bit_num, ...) MS_EXPAND(MS_DIV##bit_num##_F32(__VA_ARGS__)) +#define MS_DIV_EPI32(bit_num, ...) MS_EXPAND(MS_DIV##bit_num##_EPI32(__VA_ARGS__)) +#define MS_DIV_N_F32(bit_num, val1, val2) MS_EXPAND(MS_DIV##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) +#define MS_DIV_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_DIV##bit_num##_EPI32(val1, MS_MOV##bit_num##_EPI32(val2))) + +// sqrt (float) op +#define MS_SQRT_F32(bit_num, ...) MS_EXPAND(MS_SQRT##bit_num##_F32(__VA_ARGS__)) + +// rsqrt (float) op +#define MS_RSQRT_F32(bit_num, ...) MS_EXPAND(MS_RSQRT##bit_num##_F32(__VA_ARGS__)) + +// log (float) op +#define MS_LOG_F32(bit_num, ...) MS_EXPAND(MS_LOG##bit_num##_F32(__VA_ARGS__)) + +// cos (float) op +#define MS_COS_F32(bit_num, ...) MS_EXPAND(MS_COS##bit_num##_F32(__VA_ARGS__)) + +// sin (float) op +#define MS_SIN_F32(bit_num, ...) MS_EXPAND(MS_SIN##bit_num##_F32(__VA_ARGS__)) + +// erf (float) op +#define MS_ERF_F32(bit_num, ...) MS_EXPAND(MS_ERF##bit_num##_F32(__VA_ARGS__)) + +// log (float) op +#define MS_ABS_F32(bit_num, ...) MS_EXPAND(MS_ABS##bit_num##_F32(__VA_ARGS__)) + +// round (float) op +#define MS_ROUND_F32(bit_num, ...) MS_EXPAND(MS_ROUND##bit_num##_F32(__VA_ARGS__)) + +// ceil (float) op +#define MS_CEIL_F32(bit_num, ...) MS_EXPAND(MS_CEIL##bit_num##_F32(__VA_ARGS__)) + +// floor (float) op +#define MS_FLOOR_F32(bit_num, ...) MS_EXPAND(MS_FLOOR##bit_num##_F32(__VA_ARGS__)) + +// min (float/int) op +#define MS_MIN_F32(bit_num, ...) MS_EXPAND(MS_MIN##bit_num##_F32(__VA_ARGS__)) +#define MS_MIN_EPI32(bit_num, ...) MS_EXPAND(MS_MIN##bit_num##_EPI32(__VA_ARGS__)) +#define MS_MIN_N_F32(bit_num, val, n) MS_MIN_F32(bit_num, val, MS_MOVN_F32(bit_num, n)) +#define MS_MIN_N_EPI32(bit_num, val, n) MS_MIN_EPI32(bit_num, val, MS_MOVN_EPI32(bit_num, n)) + +// max (float/int) op +#define MS_MAX_F32(bit_num, ...) MS_EXPAND(MS_MAX##bit_num##_F32(__VA_ARGS__)) +#define MS_MAX_EPI32(bit_num, ...) MS_EXPAND(MS_MAX##bit_num##_EPI32(__VA_ARGS__)) + +// get max (float/int) op +#define MS_GET_MAX_F32(bit_num, ...) MS_EXPAND(MS_GET_MAX##bit_num##_F32(__VA_ARGS__)) +#define MS_GET_MAX_EPI32(bit_num, ...) MS_EXPAND(MS_GET_MAX##bit_num##_EPI32(__VA_ARGS__)) + +// get max (float/int) op +#define MS_GET_SUM_F32(bit_num, ...) MS_EXPAND(MS_GET_SUM##bit_num##_F32(__VA_ARGS__)) + +// max n (float/int) op +#define MS_MAX_N_F32(bit_num, val, n) MS_MAX_F32(bit_num, val, MS_MOVN_F32(bit_num, n)) +#define MS_MAX_N_EPI32(bit_num, val, n) MS_MAX_EPI32(bit_num, val, MS_MOVN_EPI32(bit_num, n)) +#define MS_CLAMP_F32(bit_num, val, min_val, max_val) MS_MIN_F32(bit_num, MS_MAX_F32(bit_num, val, min_val), max_val) +#define MS_CLAMP_EPI32(bit_num, val, min_val, max_val) \ + MS_MIN_EPI32(bit_num, MS_MAX_EPI32(bit_num, val, min_val), max_val) + +// clamp n (float/int) op +#define MS_CLAMP_N_F32(bit_num, val, min_val, max_val) \ + MS_MIN_F32(bit_num, MS_MAX_F32(bit_num, val, MS_MOV##bit_num##_F32(min_val)), MS_MOV##bit_num##_F32(max_val)) +#define MS_CLAMP_N_EPI32(bit_num, val, min_val, max_val) \ + MS_MIN_EPI32(bit_num, MS_MAX_EPI32(bit_num, val, MS_MOV##bit_num##_EPI32(min_val)), MS_MOV##bit_num##_EPI32(max_val)) + +// mul (float/int) op +#define MS_MUL_F32(bit_num, ...) MS_EXPAND(MS_MUL##bit_num##_F32(__VA_ARGS__)) +#define MS_MUL_EPI32(bit_num, ...) MS_EXPAND(MS_MUL##bit_num##_EPI32(__VA_ARGS__)) +#define MS_MUL_N_F32(bit_num, val1, val2) MS_EXPAND(MS_MUL##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) +#define MS_MUL_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_MUL##bit_num##_EPI32(val1, MS_MOV##bit_num##_EPI32(val2))) + +// fma (float/int) op +#define MS_FMADD_F32(bit_num, ...) MS_EXPAND(MS_FMADD##bit_num##_F32(__VA_ARGS__)) +#define MS_FMADD_N_F32(bit_num, val1, val2) MS_EXPAND(MS_FMADD##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) + +// fms (float/int) op +#define MS_FMSUB_F32(bit_num, ...) MS_EXPAND(MS_FMSUB##bit_num##_F32(__VA_ARGS__)) +#define MS_FMSUB_N_F32(bit_num, val1, val2) MS_EXPAND(MS_FMSUB##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) + +// square (float/int) op +#define MS_MUL_SQUARE_F32(bit_num, val) MS_EXPAND((MS_MUL##bit_num##_F32(val, val))) +#define MS_MUL_SQUARE_EPI32(bit_num, val) MS_EXPAND((MS_MUL##bit_num##_EPI32(val, val))) + +// exp (float) op +#define MS_EXP_ST_F32(bit_num, ...) MS_EXPAND((simd_exp##bit_num(__VA_ARGS__))) +#define MS_EXP_F32(bit_num, ...) MS_EXPAND((simd_exp##bit_num##_f32(__VA_ARGS__))) + +#define MS_CMPLT_F32(bit_num, ...) MS_EXPAND((MS_CMPLT##bit_num##_F32(__VA_ARGS__))) +#define MS_CMPLE_F32(bit_num, ...) MS_EXPAND((MS_CMPLE##bit_num##_F32(__VA_ARGS__))) +#define MS_CMPGT_F32(bit_num, ...) MS_EXPAND((MS_CMPGT##bit_num##_F32(__VA_ARGS__))) +#define MS_BLEND_F32(bit_num, ...) MS_EXPAND((MS_BLEND##bit_num##_F32(__VA_ARGS__))) + +#define MS_INT16_TO_FLOAT16(bit_num, ...) MS_EXPAND((MS##bit_num##_INT16_TO_FLOAT16(__VA_ARGS__))) +#define MS_FLOAT16_TO_INT16(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT16_TO_INT16(__VA_ARGS__))) + +#define MS_INT32_TO_FLOAT16(bit_num, ...) MS_EXPAND((MS##bit_num##_INT32_TO_FLOAT16(__VA_ARGS__))) +#define MS_FLOAT16_TO_INT32(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT16_TO_INT32(__VA_ARGS__))) + +#define MS_INT32_TO_FLOAT32(bit_num, ...) MS_EXPAND((MS##bit_num##_INT32_TO_FLOAT32(__VA_ARGS__))) +#define MS_FLOAT32_TO_INT32(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT32_TO_INT32(__VA_ARGS__))) + +#define MS_INT64_TO_FLOAT32(bit_num, ...) MS_EXPAND((MS##bit_num##_INT64_TO_FLOAT32(__VA_ARGS__))) +#define MS_FLOAT32_TO_INT64(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT32_TO_INT64(__VA_ARGS__))) + +#define MS_INT64_TO_FLOAT16(bit_num, ...) MS_EXPAND((MS##bit_num##_INT64_TO_FLOAT16(__VA_ARGS__))) +#define MS_FLOAT16_TO_INT64(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT16_TO_INT64(__VA_ARGS__))) + +#define MS_INT32_TO_FLOAT64(bit_num, ...) MS_EXPAND((MS##bit_num##_INT32_TO_FLOAT64(__VA_ARGS__))) +#define MS_FLOAT64_TO_INT32(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT64_TO_INT32(__VA_ARGS__))) + +#define MS_INT64_TO_FLOAT64(bit_num, ...) MS_EXPAND((MS##bit_num##_INT64_TO_FLOAT64(__VA_ARGS__))) +#define MS_FLOAT64_TO_INT64(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT64_TO_INT64(__VA_ARGS__))) + +// enable avx512 +#if defined(ENABLE_AVX512) +#define MS_SIMD_RUN_AVX512(function, ...) MS_EXPAND(function(512, 16, __VA_ARGS__)) +#else +#define MS_SIMD_RUN_AVX512(function, ...) +#endif + +// enable avx256 +#if defined(ENABLE_AVX) +#define MS_SIMD_RUN_AVX(function, ...) MS_EXPAND(function(256, 8, __VA_ARGS__)) +#else +#define MS_SIMD_RUN_AVX(function, ...) +#endif + +// enable sse +#if defined(ENABLE_SSE) +#define MS_SIMD_RUN_SSE(function, ...) MS_EXPAND(function(128, 4, __VA_ARGS__)) +#else +#define MS_SIMD_RUN_SSE(function, ...) +#endif + +// enable neon +#if defined(ENABLE_NEON) +#define MS_SIMD_RUN_NEON(function, ...) MS_EXPAND(function(128, 4, __VA_ARGS__)) +#else +#define MS_SIMD_RUN_NEON(function, ...) +#endif + +// enable neon/sse +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) +#define MS_SIMD_RUN_SSEORNEON128(function, ...) MS_EXPAND(function(128, 4, __VA_ARGS__)) +#else +#define MS_SIMD_RUN_SSEORNEON128(function, ...) +#endif + +// scalar (c style data) +#define MS_SIMD_RUN_SCALAR(function, ...) MS_EXPAND(function(32, 1, __VA_ARGS__)) + +#define MS_SIMD_RUN(function, ...) \ + do { \ + MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \ + MS_SIMD_RUN_AVX(function, __VA_ARGS__); \ + MS_SIMD_RUN_SSEORNEON128(function, __VA_ARGS__); \ + MS_SIMD_RUN_SCALAR(function, __VA_ARGS__); \ + } while (0) + +#define MS_SIMD_RUN_NO_SCALAR(function, ...) \ + do { \ + MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \ + MS_SIMD_RUN_AVX(function, __VA_ARGS__); \ + MS_SIMD_RUN_SSEORNEON128(function, __VA_ARGS__); \ + } while (0) + +#define MS_SIMD_RUN_X86(function, ...) \ + do { \ + MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \ + MS_SIMD_RUN_AVX(function, __VA_ARGS__); \ + MS_SIMD_RUN_SSE(function, __VA_ARGS__); \ + MS_SIMD_RUN_SCALAR(function, __VA_ARGS__); \ + } while (0) + +#define MS_SIMD_RUN_X86_NO_SCALAR(function, ...) \ + do { \ + MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \ + MS_SIMD_RUN_AVX(function, __VA_ARGS__); \ + MS_SIMD_RUN_SSE(function, __VA_ARGS__); \ + } while (0) + +#endif // NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..f5536771f636e08d5fc25a285498948101436e04 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions_fp16.h @@ -0,0 +1,162 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ +#define NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ +#include +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +#if defined(ENABLE_ARM82_A32) +static inline float16x8_t ms_vdivq_f16(float16x8_t in1, float16x8_t in2) { + float16x8_t dst; + asm volatile( + "vrecpe.f16 q14, %3\n" + "vrecps.f16 q15, %3, q14\n" + "vmul.f16 q14, q15, q14\n" + "vrecps.f16 q15, %3, q14\n" + "vmul.f16 q14, q15, q14\n" + "vmul.f16 %0, %2, q14\n" + : "=w"(dst) + : "0"(dst), "w"(in1), "w"(in2) + : "q14", "q15"); + return dst; +} + +static inline float16x4_t ms_vdiv_f16(float16x4_t in1, float16x4_t in2) { + float16x4_t dst; + asm volatile( + "vrecpe.f16 d14, %3\n" + "vrecps.f16 d16, %3, d14\n" + "vmul.f16 d14, d16, d14\n" + "vrecps.f16 d16, %3, d14\n" + "vmul.f16 d14, d16, d14\n" + "vmul.f16 %0, %2, d14\n" + : "=w"(dst) + : "0"(dst), "w"(in1), "w"(in2) + : "d14", "d16"); + return dst; +} + +static inline float ms_vaddvq_f32(float32x4_t in) { + // is not support in arm82 aarch32 and there is no assembly instruction to process all the data + return in[0] + in[1] + in[2] + in[3]; +} + +static inline float16_t ms_vmaxvq_f16(float16x8_t in) { + // is not support in arm82 aarch32 and there is no assembly instruction to process all the data + float16_t dst = in[0]; + for (int i = 1; i < 8; ++i) { + dst = dst > in[i] ? dst : in[i]; + } + return dst; +} + +static inline float32x4_t ms_vcvt_f32_f16(float16x4_t in) { + float32x4_t dst; + asm volatile("vcvt.f32.f16 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :); + return dst; +} + +static inline float16x4_t ms_vcvt_f16_f32(float32x4_t in) { + float16x4_t dst; + asm volatile("vcvt.f16.f32 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :); + return dst; +} + +#define MS_CVT_F32_F16(src) ms_vcvt_f32_f16(src) +#define MS_CVT_F16_F32(src) ms_vcvt_f16_f32(src) +#define MS_DIV_F16(src1, src2) ms_vdiv_f16(src1, src2) +#define MS_DIVQ_F16(src1, src2) ms_vdivq_f16(src1, src2) +#define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_f16(src1, src2, vdupq_n_f16(src3)) +#define MS_MAXVQ_F16(src) ms_vmaxvq_f16(src) +#define MS_ADDVQ_F32(src) ms_vaddvq_f32(src) +#else +#define MS_CVT_F32_F16(src) vcvt_f32_f16(src) +#define MS_CVT_F16_F32(src) vcvt_f16_f32(src) +#define MS_DIV_F16(src1, src2) vdiv_f16(src1, src2) +#define MS_DIVQ_F16(src1, src2) vdivq_f16(src1, src2) +#define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_n_f16(src1, src2, src3) +#define MS_MAXVQ_F16(src) vmaxvq_f16(src) +#define MS_ADDVQ_F32(src) vaddvq_f32(src) +#endif + +#define MS_FLOAT16X8 float16x8_t +#define MS_FLOAT16X4 float16x4_t +#define MS_FLOAT16X4X4 float16x4x4_t +#define MS_FLOAT16X4X2 float16x4x2_t +#define MS_MOVQ_F16 vmovq_n_f16 +#define MS_STQ_F16(ptr, val) vst1q_f16(ptr, val) +#define MS_ST_F16 vst1_f16 +#define MS_ST2_F16 vst2_f16 +#define MS_ST4_F16 vst4_f16 +#define MS_MINQ_F16 vminq_f16 +#define MS_MAXQ_F16 vmaxq_f16 +#define MS_LDQ_F16(ptr) vld1q_f16(ptr) +#define MS_LD_F16(ptr) vld1_f16(ptr) +#define MS_ADDQ_F16 vaddq_f16 +#define MS_SUBQ_F16 vsubq_f16 +#define MS_MULQ_F16 vmulq_f16 +#define MS_FMAQ_F16 vfmaq_f16 +#define MS_MULQ_N_F16(vector, scalar) vmulq_n_f16(vector, scalar) +#define MS_CMPGTQ_F16(src1, src2) vcgtq_f16(src1, src2) + +static inline float16x8_t MS_TANHX8_F16(float16x8_t src) { + float32x4_t src_low = MS_CVT_F32_F16(vget_low_f16(src)); + float32x4_t src_high = MS_CVT_F32_F16(vget_high_f16(src)); + return vcombine_f16(MS_CVT_F16_F32(MS_TANHX4_F32(src_low)), MS_CVT_F16_F32(MS_TANHX4_F32(src_high))); +} + +static inline float16x8_t MS_ERFX8_F16(float16x8_t src) { + float16x8_t dst; + dst[0] = erff(src[0]); + dst[1] = erff(src[1]); + dst[2] = erff(src[2]); + dst[3] = erff(src[3]); + dst[4] = erff(src[4]); + dst[5] = erff(src[5]); + dst[6] = erff(src[6]); + dst[7] = erff(src[7]); + return dst; +} + +static inline float16x8_t MS_SQRTFX8_F16(float16x8_t src) { + float16x8_t dst; + dst[0] = sqrtf(src[0]); + dst[1] = sqrtf(src[1]); + dst[2] = sqrtf(src[2]); + dst[3] = sqrtf(src[3]); + dst[4] = sqrtf(src[4]); + dst[5] = sqrtf(src[5]); + dst[6] = sqrtf(src[6]); + dst[7] = sqrtf(src[7]); + return dst; +} + +static inline float16x4_t MS_SQRTFX4_F16(float16x4_t src) { + float16x4_t dst; + dst[0] = sqrtf(src[0]); + dst[1] = sqrtf(src[1]); + dst[2] = sqrtf(src[2]); + dst[3] = sqrtf(src[3]); + return dst; +} + +static inline float32x4_t MS_VMLAL_F16(float16x4_t x, float16x4_t dy, float32x4_t sum) { + float32x4_t x_fp32 = MS_CVT_F32_F16(x); + float32x4_t dy_fp32 = MS_CVT_F32_F16(dy); + return vmlaq_f32(sum, x_fp32, dy_fp32); +} + +#endif // NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_neon_instructions.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_neon_instructions.h new file mode 100644 index 0000000000000000000000000000000000000000..53333c7f3338a541f3f4d4e6333e14b55d8a99c7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_neon_instructions.h @@ -0,0 +1,362 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_NEON_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#define NNACL_NEON_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include +#include + +#include + +#define MS_F32X4_GETI(src, i) src[i] +#define MS128_F32_GETI(src, i) src[i] +#define MS_FLOAT32X4 float32x4_t +#define MS_FLOAT32X4X2 float32x4x2_t +#define MS_FLOAT32X4X4 float32x4x4_t +#define MS_FLOAT128_F32 float32x4_t +#define MS_INT32X4 int32x4_t +#define MS_INT128_EPI32 int32x4_t +#define MS_UINT32X4 uint32x4_t +#define MS_MASK128_TYPE MS_UINT32X4 +#define MS_LDQ_F32 vld1q_f32 +#define MS_LD128_F32 vld1q_f32 +#define MS_LDQ_EPI32 vld1q_s32 +#define MS_LD128_EPI32 vld1q_s32 +#define MS_ADDQ_F32 vaddq_f32 +#define MS_ADD128_F32 vaddq_f32 +#define MS_ADDQ_EPI32 vaddq_s32 +#define MS_ADD128_EPI32 vaddq_s32 +#define MS_MOVQ_F32 vmovq_n_f32 +#define MS_MOV128_F32 vmovq_n_f32 +#define MS_MOVQ_EPI32 vmovq_n_s32 +#define MS_MOV128_VAL0_F32 vmovq_n_f32(0.0f) +#define MS_MOV128_EPI32 vmovq_n_s32 +#define MS_SUBQ_F32 vsubq_f32 +#define MS_SUB128_F32 vsubq_f32 +#define MS_SUB128_EPI32 vsubq_s32 +#define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3) +#define MS_STQ_F32 vst1q_f32 +#define MS_ST128_F32 vst1q_f32 +#define MS_STQ_EPI32 vst1q_s32 +#define MS_ST128_EPI32 vst1q_s32 +#define MS_MAXQ_F32 vmaxq_f32 +#define MS_MAXQ_EPI32 vmaxq_s32 +#define MS_MAX128_F32 vmaxq_f32 +#define MS_MAX128_EPI32 vmaxq_s32 +#define MS_MINQ_F32 vminq_f32 +#define MS_MINQ_EPI32 vminq_s32 +#define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2) +#define MS_MULQ_EPI32(src1, src2) vmulq_s32(src1, src2) +#define MS_MIN128_F32 vminq_f32 +#define MS_MIN128_EPI32 vminq_s32 +#define MS_MUL128_F32(src1, src2) vmulq_f32(src1, src2) +#define MS_MUL128_EPI32(src1, src2) vmulq_s32(src1, src2) +#define MS_FMADD128_F32(src1, src2, src3) vmlaq_f32(src3, src1, src2) +#define MS_FSMUL128_F32(src1, src2, src3) vmlsq_f32(src1, src2, src3) +#define MS_FMSUB128_EPI32(src1, src2, src3) vmlsq_s32(src3, src1, src2) +#ifdef ENABLE_ARM64 +#define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2) +#define MS_DIV128_F32(src1, src2) vdivq_f32(src1, src2) +#else +static inline float32x4_t vrecp(float32x4_t v) { + float32x4_t r = vrecpeq_f32(v); + r = vmulq_f32(vrecpsq_f32(v, r), r); + r = vmulq_f32(vrecpsq_f32(v, r), r); + return r; +} +#define MS_DIVQ_F32(src1, src2) vmulq_f32(src1, vrecp(src2)) +#define MS_DIV128_F32(src1, src2) vmulq_f32(src1, vrecp(src2)) +#endif +#define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2) +#define MS_MULQ_N_EPI32(src1, src2) vmulq_n_s32(src1, src2) +#define MS_DIVQ_N_F32(src1, src2) vdivq_n_f32(src1, src2) +#define MS_SLLIQ_EPI32(src1, src2) vshlq_s32(src1, vmovq_n_s32(src2)) +#define MS_CVTQPS_EPI32(src) vcvtq_s32_f32(src) +#define MS_CVTQEPI32_PS(src) vcvtq_f32_s32(src) +#define MS_CMPLEQ_F32(src1, src2) vcleq_f32(src1, src2) +#define MS_CMPGTQ_F32(src1, src2) vcgtq_f32(src1, src2) +#define MS_CMPGTQ_EPI32(src1, src2) vcgtq_s32(src1, src2) +#define MS_CMPLE128_F32(src1, src2) vcleq_f32(src1, src2) +#define MS_CMPLT128_F32(src1, src2) vcltq_f32(src1, src2) +#define MS_CMPGT128_F32(src1, src2) vcgtq_f32(src1, src2) +#define MS_CMPGT128_EPI32(src1, src2) vcgtq_s32(src1, src2) +// Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_f32 +#define MS_BLENDQ_F32(src1, src2, src3) vbslq_f32(src3, src2, src1) +#define MS_BLENDQ_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1) +#define MS_BLEND128_F32(src1, src2, src3) vbslq_f32(src3, src2, src1) +#define MS_BLEND128_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1) +#define MS_CAST128_F32_S32(src) vreinterpretq_f32_s32(src) +#define MS_AND128_MASK(src1, src2) vandq_u32(src1, src2) +#define MS_AND128_F32(src1, src2) \ + vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(src1), vreinterpretq_u32_f32(src2))) +#define MS_OR128_F32(src1, src2) \ + vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(src1), vreinterpretq_u32_f32(src2))) +#define MS_CAST128_U32_F32(src) vreinterpretq_u32_f32(src) +#define MS_CAST128_F32_U32(src) vreinterpretq_f32_u32(src) +#define MS_OR128_MASK(src1, src2) vorrq_u32(src1, src2) + +#ifdef ENABLE_ARM64 +#define MS_GET_MAX128_F32 vmaxvq_f32 +static inline float MS_GET_SUM128_F32(MS_FLOAT32X4 src) { return vaddvq_f32(src); } +#else +static inline float MS_GET_MAX128_F32(MS_FLOAT32X4 src) { + float result = MS_F32X4_GETI(src, 0); + for (int i = 1; i < 4; i++) { // neon block num : 4 + result = fmaxf(result, MS_F32X4_GETI(src, i)); + } + return result; +} + +static inline float MS_GET_SUM128_F32(MS_FLOAT32X4 src) { + float result = MS_F32X4_GETI(src, 0); + for (int i = 1; i < 4; i++) { // neon block num : 4 + result = result + MS_F32X4_GETI(src, i); + } + return result; +} +#endif + +static inline MS_FLOAT32X4 MS_AND128_MASK_F32(MS_UINT32X4 src1, MS_FLOAT32X4 src2) { + MS_FLOAT32X4 result; + result[0] = (src1[0] == 0) ? 0.0f : src2[0]; + result[1] = (src1[1] == 0) ? 0.0f : src2[1]; + result[2] = (src1[2] == 0) ? 0.0f : src2[2]; + result[3] = (src1[3] == 0) ? 0.0f : src2[3]; + return result; +} + +static inline int32x4_t MS_DIV128_EPI32(int32x4_t src1, int32x4_t src2) { + int32x4_t result; + result[0] = src1[0] / src2[0]; // C0 : 0 + result[1] = src1[1] / src2[1]; // C1 : 1 + result[2] = src1[2] / src2[2]; // C2 : 2 + result[3] = src1[3] / src2[3]; // C3 : 3 + return result; +} + +#define MS128_INT32_TO_FLOAT32(src) vcvtq_f32_s32(src) +#define MS128_FLOAT32_TO_INT32(src) vcvtq_s32_f32(src) + +static inline MS_FLOAT32X4 MS_POW128_F32(MS_FLOAT32X4 src1, MS_FLOAT32X4 src2) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = powf(MS_F32X4_GETI(src1, 0), MS_F32X4_GETI(src2, 0)); + MS_F32X4_GETI(dst, 1) = powf(MS_F32X4_GETI(src1, 1), MS_F32X4_GETI(src2, 1)); + MS_F32X4_GETI(dst, 2) = powf(MS_F32X4_GETI(src1, 2), MS_F32X4_GETI(src2, 2)); + MS_F32X4_GETI(dst, 3) = powf(MS_F32X4_GETI(src1, 3), MS_F32X4_GETI(src2, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS_ABS128_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = fabsf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = fabsf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = fabsf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = fabsf(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS128_LOG_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = logf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = logf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = logf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = logf(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS_SQRTFX4_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = sqrtf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = sqrtf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = sqrtf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = sqrtf(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS_SQRT128_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = sqrtf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = sqrtf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = sqrtf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = sqrtf(MS_F32X4_GETI(src, 3)); + return dst; +} +#define MS_RSQRT128_F32 vrsqrteq_f32 + +#define LOAD128X8_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \ + MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \ + MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \ + MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \ + MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num); + +#define STORE128X8_F32(output_ptr, num, dst) \ + MS_STQ_F32(output_ptr + 0 * num, dst##1); \ + MS_STQ_F32(output_ptr + 1 * num, dst##2); \ + MS_STQ_F32(output_ptr + 2 * num, dst##3); \ + MS_STQ_F32(output_ptr + 3 * num, dst##4); \ + MS_STQ_F32(output_ptr + 4 * num, dst##5); \ + MS_STQ_F32(output_ptr + 5 * num, dst##6); \ + MS_STQ_F32(output_ptr + 6 * num, dst##7); \ + MS_STQ_F32(output_ptr + 7 * num, dst##8); + +static inline MS_FLOAT32X4 VexpFp32(MS_FLOAT32X4 input) { + static MS_FLOAT32X4 param[] = { + {0.693147f, 0.693147f, 0.693147f, 0.693147f}, + {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, + {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, + {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, + {0.5f, 0.5f, 0.5f, 0.5f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f}, + {2.0f, 2.0f, 2.0f, 2.0f}}; + static MS_FLOAT32X4 negative_flag = {-0.0f, -0.0f, -0.0f, -0.0f}; + + MS_INT32X4 integer = + MS_CVTQPS_EPI32(MS_FMADD128_F32(input, param[6], MS_OR128_F32(MS_AND128_F32(input, negative_flag), param[4]))); + MS_FLOAT32X4 decimal = MS_SUBQ_F32(input, MS_MULQ_F32(MS_CVTQEPI32_PS(integer), param[0])); + MS_INT32X4 int_exp = MS_SLLIQ_EPI32(MS_ADDQ_EPI32(integer, MS_MOVQ_EPI32(126)), 23); + MS_FLOAT32X4 tmp = MS_MULQ_F32(decimal, (MS_ADDQ_F32(param[2], MS_MULQ_F32(decimal, param[1])))); + tmp = MS_MULQ_F32(decimal, MS_ADDQ_F32(param[4], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[3], tmp)))); + MS_FLOAT32X4 decimal_exp = MS_ADDQ_F32(param[5], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[5], tmp))); + return MS_MULQ_F32(param[7], MS_MULQ_F32(decimal_exp, MS_CAST128_F32_S32(int_exp))); +} + +static inline void simd_exp128(MS_FLOAT32X4 input, float *dst) { + static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); + MS_STQ_F32(dst, VexpFp32(input)); +} + +static inline MS_FLOAT32X4 simd_exp128_f32(MS_FLOAT32X4 input) { + static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); + return VexpFp32(input); +} + +static inline MS_FLOAT32X4 simd_hexp128_f32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = exp(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = exp(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = exp(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = exp(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) { + static const MS_FLOAT32X4 data0 = {378.0f, 378.0f, 378.0f, 378.0f}; + static const MS_FLOAT32X4 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f}; + static const MS_FLOAT32X4 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f}; + static const MS_FLOAT32X4 data3 = {28.0f, 28.0f, 28.0f, 28.0f}; + static const MS_FLOAT32X4 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f}; + static const MS_FLOAT32X4 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f}; + static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT32X4 up_limit = {5.0f, 5.0f, 5.0f, 5.0f}; + static const MS_FLOAT32X4 down_limit = {-5.0f, -5.0f, -5.0f, -5.0f}; + + MS_UINT32X4 up_mask = MS_CMPGTQ_F32(src, up_limit); + MS_UINT32X4 down_mask = MS_CMPGTQ_F32(down_limit, src); + + MS_FLOAT32X4 square = MS_MULQ_F32(src, src); + MS_FLOAT32X4 a = MS_MULQ_F32( + MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(square, data0), square), data1), square), data2), src); + MS_FLOAT32X4 b = MS_ADDQ_F32( + MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(data3, square), data4), square), data5), square), + data2); + + MS_FLOAT32X4 tanh_value = MS_DIVQ_F32(a, b); + MS_FLOAT32X4 res = MS_BLENDQ_F32(tanh_value, pos, up_mask); + res = MS_BLENDQ_F32(res, neg, down_mask); + return res; +} + +static inline MS_FLOAT128_F32 SIMD_SIGN128_F32(MS_FLOAT128_F32 src) { + MS_FLOAT128_F32 abs_src = MS_ABS128_F32(src); + MS_FLOAT128_F32 src_tmp = MS_OR128_F32(src, MS_MOV128_F32(1.0f)); + MS_FLOAT128_F32 sign = MS_DIV128_F32(abs_src, src_tmp); + return sign; +} + +static inline MS_FLOAT128_F32 SIMD_SIGNABS128_F32(MS_FLOAT128_F32 src, MS_FLOAT128_F32 abs_src) { + MS_FLOAT128_F32 src_tmp = MS_OR128_F32(src, MS_MOV128_F32(1.0f)); + return MS_DIV128_F32(abs_src, src_tmp); +} + +#define MS_TANH128_F32 MS_TANHX4_F32 + +static inline MS_FLOAT32X4 MS128_ERF_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = erff(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = erff(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = erff(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = erff(MS_F32X4_GETI(src, 3)); + return dst; +} + +#define MS_FMADD128X8_F32(src, weight, dst) \ + dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ + dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ + dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ + dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); \ + dst##5 = MS_MLAQ_F32(src##5, weight, dst##5); \ + dst##6 = MS_MLAQ_F32(src##6, weight, dst##6); \ + dst##7 = MS_MLAQ_F32(src##7, weight, dst##7); \ + dst##8 = MS_MLAQ_F32(src##8, weight, dst##8); + +#define MS_LOAD128X4_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); + +#define MS_FMADD128X4_F32(src, weight, dst) \ + dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ + dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ + dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ + dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); + +#define MS_LOAD128X8_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \ + MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \ + MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \ + MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \ + MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num); + +#define MS_SET_ZERO128X8_F32(dst) \ + MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##5 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##6 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##7 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##8 = MS_MOVQ_F32(0.0f); + +#define MS_SET_ZERO128X4_F32(dst) \ + MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); +#endif // NNACL_NEON_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_sse_instructions.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_sse_instructions.h new file mode 100644 index 0000000000000000000000000000000000000000..6eb07e256818385c221f939c10a4346ff6675ffc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_sse_instructions.h @@ -0,0 +1,403 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SSE_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#define NNACL_SSE_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include + +#ifdef _MSC_VER +#include +#define MS_F32X4_GETI(src, i) src.m128_f32[i] +#define MS128_F32_GETI(src, i) src.m128_f32[i] +#else +#include +#define MS_F32X4_GETI(src, i) src[i] +#define MS128_F32_GETI(src, i) src[i] +#endif + +#define PI 3.1415926f +#define LN2 0.693147f + +#define MS_FLOAT32X4 __m128 +#define MS_FLOAT128_F32 __m128 +#define MS_INT32X4 __m128i +#define MS_INT128_EPI32 __m128i +#define MS_MASK128_TYPE MS_FLOAT32X4 +#define MS_LDQ_F32 _mm_loadu_ps +#define MS_LD128_F32 _mm_loadu_ps +#define MS_LDQ_EPI32(src) _mm_loadu_si128((__m128i const *)(src)) +#define MS_LD128_EPI32(src) _mm_loadu_si128((__m128i const *)(src)) +#define MS_ADDQ_F32 _mm_add_ps +#define MS_ADD128_F32 _mm_add_ps +#define MS_ADDQ_EPI32 _mm_add_epi32 +#define MS_ADD128_EPI32 _mm_add_epi32 +#define MS_MOVQ_F32 _mm_set1_ps +#define MS_MOV128_F32 _mm_set1_ps +#define MS_MOVQ_EPI32 _mm_set1_epi32 +#define MS_MOV128_EPI32 _mm_set1_epi32 +#define MS_MOV128_VAL0_F32 _mm_setzero_ps() +#define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3)) +#define MS_STQ_F32 _mm_storeu_ps +#define MS_ST128_F32 _mm_storeu_ps +#define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2) +#define MS_ST128_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2) +#define MS_SUBQ_F32 _mm_sub_ps +#define MS_SUB128_F32 _mm_sub_ps +#define MS_SUB128_EPI32 _mm_sub_epi32 +#define MS_MAXQ_F32 _mm_max_ps +#define MS_MAXQ_EPI32 _mm_max_epi32 +#define MS_MAX128_F32 _mm_max_ps +#define MS_MAX128_EPI32 _mm_max_epi32 +#define MS_MINQ_F32 _mm_min_ps +#define MS_MINQ_EPI32 _mm_min_epi32 +#define MS_SQRT128_F32 _mm_sqrt_ps +#define MS_RSQRT128_F32 _mm_rsqrt_ps +#define MS_SIN128_F32 _mm_sin_ps +#define MS_ERF128_F32 _mm_erf_ps +#define MS_ROUND128_F32(src) _mm_round_ps(src, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) +#define MS_FLOOR128_F32 _mm_floor_ps +#define MS_CEIL128_F32 _mm_ceil_ps +#define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, src2) +#define MS_MULQ_EPI32(src1, src2) _mm_mullo_epi32(src1, src2) +#define MS_MIN128_F32 _mm_min_ps +#define MS_MIN128_EPI32 _mm_min_epi32 +#define MS_MUL128_F32(src1, src2) _mm_mul_ps(src1, src2) +#define MS_MUL128_EPI32(src1, src2) _mm_mullo_epi32(src1, src2) +#define MS_DIVQ_F32(src1, src2) _mm_div_ps(src1, src2) +#define MS_DIV128_F32(src1, src2) _mm_div_ps(src1, src2) +#define MS_MULQ_N_F32(src1, src2) _mm_mul_ps(src1, _mm_set1_ps(src2)) +#define MS_MULQ_N_EPI32(src1, src2) _mm_mullo_epi32(src1, _mm_set1_epi32(src2)) +#define MS_DIVQ_N_F32(src1, src2) _mm_div_ps(src1, _mm_set1_ps(src2)) +#define MS_SLLIQ_EPI32(src1, src2) _mm_slli_epi32(src1, src2) +#define MS_CVTQPS_EPI32(src) _mm_cvttps_epi32(src) // truncate float to int +#define MS_CVTQEPI32_PS(src) _mm_cvtepi32_ps(src) +#define MS_CMPLEQ_F32(src1, src2) _mm_cmple_ps(src1, src2) +#define MS_CMPGTQ_F32(src1, src2) _mm_cmpgt_ps(src1, src2) +#define MS_CMPGTQ_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2) +#define MS_BLENDQ_F32(src1, src2, src3) _mm_blendv_ps(src1, src2, src3) +#define MS_BLENDQ_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3) +#define MS_CMPLT128_F32(src1, src2) _mm_cmplt_ps(src1, src2) +#define MS_CMPLE128_F32(src1, src2) _mm_cmple_ps(src1, src2) +#define MS_CMPGT128_F32(src1, src2) _mm_cmpgt_ps(src1, src2) +#define MS_CMPEQ128_F32(src1, src2) _mm_cmpeq_ps(src1, src2) +#define MS_CMPUNORD128_F32(src1, src2) _mm_cmpunord_ps(src1, src2) +#define MS_CMPGT128_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2) +#define MS_BLEND128_F32(src1, src2, src3) _mm_blendv_ps(src1, src2, src3) +#define MS_BLEND128_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3) +#define MS_CAST128_F32_S32(src) _mm_castsi128_ps(src) +#define MS_DIV128_EPI32(src1, src2) _mm_cvttps_epi32(MS_DIV128_F32(_mm_cvtepi32_ps(src1), _mm_cvtepi32_ps(src2))) +#define MS_AND128_MASK(src1, src2) _mm_and_ps(src1, src2) +#define MS_OR128_F32(src1, src2) _mm_or_ps(src1, src2) +#define MS_AND128_MASK_F32(src1, src2) _mm_and_ps(src1, src2) +#define MS_AND128_F32(src1, src2) _mm_and_ps(src1, src2) + +#define MS128_ANDNOT_F32(src1, src2) _mm_andnot_ps(src1, src2) +#define MS128_SRLI_EPI32(src1, src2) _mm_srli_epi32(src1, src2) +#define MS128_AND_EPI32(src1, src2) _mm_and_si128(src1, src2) +#define MS128_CASTPS_EPI32(src) _mm_castps_si128(src) +#define MS_CVT128EPI32_PS(src) _mm_cvtepi32_ps(src) +#define MS_CAST128_F32_S32(src) _mm_castsi128_ps(src) + +static inline MS_FLOAT32X4 MS_POW128_F32(MS_FLOAT32X4 src1, MS_FLOAT32X4 src2) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = powf(MS_F32X4_GETI(src1, 0), MS_F32X4_GETI(src2, 0)); + MS_F32X4_GETI(dst, 1) = powf(MS_F32X4_GETI(src1, 1), MS_F32X4_GETI(src2, 1)); + MS_F32X4_GETI(dst, 2) = powf(MS_F32X4_GETI(src1, 2), MS_F32X4_GETI(src2, 2)); + MS_F32X4_GETI(dst, 3) = powf(MS_F32X4_GETI(src1, 3), MS_F32X4_GETI(src2, 3)); + return dst; +} + +#ifdef ENABLE_AVX // only enable sse, dont support fma instruction. +#define MS_FMADD128_F32(src1, src2, src3) _mm_fmadd_ps(src1, src2, src3) +#define MS_FMSUB128_F32(src1, src2, src3) _mm_fmsub_ps(src1, src2, src3) +#define MS_FSMUL128_F32(src1, src2, src3) _mm_fnmadd_ps(src3, src2, src1) +#else +#define MS_FMADD128_F32(src1, src2, src3) _mm_add_ps(_mm_mul_ps(src1, src2), src3) +#define MS_FMSUB128_F32(src1, src2, src3) _mm_sub_ps(_mm_mul_ps(src1, src2), src3) +#define MS_FSMUL128_F32(src1, src2, src3) _mm_sub_ps(src1, _mm_mul_ps(src2, src3)) +#endif + +#define MS128_INT16_TO_FLOAT16(src) _mm_cvtepi16_ph(src) +#define MS128_FLOAT16_TO_INT16(src) _mm_cvttph_epi16(src) + +#define MS128_INT32_TO_FLOAT16(src) _mm_cvtepi32_ph(src) +#define MS128_FLOAT16_TO_INT32(src) _mm_cvttph_epi32(src) + +#define MS128_INT32_TO_FLOAT32(src) _mm_cvtepi32_ps(src) +#define MS128_FLOAT32_TO_INT32(src) _mm_cvttps_epi32(src) + +#define MS128_INT64_TO_FLOAT32(src) _mm_cvtepi64_ps(src) +#define MS128_FLOAT32_TO_INT64(src) _mm_cvttps_epi64(src) + +#define MS128_INT64_TO_FLOAT16(src) _mm_cvtepi64_ph(src) +#define MS128_FLOAT16_TO_INT64(src) _mm_cvttph_epi64(src) + +#define MS128_INT32_TO_FLOAT64(src) _mm_cvtepi32_pd(src) +#define MS128_FLOAT64_TO_INT32(src) _mm_cvttpd_epi32(src) + +#define MS128_INT64_TO_FLOAT64(src) _mm_cvtepi64_pd(src) +#define MS128_FLOAT64_TO_INT64(src) _mm_cvttpd_epi64(src) + +#define MS128_INT16_TO_INT32(src) _mm128_cvtepi16_epi32(src) +#define MS128_INT16_TO_INT64(src) _mm128_cvtepi16_epi64(src) +#define MS128_INT32_TO_INT16(src) _mm128_cvtepi32_epi16(src) +#define MS128_INT32_TO_INT64(src) _mm128_cvtepi32_epi64(src) +#define MS128_INT64_TO_INT16(src) _mm128_cvtepi64_epi16(src) +#define MS128_INT64_TO_INT32(src) _mm128_cvtepi64_epi32(src) + +static inline MS_FLOAT32X4 MS_ABS128_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = fabsf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = fabsf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = fabsf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = fabsf(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT128_F32 SIMD_SIGN128_F32(MS_FLOAT128_F32 src) { + MS_FLOAT128_F32 abs_src = MS_ABS128_F32(src); + MS_FLOAT128_F32 sign = MS_DIV128_F32(abs_src, src); + return sign; +} + +#define SIMD_SIGNABS128_F32(src, abs_src) MS_DIV128_F32(abs_src, src) + +static inline MS_FLOAT32X4 MS_COS128_F32(MS_FLOAT32X4 src) { + static const MS_FLOAT32X4 pi = {PI, PI, PI, PI}; + static const MS_FLOAT32X4 pi2_neg = {-2 * PI, -2 * PI, -2 * PI, -2 * PI}; + static const MS_FLOAT32X4 div_pi2 = {1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI)}; + MS_FLOAT128_F32 src_abs = MS_ABS128_F32(src); + MS_FLOAT128_F32 src_cycle = + MS_ADD128_F32(MS_MUL128_F32(MS_FLOOR128_F32(MS_MUL128_F32(MS_ADD128_F32(src_abs, pi), div_pi2)), pi2_neg), src_abs); + static const MS_FLOAT128_F32 data0 = {1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90}; + static const MS_FLOAT128_F32 data1 = {1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56}; + static const MS_FLOAT128_F32 data2 = {1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30}; + static const MS_FLOAT128_F32 data3 = {1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12}; + static const MS_FLOAT128_F32 data4 = {0.5f, 0.5f, 0.5f, 0.5f}; + static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f}; + MS_FLOAT32X4 square = MS_MUL128_F32(src_cycle, src_cycle); + + MS_FLOAT32X4 tmp = + MS_MUL128_F32(MS_MUL128_F32(MS_ADD128_F32(MS_MUL128_F32(MS_MUL128_F32(neg, square), data0), pos), square), data1); + MS_FLOAT32X4 tmp1 = MS_MUL128_F32(MS_MUL128_F32(MS_ADD128_F32(tmp, neg), square), data2); + MS_FLOAT128_F32 res = MS_ADD128_F32( + MS_MUL128_F32( + MS_MUL128_F32(MS_ADD128_F32(MS_MUL128_F32(MS_MUL128_F32(MS_ADD128_F32(tmp1, pos), square), data3), neg), square), + data4), + pos); + return res; +} + +static inline MS_FLOAT32X4 MS128_LOG_F32(MS_FLOAT32X4 src) { + const MS_INT128_EPI32 gFloatExpMask = MS_MOV128_EPI32(0xffULL << 23); + const MS_INT128_EPI32 gFloatExp0 = MS_MOV128_EPI32(127ULL << 23); + const MS_INT128_EPI32 gExpNormalizer = MS_MOV128_EPI32(127); + static const MS_FLOAT128_F32 data0 = {1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11}; + static const MS_FLOAT128_F32 data1 = {1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9}; + static const MS_FLOAT128_F32 data2 = {1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7}; + static const MS_FLOAT128_F32 data3 = {0.2f, 0.2f, 0.2f, 0.2f}; + static const MS_FLOAT128_F32 data4 = {1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3}; + static const MS_FLOAT128_F32 data5 = {1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT128_F32 data6 = {2.0f, 2.0f, 2.0f, 2.0f}; + static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT32X4 ln2 = {LN2, LN2, LN2, LN2}; + + const MS_INT128_EPI32 exps32 = MS128_SRLI_EPI32(MS128_AND_EPI32(gFloatExpMask, MS128_CASTPS_EPI32(src)), 23); + const MS_INT128_EPI32 normExps = MS_SUB128_EPI32(exps32, gExpNormalizer); + const MS_FLOAT32X4 expsPD = MS_CVT128EPI32_PS(normExps); + const MS_FLOAT32X4 y = + MS_OR128_F32(MS_CAST128_F32_S32(gFloatExp0), MS128_ANDNOT_F32(MS_CAST128_F32_S32(gFloatExpMask), src)); + MS_FLOAT32X4 div = MS_DIV128_F32(MS_ADD128_F32(y, neg), MS_ADD128_F32(y, pos)); + MS_FLOAT32X4 square = MS_MUL128_F32(div, div); + + MS_FLOAT32X4 tmp = MS_ADD128_F32( + MS_MUL128_F32(MS_ADD128_F32(MS_MUL128_F32(square, MS_ADD128_F32(MS_MUL128_F32(square, data0), data1)), data2), + square), + data3); + MS_FLOAT32X4 tmp1 = MS_MUL128_F32(square, MS_ADD128_F32(MS_MUL128_F32(square, tmp), data4)); + MS_FLOAT32X4 res = + MS_ADD128_F32(MS_MUL128_F32(ln2, expsPD), MS_MUL128_F32(MS_MUL128_F32(div, MS_ADD128_F32(tmp1, data5)), data6)); + MS_FLOAT32X4 mask = MS_CMPEQ128_F32(src, MS_MOV128_F32(0.0f)); + res = MS_BLEND128_F32(res, MS_MOV128_F32(-INFINITY), mask); + mask = MS_CMPEQ128_F32(src, MS_MOV128_F32(INFINITY)); + res = MS_BLEND128_F32(res, MS_MOV128_F32(INFINITY), mask); + mask = MS_OR128_F32(MS_CMPLT128_F32(src, MS_MOV128_F32(0.0f)), MS_CMPUNORD128_F32(src, MS_MOV128_F32(0.0f))); + res = MS_BLEND128_F32(res, MS_MOV128_F32(NAN), mask); + return res; +} + +static inline MS_FLOAT32X4 MS_SQRTFX4_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = sqrtf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = sqrtf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = sqrtf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = sqrtf(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline float MS_GET_MAX128_F32(__m128 src) { + float result = MS_F32X4_GETI(src, 0); + for (int i = 1; i < 4; i++) { // sse block num : 4 + result = fmaxf(result, MS_F32X4_GETI(src, i)); + } + return result; +} + +static inline float MS_GET_SUM128_F32(__m128 src) { + float result = MS_F32X4_GETI(src, 0); + for (int i = 1; i < 4; i++) { // sse block num : 4 + result = result + MS_F32X4_GETI(src, i); + } + return result; +} + +#define STORE128X8_F32(output_ptr, num, dst) \ + MS_STQ_F32(output_ptr + 0 * num, dst##1); \ + MS_STQ_F32(output_ptr + 1 * num, dst##2); \ + MS_STQ_F32(output_ptr + 2 * num, dst##3); \ + MS_STQ_F32(output_ptr + 3 * num, dst##4); \ + MS_STQ_F32(output_ptr + 4 * num, dst##5); \ + MS_STQ_F32(output_ptr + 5 * num, dst##6); \ + MS_STQ_F32(output_ptr + 6 * num, dst##7); \ + MS_STQ_F32(output_ptr + 7 * num, dst##8); + +static inline MS_FLOAT32X4 VexpFp32(MS_FLOAT32X4 input) { + static MS_FLOAT32X4 param[] = { + {0.693147f, 0.693147f, 0.693147f, 0.693147f}, + {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, + {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, + {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, + {0.5f, 0.5f, 0.5f, 0.5f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f}, + {2.0f, 2.0f, 2.0f, 2.0f}}; + MS_INT32X4 integer = MS_CVTQPS_EPI32(MS_FLOOR128_F32(MS_FMADD128_F32(input, param[6], param[4]))); + MS_FLOAT32X4 decimal = MS_SUBQ_F32(input, MS_MULQ_F32(MS_CVTQEPI32_PS(integer), param[0])); + MS_INT32X4 int_exp = MS_SLLIQ_EPI32(MS_ADDQ_EPI32(integer, MS_MOVQ_EPI32(126)), 23); + MS_FLOAT32X4 tmp = MS_MULQ_F32(decimal, (MS_ADDQ_F32(param[2], MS_MULQ_F32(decimal, param[1])))); + tmp = MS_MULQ_F32(decimal, MS_ADDQ_F32(param[4], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[3], tmp)))); + MS_FLOAT32X4 decimal_exp = MS_ADDQ_F32(param[5], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[5], tmp))); + return MS_MULQ_F32(param[7], MS_MULQ_F32(decimal_exp, MS_CAST128_F32_S32(int_exp))); +} + +static inline void simd_exp128(MS_FLOAT32X4 input, float *dst) { + static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); + MS_STQ_F32(dst, VexpFp32(input)); +} + +static inline MS_FLOAT32X4 simd_exp128_f32(MS_FLOAT32X4 input) { + static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); + return VexpFp32(input); +} + +static inline MS_FLOAT32X4 simd_hexp128_f32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = exp(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = exp(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = exp(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = exp(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) { + static const MS_FLOAT32X4 data0 = {378.0f, 378.0f, 378.0f, 378.0f}; + static const MS_FLOAT32X4 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f}; + static const MS_FLOAT32X4 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f}; + static const MS_FLOAT32X4 data3 = {28.0f, 28.0f, 28.0f, 28.0f}; + static const MS_FLOAT32X4 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f}; + static const MS_FLOAT32X4 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f}; + static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f}; + MS_FLOAT32X4 square = MS_MULQ_F32(src, src); + MS_FLOAT32X4 a = + MS_MUL128_F32(MS_FMADD128_F32(MS_FMADD128_F32(MS_ADD128_F32(square, data0), square, data1), square, data2), src); + MS_FLOAT32X4 b = + MS_FMADD128_F32(MS_FMADD128_F32(MS_FMADD128_F32(data3, square, data4), square, data5), square, data2); + MS_FLOAT32X4 res = MS_DIVQ_F32(a, b); + MS_FLOAT32X4 up_limit = MS_MOV128_F32(5.0f); + MS_FLOAT32X4 down_limit = MS_MOV128_F32(-5.0f); + MS_FLOAT32X4 up_mask = MS_CMPGT128_F32(src, up_limit); + MS_FLOAT32X4 down_mask = MS_CMPLT128_F32(src, down_limit); + res = MS_BLEND128_F32(res, pos, up_mask); + res = MS_BLEND128_F32(res, neg, down_mask); + return res; +} + +#define MS_TANH128_F32 MS_TANHX4_F32 + +static inline MS_FLOAT32X4 MS128_ERF_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = erff(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = erff(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = erff(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = erff(MS_F32X4_GETI(src, 3)); + return dst; +} + +#define MS_FMADD128X8_F32(src, weight, dst) \ + dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ + dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ + dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ + dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); \ + dst##5 = MS_MLAQ_F32(src##5, weight, dst##5); \ + dst##6 = MS_MLAQ_F32(src##6, weight, dst##6); \ + dst##7 = MS_MLAQ_F32(src##7, weight, dst##7); \ + dst##8 = MS_MLAQ_F32(src##8, weight, dst##8); + +#define MS_LOAD128X4_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); + +#define MS_FMADD128X4_F32(src, weight, dst) \ + dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ + dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ + dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ + dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); + +#define MS_LOAD128X8_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \ + MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \ + MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \ + MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \ + MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num); + +#define MS_SET_ZERO128X8_F32(dst) \ + MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##5 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##6 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##7 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##8 = MS_MOVQ_F32(0.0f); + +#define MS_SET_ZERO128X4_F32(dst) \ + MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); +#endif // NNACL_SSE_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/ConvDwFp32IndirectRow.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/ConvDwFp32IndirectRow.c new file mode 100644 index 0000000000000000000000000000000000000000..21f356a23e948a42978f53d326aefdb488bc9529 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/ConvDwFp32IndirectRow.c @@ -0,0 +1,120 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#ifdef _MSC_VER +#include +#else +#include +#endif +#include "nnacl_c/fp32/conv_depthwise_fp32.h" + +#define INPUT_SIZE 25 + +void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const float *bias, size_t channels, + size_t output_width, size_t input_stride, size_t relu, size_t relu6) { + input_stride /= sizeof(float *); + size_t c8 = UP_DIV(channels, C8NUM) * C8NUM; + size_t c8_mod = channels % C8NUM; + float *in[INPUT_SIZE]; + for (int i = 0; i < output_width; ++i) { + for (int k = 0; k < INPUT_SIZE; k++) { + in[k] = input[k]; + } + input += input_stride; + size_t c = c8; + const float *w = weights; + const float *bias1 = bias; + for (; c >= C8NUM; c -= C8NUM) { + __m256 out1 = _mm256_loadu_ps(bias1); + bias1 += 8; + for (int k = 0; k < INPUT_SIZE; k += 5) { + __m256 in1 = _mm256_loadu_ps(in[k]); + __m256 w1 = _mm256_loadu_ps(w); + __m256 in2 = _mm256_loadu_ps(in[k + 1]); + __m256 w2 = _mm256_loadu_ps(w + 8); + out1 = _mm256_fmadd_ps(in1, w1, out1); + __m256 in3 = _mm256_loadu_ps(in[k + 2]); + __m256 w3 = _mm256_loadu_ps(w + 16); + out1 = _mm256_fmadd_ps(in2, w2, out1); + __m256 in4 = _mm256_loadu_ps(in[k + 3]); + __m256 w4 = _mm256_loadu_ps(w + 24); + out1 = _mm256_fmadd_ps(in3, w3, out1); + __m256 in5 = _mm256_loadu_ps(in[k + 4]); + __m256 w5 = _mm256_loadu_ps(w + 32); + out1 = _mm256_fmadd_ps(in4, w4, out1); + w += 40; + in[k] += C8NUM; + in[k + 1] += C8NUM; + in[k + 2] += C8NUM; + in[k + 3] += C8NUM; + in[k + 4] += C8NUM; + out1 = _mm256_fmadd_ps(in5, w5, out1); + } + if (relu6 != 0) { + __m256 relu6_data = _mm256_set1_ps(6.0); + out1 = _mm256_min_ps(out1, relu6_data); + } + if (relu != 0 || relu6 != 0) { + __m256 zero = _mm256_setzero_ps(); + out1 = _mm256_max_ps(out1, zero); + } + if (c > C8NUM || c8_mod == 0) { + _mm256_storeu_ps(output, out1); + output += C8NUM; + } else { + __m128 tmp; + switch (c8_mod) { + case 1: + _mm_store_ss(output, _mm256_castps256_ps128(out1)); + break; + case 2: + _mm_storel_pi((__m64 *)output, _mm256_castps256_ps128(out1)); + break; + case 3: + tmp = _mm256_castps256_ps128(out1); + _mm_storel_pi((__m64 *)output, tmp); + tmp = _mm_unpackhi_ps(tmp, tmp); + _mm_store_ss(output + 2, tmp); + break; + case 4: + _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); + break; + case 5: + _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); + _mm_store_ss(output + 4, _mm256_extractf128_ps(out1, 1)); + break; + case 6: + _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); + _mm_storel_pi((__m64 *)(output + 4), _mm256_extractf128_ps(out1, 1)); + break; + case 7: + _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); + tmp = _mm256_extractf128_ps(out1, 1); + _mm_storel_pi((__m64 *)(output + 4), tmp); + tmp = _mm_unpackhi_ps(tmp, tmp); + _mm_store_ss(output + 6, tmp); + break; + default: + _mm256_storeu_ps(output, out1); + break; + } + output += c8_mod; + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/ConvDwFp32Row_sse.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/ConvDwFp32Row_sse.c new file mode 100644 index 0000000000000000000000000000000000000000..09b7e956fa22f25ecf76ce9588f7f55d06a88324 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/ConvDwFp32Row_sse.c @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" + +void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step) { + size_t out_c16 = DOWN_DIV(output_channel, C16NUM) * C16NUM; + size_t out_c8 = DOWN_DIV(output_channel, C8NUM) * C8NUM; + size_t out_c4 = DOWN_DIV(output_channel, C4NUM) * C4NUM; + for (int i = 0; i < num_pixels; i++) { + const float *weight_tmp = weight_ptr; + const float *input_tmp = input_ptr; + size_t out_c = 0; + for (; out_c < out_c16; out_c += C16NUM) { + __m128 dst1 = _mm_loadu_ps(output_ptr); + __m128 dst2 = _mm_loadu_ps(output_ptr + 4); + __m128 dst3 = _mm_loadu_ps(output_ptr + 8); + __m128 dst4 = _mm_loadu_ps(output_ptr + 12); + __m128 w1 = _mm_loadu_ps(weight_tmp); + __m128 w2 = _mm_loadu_ps(weight_tmp + 4); + __m128 w3 = _mm_loadu_ps(weight_tmp + 8); + __m128 w4 = _mm_loadu_ps(weight_tmp + 12); + __m128 in1 = _mm_loadu_ps(input_tmp); + __m128 in2 = _mm_loadu_ps(input_tmp + 4); + __m128 in3 = _mm_loadu_ps(input_tmp + 8); + __m128 in4 = _mm_loadu_ps(input_tmp + 12); + dst1 = MS_MLAQ_F32(dst1, w1, in1); + dst2 = MS_MLAQ_F32(dst2, w2, in2); + dst3 = MS_MLAQ_F32(dst3, w3, in3); + dst4 = MS_MLAQ_F32(dst4, w4, in4); + _mm_storeu_ps(output_ptr, dst1); + _mm_storeu_ps(output_ptr + 4, dst2); + _mm_storeu_ps(output_ptr + 8, dst3); + _mm_storeu_ps(output_ptr + 12, dst4); + output_ptr += 16; + input_tmp += 16; + weight_tmp += 16; + } + for (; out_c < out_c8; out_c += C8NUM) { + __m128 dst1 = _mm_loadu_ps(output_ptr); + __m128 dst2 = _mm_loadu_ps(output_ptr + 4); + __m128 w1 = _mm_loadu_ps(weight_tmp); + __m128 w2 = _mm_loadu_ps(weight_tmp + 4); + __m128 in1 = _mm_loadu_ps(input_tmp); + __m128 in2 = _mm_loadu_ps(input_tmp + 4); + dst1 = MS_MLAQ_F32(dst1, w1, in1); + dst2 = MS_MLAQ_F32(dst2, w2, in2); + _mm_storeu_ps(output_ptr, dst1); + _mm_storeu_ps(output_ptr + 4, dst2); + output_ptr += 8; + input_tmp += 8; + weight_tmp += 8; + } + for (; out_c < out_c4; out_c += C4NUM) { + __m128 dst1 = _mm_loadu_ps(output_ptr); + __m128 w1 = _mm_loadu_ps(weight_tmp); + __m128 in1 = _mm_loadu_ps(input_tmp); + dst1 = MS_MLAQ_F32(dst1, w1, in1); + _mm_storeu_ps(output_ptr, dst1); + output_ptr += 4; + input_tmp += 4; + weight_tmp += 4; + } + for (; out_c < output_channel; out_c++) { + *output_ptr++ += weight_ptr[out_c] * input_ptr[out_c]; + } + input_ptr += input_step; + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/DepthwiseFp32_Sse.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/DepthwiseFp32_Sse.c new file mode 100644 index 0000000000000000000000000000000000000000..82832584643edd6f27b99168e07d3276223496f3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/DepthwiseFp32_Sse.c @@ -0,0 +1,327 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_SSE +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/intrinsics/sse/sse_common.h" + +#ifndef ENABLE_AVX +void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t in_kh_step, size_t in_kw_step, size_t kernel_w_step, size_t relu, size_t relu6) { + in_kh_step /= sizeof(float); + in_kw_step /= sizeof(float); + kernel_w_step /= sizeof(float); + + const float *src_kh = src; + const float *weight_kh = weight; + __m128 dst_ma = _mm_setzero_ps(); + + for (int kh = 0; kh < height; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + + int c1 = 0; + int c4 = DOWN_DIV(width, C4NUM) * C4NUM; + int c2 = DOWN_DIV(width, C2NUM) * C2NUM; + // c4 loop + for (; c1 < c4; c1 += C4NUM) { + __m128 src_ma1 = _mm_loadu_ps(src_kw); + __m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step); + __m128 src_ma3 = _mm_loadu_ps(src_kw + 2 * in_kw_step); + __m128 src_ma4 = _mm_loadu_ps(src_kw + 3 * in_kw_step); + + __m128 weight_ma1 = _mm_loadu_ps(weight_kw); + __m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 weight_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM); + __m128 weight_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM); + + __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1); + __m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2); + __m128 mul_ma3 = _mm_mul_ps(src_ma3, weight_ma3); + __m128 mul_ma4 = _mm_mul_ps(src_ma4, weight_ma4); + + dst_ma = _mm_add_ps(dst_ma, mul_ma1); + dst_ma = _mm_add_ps(dst_ma, mul_ma2); + dst_ma = _mm_add_ps(dst_ma, mul_ma3); + dst_ma = _mm_add_ps(dst_ma, mul_ma4); + + src_kw += in_kw_step * 4; + weight_kw += C4NUM * 4; + } + + // c2 loop + for (; c1 < c2; c1 += C2NUM) { + __m128 src_ma1 = _mm_loadu_ps(src_kw); + __m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step); + __m128 weight_ma1 = _mm_loadu_ps(weight_kw); + __m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1); + __m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2); + dst_ma = _mm_add_ps(dst_ma, mul_ma1); + dst_ma = _mm_add_ps(dst_ma, mul_ma2); + + src_kw += in_kw_step * 2; + weight_kw += C4NUM * 2; + } + + // remaining + for (; c1 < width; ++c1) { + __m128 src_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_ma1 = _mm_loadu_ps(weight_kw); + __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1); + dst_ma = _mm_add_ps(dst_ma, mul_ma1); + + src_kw += in_kw_step; + weight_kw += C4NUM; + } + + src_kh += in_kh_step; + weight_kh += kernel_w_step; + } + + __m128 bias_ma = _mm_loadu_ps(bias); + dst_ma = _mm_add_ps(dst_ma, bias_ma); + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + dst_ma = _mm_max_ps(zero_ma, dst_ma); + if (relu6) { + __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + dst_ma = _mm_min_ps(const_ma, dst_ma); + } + } + _mm_storeu_ps(dst, dst_ma); +} +#endif + +void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6) { + out_h_step /= sizeof(float); + block_channel /= sizeof(float); + in_sh_step /= sizeof(float); + in_sw_step /= sizeof(float); + in_kh_step /= sizeof(float); + in_kw_step /= sizeof(float); + + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + int c4 = DOWN_DIV(width, C4NUM) * C4NUM; + int c2 = DOWN_DIV(width, C2NUM) * C2NUM; + int c1 = 0; + // c4 loop + for (; c1 < c4; c1 += C4NUM, dst_w += C4NUM * block_channel, src_w += C4NUM * in_sw_step) { + const float *src_kh = src_w, *weight_kh = weight; + __m128 dst_w_ma1 = _mm_setzero_ps(); + __m128 dst_w_ma2 = _mm_setzero_ps(); + __m128 dst_w_ma3 = _mm_setzero_ps(); + __m128 dst_w_ma4 = _mm_setzero_ps(); + + for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) { + const float *src_kw = src_kh, *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) { + __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + + __m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + + __m128 src_kw_ma3 = _mm_loadu_ps(src_kw + 2 * in_sw_step); + __m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma3 = _mm_mul_ps(src_kw_ma3, weight_kw_ma3); + dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3); + + __m128 src_kw_ma4 = _mm_loadu_ps(src_kw + 3 * in_sw_step); + __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma4 = _mm_mul_ps(src_kw_ma4, weight_kw_ma4); + dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4); + } // kernel_w loop + } // kernel_h loop + + // add bias relu + __m128 bias_ma = _mm_loadu_ps(bias); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma); + dst_w_ma3 = _mm_add_ps(dst_w_ma3, bias_ma); + dst_w_ma4 = _mm_add_ps(dst_w_ma4, bias_ma); + + ActBlock4(&dst_w_ma1, &dst_w_ma2, &dst_w_ma3, &dst_w_ma4, relu, relu6); + + _mm_storeu_ps(dst_w, dst_w_ma1); + _mm_storeu_ps(dst_w + block_channel, dst_w_ma2); + _mm_storeu_ps(dst_w + 2 * block_channel, dst_w_ma3); + _mm_storeu_ps(dst_w + 3 * block_channel, dst_w_ma4); + } // dst_width loop + + // c2 loop + for (; c1 < c2; c1 += C2NUM, dst_w += C2NUM * block_channel, src_w += C2NUM * in_sw_step) { + const float *src_kh = src_w, *weight_kh = weight; + __m128 dst_w_ma1 = _mm_setzero_ps(); + __m128 dst_w_ma2 = _mm_setzero_ps(); + + for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) { + const float *src_kw = src_kh, *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) { + __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + + __m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + } // kernel_w loop + } // kernel_h loop + // add bias relu + __m128 bias_ma = _mm_loadu_ps(bias); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma); + + ActBlock2(&dst_w_ma1, &dst_w_ma2, relu, relu6); + + _mm_storeu_ps(dst_w, dst_w_ma1); + _mm_storeu_ps(dst_w + block_channel, dst_w_ma2); + } + + // remaining + for (; c1 < width; c1++, dst_w += block_channel, src_w += in_sw_step) { + const float *src_kh = src_w, *weight_kh = weight; + __m128 dst_w_ma1 = _mm_setzero_ps(); + for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) { + const float *src_kw = src_kh, *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) { + __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + } // kernel_w loop + } // kernel_h loop + + // add bias relu + __m128 bias_ma = _mm_loadu_ps(bias); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); + ActBlock1(&dst_w_ma1, relu, relu6); + _mm_storeu_ps(dst_w, dst_w_ma1); + } + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} + +void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h, + size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, + size_t in_kh_step, size_t in_kw_step) { + out_h_step /= sizeof(float); + block_channel /= sizeof(float); + in_sh_step /= sizeof(float); + in_sw_step /= sizeof(float); + in_kh_step /= sizeof(float); + in_kw_step /= sizeof(float); + + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + float *dst_kh = dst_w; + const float *weight_kh = weight; + __m128 src_w_ma = _mm_loadu_ps(src_w); + for (int kh = 0; kh < kernel_h; kh++) { + float *dst_kw = dst_kh; + const float *weight_kw = weight_kh; + + int c4 = DOWN_DIV(kernel_w, C4NUM) * C4NUM; + int c2 = DOWN_DIV(kernel_w, C2NUM) * C2NUM; + int c1 = 0; + // c4 loop + for (; c1 < c4; c1 += C4NUM) { + __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + _mm_storeu_ps(dst_kw, dst_w_ma1); + + __m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + _mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2); + + __m128 dst_w_ma3 = _mm_loadu_ps(dst_kw + 2 * in_kw_step); + __m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM); + __m128 tmp_ma3 = _mm_mul_ps(src_w_ma, weight_kw_ma3); + dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3); + _mm_storeu_ps(dst_kw + 2 * in_kw_step, dst_w_ma3); + + __m128 dst_w_ma4 = _mm_loadu_ps(dst_kw + 3 * in_kw_step); + __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM); + __m128 tmp_ma4 = _mm_mul_ps(src_w_ma, weight_kw_ma4); + dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4); + _mm_storeu_ps(dst_kw + 3 * in_kw_step, dst_w_ma4); + + dst_kw += 4 * in_kw_step; + weight_kw += 4 * C4NUM; + } + // c2 loop + for (; c1 < c2; c1 += C2NUM) { + __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + _mm_storeu_ps(dst_kw, dst_w_ma1); + + __m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + _mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2); + + dst_kw += 2 * in_kw_step; + weight_kw += 2 * C4NUM; + } + // remaining + for (; c1 < kernel_w; ++c1) { + __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + _mm_storeu_ps(dst_kw, dst_w_ma1); + + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/MatMul_Sse.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/MatMul_Sse.c new file mode 100644 index 0000000000000000000000000000000000000000..17103cdb0eec6c5e2418c49a175c1c69db407a6e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/MatMul_Sse.c @@ -0,0 +1,243 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_SSE +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/intrinsics/sse/sse_common.h" +#include "nnacl_c/base/minimal_filtering_generator.h" + +void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, + int in_channel, int c4_channel) { + const float *src1 = matix_a; + int c16 = DOWN_DIV(in_channel, C16NUM) * C16NUM; + int c8 = DOWN_DIV(in_channel, C8NUM) * C8NUM; + for (int i = 0; i < m; ++i) { + const float *src1_n = src1; + const float *src2_n = matrix_b; + for (int j = 0; j < n; ++j) { + const float *src1_j = src1_n; + int y = 0; + // 16 channel + for (; y < c16; y += C16NUM) { + __m128 dst1 = _mm_setzero_ps(); + __m128 dst2 = _mm_setzero_ps(); + __m128 dst3 = _mm_setzero_ps(); + __m128 dst4 = _mm_setzero_ps(); + const float *src2_y = src2_n; + for (int z = 0; z < k; ++z) { + __m128 ma1 = _mm_loadu_ps(src1_j); + __m128 ma2 = _mm_loadu_ps(src1_j + 4); + __m128 ma3 = _mm_loadu_ps(src1_j + 8); + __m128 ma4 = _mm_loadu_ps(src1_j + 12); + + __m128 mb = _mm_load_ps1(src2_y); + __m128 tmp1 = _mm_mul_ps(ma1, mb); + __m128 tmp2 = _mm_mul_ps(ma2, mb); + __m128 tmp3 = _mm_mul_ps(ma3, mb); + __m128 tmp4 = _mm_mul_ps(ma4, mb); + dst1 = _mm_add_ps(dst1, tmp1); + dst2 = _mm_add_ps(dst2, tmp2); + dst3 = _mm_add_ps(dst3, tmp3); + dst4 = _mm_add_ps(dst4, tmp4); + src1_j += in_channel; + src2_y += n; + } + _mm_storeu_ps(matrix_c, dst1); + _mm_storeu_ps(matrix_c + 4, dst2); + _mm_storeu_ps(matrix_c + 8, dst3); + _mm_storeu_ps(matrix_c + 12, dst4); + src1_j -= in_channel * k; + src1_j += C16NUM; + matrix_c += C16NUM; + } + // 8 channel + for (; y < c8; y += C8NUM) { + __m128 dst1 = _mm_setzero_ps(); + __m128 dst2 = _mm_setzero_ps(); + const float *src2_y = src2_n; + for (int z = 0; z < k; ++z) { + __m128 ma1 = _mm_loadu_ps(src1_j); + __m128 ma2 = _mm_loadu_ps(src1_j + 4); + + __m128 mb = _mm_load_ps1(src2_y); + __m128 tmp1 = _mm_mul_ps(ma1, mb); + __m128 tmp2 = _mm_mul_ps(ma2, mb); + dst1 = _mm_add_ps(dst1, tmp1); + dst2 = _mm_add_ps(dst2, tmp2); + src1_j += in_channel; + src2_y += n; + } + _mm_storeu_ps(matrix_c, dst1); + _mm_storeu_ps(matrix_c + 4, dst2); + src1_j -= in_channel * k; + src1_j += C8NUM; + matrix_c += C8NUM; + } + // remain chann + for (; y < in_channel; ++y) { + float tmp = 0; + for (int z = 0; z < k; ++z) { + tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; + } + *matrix_c++ = tmp; + } + src2_n += 1; + } + src1 += k * in_channel; + } +} + +void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, int write_mode) { + int C8Steps = row * C8NUM, WinoSteps1 = stride * col, WinoSteps2 = stride * C8NUM; + for (int r = row; r > 0; r -= C4NUM) { + const float *srcb_d = b, *bias_d = bias; + float *dst = NULL; + for (int cc = col; cc > 0; cc -= C8NUM) { + if (write_mode != 0) { // writec8 + dst = c; + } + const float *srca_d = a; + __m128 dst1 = _mm_setzero_ps(), dst2 = _mm_setzero_ps(), dst3 = _mm_setzero_ps(), dst4 = _mm_setzero_ps(); + __m128 dst5 = _mm_setzero_ps(), dst6 = _mm_setzero_ps(), dst7 = _mm_setzero_ps(), dst8 = _mm_setzero_ps(); + for (int d = depth; d > 0; --d) { + __m128 b1 = _mm_loadu_ps(srcb_d), b2 = _mm_loadu_ps(srcb_d + 4); + __m128 a1 = _mm_load_ps1(srca_d), a2 = _mm_load_ps1(srca_d + 1); + __m128 tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + __m128 tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + a1 = _mm_load_ps1(srca_d + 2); + dst1 = _mm_add_ps(dst1, tmp1), dst2 = _mm_add_ps(dst2, tmp2); + a2 = _mm_load_ps1(srca_d + 3); + dst3 = _mm_add_ps(dst3, tmp3), dst4 = _mm_add_ps(dst4, tmp4); + tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + dst5 = _mm_add_ps(dst5, tmp1), dst6 = _mm_add_ps(dst6, tmp2); + dst7 = _mm_add_ps(dst7, tmp3), dst8 = _mm_add_ps(dst8, tmp4); + srcb_d += C8NUM, srca_d += C4NUM; + } + + if (bias != NULL) { + DoBiasBlock8(bias_d, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8); + bias_d += C8NUM; + } + + ActBlock8(&dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, act_type); + + if (write_mode == OutType_TileC8) { // WriteWino + c = dst + WinoSteps2; + _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2); + dst += WinoSteps1; + _mm_storeu_ps(dst, dst3), _mm_storeu_ps(dst + 4, dst4); + dst += WinoSteps1; + _mm_storeu_ps(dst, dst5), _mm_storeu_ps(dst + 4, dst6); + dst += WinoSteps1; + _mm_storeu_ps(dst, dst7), _mm_storeu_ps(dst + 4, dst8); + } else if (write_mode == OutType_C8) { // WriteC8 + _mm_storeu_ps(c, dst1), _mm_storeu_ps(c + 4, dst2); + _mm_storeu_ps(c + 8, dst3), _mm_storeu_ps(c + 12, dst4); + _mm_storeu_ps(c + 16, dst5), _mm_storeu_ps(c + 20, dst6); + _mm_storeu_ps(c + 24, dst7), _mm_storeu_ps(c + 28, dst8); + c += C8Steps; + } else { + switch (cc) { + case 1: // write1 + c = dst + 1; + WriteCol1(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 1, r); + break; + case 2: // write2 + c = dst + 2; + WriteCol2Opt(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, r); + break; + case 3: // write3 + c = dst + 3; + _mm_store_ss(dst, dst1); + dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst1); + dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 2, dst1); + WriteCol3(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 3, r); + break; + case 4: // write4 + c = dst + 4; + WriteCol4(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 4, r); + break; + case 5: // write5 + c = dst + 5; + WriteCol5(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 5, r); + break; + case 6: // write6 + c = dst + 6; + WriteCol6(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 6, r); + break; + case 7: // write7 + c = dst + 7; + WriteCol7(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 7, r); + break; + default: // write8 + c = dst + C8NUM; + WriteCol8(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 8, r); + break; + } + } + if (cc <= C8NUM) break; // write end + } + a += C4NUM * depth; + if (write_mode == OutType_C8) c += 32; + if (write_mode == OutType_TileC8) c = dst + WinoSteps2; + if (write_mode == OutType_Nhwc) c = dst - col; + if (r <= C4NUM) break; + } +} + +void DeconvMatmulFloatSse(const float *a, const float *b, float *c, int depth, int row, int col) { + for (int col_tmp = col; col_tmp > 0; col_tmp -= C8NUM) { + const float *srca_d = a; + float *dst = c; + for (int r = row; r > 0; r -= C4NUM) { + const float *srcb_d = b; + __m128 dst1 = _mm_setzero_ps(), dst2 = _mm_setzero_ps(); + __m128 dst3 = _mm_setzero_ps(), dst4 = _mm_setzero_ps(); + __m128 dst5 = _mm_setzero_ps(), dst6 = _mm_setzero_ps(); + __m128 dst7 = _mm_setzero_ps(), dst8 = _mm_setzero_ps(); + for (int d = 0; d < depth; d++) { + __m128 b1 = _mm_loadu_ps(srcb_d), b2 = _mm_loadu_ps(srcb_d + 4); + __m128 a1 = _mm_load_ps1(srca_d), a2 = _mm_load_ps1(srca_d + 1); + __m128 tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + __m128 tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + a1 = _mm_load_ps1(srca_d + 2); + dst1 = _mm_add_ps(dst1, tmp1), dst2 = _mm_add_ps(dst2, tmp2); + a2 = _mm_load_ps1(srca_d + 3); + dst3 = _mm_add_ps(dst3, tmp3), dst4 = _mm_add_ps(dst4, tmp4); + tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + dst5 = _mm_add_ps(dst5, tmp1), dst6 = _mm_add_ps(dst6, tmp2); + dst7 = _mm_add_ps(dst7, tmp3), dst8 = _mm_add_ps(dst8, tmp4); + srcb_d += C8NUM, srca_d += C4NUM; + } + _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2); + _mm_storeu_ps(dst + 8, dst3), _mm_storeu_ps(dst + 12, dst4); + _mm_storeu_ps(dst + 16, dst5), _mm_storeu_ps(dst + 20, dst6); + _mm_storeu_ps(dst + 24, dst7), _mm_storeu_ps(dst + 28, dst8); + dst += 32; + c = dst; + } + b += depth * C8NUM; + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/PostFuncBiasReluC8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/PostFuncBiasReluC8.c new file mode 100644 index 0000000000000000000000000000000000000000..0526761b72203772b40df8841efd07ff7c0b28f6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/PostFuncBiasReluC8.c @@ -0,0 +1,131 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/intrinsics/sse/sse_common.h" + +void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t stride, size_t relu_type) { + stride /= sizeof(float); + for (int loop_c8 = 0; loop_c8 != oc8div; loop_c8 += C8NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m128 bias1 = _mm_setzero_ps(), bias2 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias2 = _mm_loadu_ps(bias + 4); + bias += 8; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM, src += 32) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + 4); + __m128 src3 = _mm_loadu_ps(src + 8); + __m128 src4 = _mm_loadu_ps(src + 12); + __m128 src5 = _mm_loadu_ps(src + 16); + __m128 src6 = _mm_loadu_ps(src + 20); + __m128 src7 = _mm_loadu_ps(src + 24); + __m128 src8 = _mm_loadu_ps(src + 28); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + src3 = _mm_add_ps(src3, bias1); + src4 = _mm_add_ps(src4, bias2); + src5 = _mm_add_ps(src5, bias1); + src6 = _mm_add_ps(src6, bias2); + src7 = _mm_add_ps(src7, bias1); + src8 = _mm_add_ps(src8, bias2); + + ActBlock8(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type); + + _mm_storeu_ps(dst_c8, src1); + _mm_storeu_ps(dst_c8 + 4, src2); + dst_c8 += stride; + _mm_storeu_ps(dst_c8, src3); + _mm_storeu_ps(dst_c8 + 4, src4); + dst_c8 += stride; + _mm_storeu_ps(dst_c8, src5); + _mm_storeu_ps(dst_c8 + 4, src6); + dst_c8 += stride; + _mm_storeu_ps(dst_c8, src7); + _mm_storeu_ps(dst_c8 + 4, src8); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1, src += 8, dst_c8 += stride) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + 4); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + + ActBlock2(&src1, &src2, relu_type == 1, relu_type == 3); + + _mm_storeu_ps(dst_c8, src1); + _mm_storeu_ps(dst_c8 + 4, src2); + } + } + + if (oc8mod == 0) return; + + __m128 bias1 = _mm_setzero_ps(); + __m128 bias2 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias2 = _mm_loadu_ps(bias + 4); + bias += 8; + } + float *dst_c1 = dst + oc8div; + for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1, src += 8, dst_c1 += stride) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + 4); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + + ActBlock2(&src1, &src2, relu_type == 1, relu_type == 3); + + switch (oc8mod) { + case 1: + _mm_store_ss(dst_c1, src1); + break; + case 2: + _mm_storel_pi((__m64 *)(dst_c1), src1); + break; + case 3: + _mm_storel_pi((__m64 *)(dst_c1), src1); + src1 = _mm_unpackhi_ps(src1, src1); + _mm_store_ss(dst_c1 + 2, src1); + break; + case 4: + _mm_storeu_ps(dst_c1, src1); + break; + case 5: + _mm_storeu_ps(dst_c1, src1); + _mm_store_ss(dst_c1 + 4, src2); + break; + case 6: + _mm_storeu_ps(dst_c1, src1); + _mm_storel_pi((__m64 *)(dst_c1 + 4), src2); + break; + case 7: + _mm_storeu_ps(dst_c1, src1); + _mm_storel_pi((__m64 *)(dst_c1 + 4), src2); + src2 = _mm_unpackhi_ps(src2, src2); + _mm_store_ss(dst_c1 + 6, src2); + break; + default: + break; + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/TiledC4MatMulFp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/TiledC4MatMulFp32.c new file mode 100644 index 0000000000000000000000000000000000000000..0460476fbd685b9dadce7a5a2985cf8726e4a641 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/TiledC4MatMulFp32.c @@ -0,0 +1,161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" + +static inline void TiledC4MatmulFp32_Transfer(__m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, + const __m128 weight, const float v1, const float v2, const float v3, + const float v4) { + *dst1 = _mm_add_ps(*dst1, _mm_mul_ps(weight, _mm_set_ps1(v1))); + *dst2 = _mm_add_ps(*dst2, _mm_mul_ps(weight, _mm_set_ps1(v2))); + *dst3 = _mm_add_ps(*dst3, _mm_mul_ps(weight, _mm_set_ps1(v3))); + *dst4 = _mm_add_ps(*dst4, _mm_mul_ps(weight, _mm_set_ps1(v4))); +} + +static inline void TiledC4MatmulFp32_LoadData(__m128 *src1, __m128 *src2, __m128 *src3, __m128 *src4, + const float *src) { + *src1 = _mm_loadu_ps(src); + *src2 = _mm_loadu_ps(src + 4); + *src3 = _mm_loadu_ps(src + 8); + *src4 = _mm_loadu_ps(src + 12); +} + +void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) { + const float *src_tmp = src; + for (int i = 0; i < oc4; ++i) { + float *dst_tmp = dst; + src = src_tmp; + size_t ic4_tmp = ic4 - 1; + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + 4); + __m128 src3 = _mm_loadu_ps(src + 8); + __m128 src4 = _mm_loadu_ps(src + 12); + src += 16; + __m128 weight_data[4]; + weight_data[0] = _mm_loadu_ps(weight); + weight_data[1] = _mm_loadu_ps(weight + 4); + weight_data[2] = _mm_loadu_ps(weight + 8); + weight_data[3] = _mm_loadu_ps(weight + 12); + weight += 16; + __m128 dst1 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0))); + __m128 dst2 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0))); + __m128 dst3 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0))); + __m128 dst4 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0))); + for (int j = 1; j < 4; ++j) { + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[j], MS_F32X4_GETI(src1, j), + MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j)); + } + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + __m128 dst5 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0))); + __m128 dst6 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0))); + __m128 dst7 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0))); + __m128 dst8 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0))); + for (int j = 1; j < 4; ++j) { + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], MS_F32X4_GETI(src1, j), + MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j)); + } + if (ic4_tmp != 0) { + ic4_tmp -= 1; + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + weight_data[0] = _mm_loadu_ps(weight); + weight_data[1] = _mm_loadu_ps(weight + 4); + weight += 8; + + dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0)))); + dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0)))); + for (; ic4_tmp != 0; ic4_tmp -= 1) { + dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0)))); + dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0)))); + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], MS_F32X4_GETI(src1, 1), + MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1)); + + weight_data[2] = _mm_loadu_ps(weight); + weight_data[3] = _mm_loadu_ps(weight + 4); + weight += 8; + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], MS_F32X4_GETI(src1, 2), + MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2)); + + dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src1, 3)))); + dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src2, 3)))); + src1 = _mm_loadu_ps(src); + src2 = _mm_loadu_ps(src + 4); + dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src3, 3)))); + dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src4, 3)))); + src3 = _mm_loadu_ps(src + 8); + src4 = _mm_loadu_ps(src + 12); + src += 16; + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[0], MS_F32X4_GETI(src1, 0), + MS_F32X4_GETI(src2, 0), MS_F32X4_GETI(src3, 0), MS_F32X4_GETI(src4, 0)); + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[1], MS_F32X4_GETI(src1, 1), + MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1)); + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[2], MS_F32X4_GETI(src1, 2), + MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2)); + + weight_data[0] = _mm_loadu_ps(weight); + weight_data[1] = _mm_loadu_ps(weight + 4); + weight += 8; + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[3], MS_F32X4_GETI(src1, 3), + MS_F32X4_GETI(src2, 3), MS_F32X4_GETI(src3, 3), MS_F32X4_GETI(src4, 3)); + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + + dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0)))); + dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0)))); + } + dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0)))); + dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0)))); + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], MS_F32X4_GETI(src1, 1), + MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1)); + + weight_data[2] = _mm_loadu_ps(weight); + weight_data[3] = _mm_loadu_ps(weight + 4); + weight += 8; + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], MS_F32X4_GETI(src1, 2), + MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2)); + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[3], MS_F32X4_GETI(src1, 3), + MS_F32X4_GETI(src2, 3), MS_F32X4_GETI(src3, 3), MS_F32X4_GETI(src4, 3)); + + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + for (int j = 0; j < 4; ++j) { + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], MS_F32X4_GETI(src1, j), + MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j)); + } + } + _mm_storeu_ps(dst, dst1); + _mm_storeu_ps(dst + 4, dst2); + _mm_storeu_ps(dst + 8, dst3); + _mm_storeu_ps(dst + 12, dst4); + _mm_storeu_ps(dst + 16, dst5); + _mm_storeu_ps(dst + 20, dst6); + _mm_storeu_ps(dst + 24, dst7); + _mm_storeu_ps(dst + 28, dst8); + dst = dst_tmp + cal_num; + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradPostFuncBiasReluC4.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradPostFuncBiasReluC4.c new file mode 100644 index 0000000000000000000000000000000000000000..7c382e93f3c9b95810dc3ee13829ede0e4ef1255 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradPostFuncBiasReluC4.c @@ -0,0 +1,349 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/intrinsics/sse/sse_common.h" + +void WinogradPostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, + size_t plane_size, size_t plane_stride, size_t relu_type) { + size_t stride = oc4div + oc4mod; + plane_stride /= sizeof(float); + int loop_c4 = 0; + size_t src_stride = plane_size * C4NUM + plane_stride; + for (; loop_c4 <= (int)(oc4div)-C16NUM; loop_c4 += C16NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c4 = dst + loop_c4; + __m128 bias1 = _mm_setzero_ps(); + __m128 bias2 = _mm_setzero_ps(); + __m128 bias3 = _mm_setzero_ps(); + __m128 bias4 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias2 = _mm_loadu_ps(bias + C4NUM); + bias3 = _mm_loadu_ps(bias + C8NUM); + bias4 = _mm_loadu_ps(bias + C12NUM); + bias += C16NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + C4NUM); + __m128 src5 = _mm_loadu_ps(src + src_stride); + __m128 src6 = _mm_loadu_ps(src + src_stride + C4NUM); + __m128 src9 = _mm_loadu_ps(src + src_stride * C2NUM); + __m128 src10 = _mm_loadu_ps(src + src_stride * C2NUM + C4NUM); + __m128 src13 = _mm_loadu_ps(src + src_stride * C3NUM); + __m128 src14 = _mm_loadu_ps(src + src_stride * C3NUM + C4NUM); + + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias1); + src5 = _mm_add_ps(src5, bias2); + src6 = _mm_add_ps(src6, bias2); + src9 = _mm_add_ps(src9, bias3); + src10 = _mm_add_ps(src10, bias3); + src13 = _mm_add_ps(src13, bias4); + src14 = _mm_add_ps(src14, bias4); + + ActBlock8(&src1, &src2, &src5, &src6, &src9, &src10, &src13, &src14, relu_type); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src5); + _mm_storeu_ps(dst_c4 + C8NUM, src9); + _mm_storeu_ps(dst_c4 + C12NUM, src13); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src2); + _mm_storeu_ps(dst_c4 + C4NUM, src6); + _mm_storeu_ps(dst_c4 + C8NUM, src10); + _mm_storeu_ps(dst_c4 + C12NUM, src14); + dst_c4 += stride; + + __m128 src3 = _mm_loadu_ps(src + C8NUM); + __m128 src4 = _mm_loadu_ps(src + C12NUM); + __m128 src7 = _mm_loadu_ps(src + src_stride + C8NUM); + __m128 src8 = _mm_loadu_ps(src + src_stride + C12NUM); + __m128 src11 = _mm_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m128 src12 = _mm_loadu_ps(src + src_stride * C2NUM + C12NUM); + __m128 src15 = _mm_loadu_ps(src + src_stride * C3NUM + C8NUM); + __m128 src16 = _mm_loadu_ps(src + src_stride * C3NUM + C12NUM); + src3 = _mm_add_ps(src3, bias1); + src4 = _mm_add_ps(src4, bias1); + src7 = _mm_add_ps(src7, bias2); + src8 = _mm_add_ps(src8, bias2); + src11 = _mm_add_ps(src11, bias3); + src12 = _mm_add_ps(src12, bias3); + src15 = _mm_add_ps(src15, bias4); + src16 = _mm_add_ps(src16, bias4); + + ActBlock8(&src3, &src4, &src7, &src8, &src11, &src12, &src15, &src16, relu_type); + + _mm_storeu_ps(dst_c4, src3); + _mm_storeu_ps(dst_c4 + C4NUM, src7); + _mm_storeu_ps(dst_c4 + C8NUM, src11); + _mm_storeu_ps(dst_c4 + C12NUM, src15); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src4); + _mm_storeu_ps(dst_c4 + C4NUM, src8); + _mm_storeu_ps(dst_c4 + C8NUM, src12); + _mm_storeu_ps(dst_c4 + C12NUM, src16); + dst_c4 += stride; + src += C16NUM; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + src_stride); + __m128 src3 = _mm_loadu_ps(src + src_stride * C2NUM); + __m128 src4 = _mm_loadu_ps(src + src_stride * C3NUM); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + src3 = _mm_add_ps(src3, bias3); + src4 = _mm_add_ps(src4, bias4); + + ActBlock4(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src2); + _mm_storeu_ps(dst_c4 + C8NUM, src3); + _mm_storeu_ps(dst_c4 + C12NUM, src4); + dst_c4 += stride; + src += C4NUM; + } + src += plane_stride; + src += C3NUM * src_stride; + } + for (; loop_c4 <= (int)(oc4div)-C12NUM; loop_c4 += C12NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c4 = dst + loop_c4; + __m128 bias1 = _mm_setzero_ps(); + __m128 bias2 = _mm_setzero_ps(); + __m128 bias3 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias2 = _mm_loadu_ps(bias + C4NUM); + bias3 = _mm_loadu_ps(bias + C8NUM); + bias += C12NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + C4NUM); + __m128 src3 = _mm_loadu_ps(src + C8NUM); + __m128 src4 = _mm_loadu_ps(src + C12NUM); + __m128 src5 = _mm_loadu_ps(src + src_stride); + __m128 src6 = _mm_loadu_ps(src + src_stride + C4NUM); + __m128 src7 = _mm_loadu_ps(src + src_stride + C8NUM); + __m128 src8 = _mm_loadu_ps(src + src_stride + C12NUM); + __m128 src9 = _mm_loadu_ps(src + src_stride * C2NUM); + __m128 src10 = _mm_loadu_ps(src + src_stride * C2NUM + C4NUM); + __m128 src11 = _mm_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m128 src12 = _mm_loadu_ps(src + src_stride * C2NUM + C12NUM); + src += C16NUM; + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias1); + src3 = _mm_add_ps(src3, bias1); + src4 = _mm_add_ps(src4, bias1); + src5 = _mm_add_ps(src5, bias2); + src6 = _mm_add_ps(src6, bias2); + src7 = _mm_add_ps(src7, bias2); + src8 = _mm_add_ps(src8, bias2); + src9 = _mm_add_ps(src9, bias3); + src10 = _mm_add_ps(src10, bias3); + src11 = _mm_add_ps(src11, bias3); + src12 = _mm_add_ps(src12, bias3); + + ActBlock12(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, &src9, &src10, &src11, &src12, relu_type == 1, + relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src5); + _mm_storeu_ps(dst_c4 + C8NUM, src9); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src2); + _mm_storeu_ps(dst_c4 + C4NUM, src6); + _mm_storeu_ps(dst_c4 + C8NUM, src10); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src3); + _mm_storeu_ps(dst_c4 + C4NUM, src7); + _mm_storeu_ps(dst_c4 + C8NUM, src11); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src4); + _mm_storeu_ps(dst_c4 + C4NUM, src8); + _mm_storeu_ps(dst_c4 + C8NUM, src12); + dst_c4 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + src_stride); + __m128 src3 = _mm_loadu_ps(src + src_stride * C2NUM); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + src3 = _mm_add_ps(src3, bias3); + + ActBlock1(&src1, relu_type == 1, relu_type == C3NUM); + ActBlock1(&src2, relu_type == 1, relu_type == C3NUM); + ActBlock1(&src3, relu_type == 1, relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src2); + _mm_storeu_ps(dst_c4 + C8NUM, src3); + dst_c4 += stride; + src += C4NUM; + } + src += plane_stride; + src += C2NUM * src_stride; + } + + for (; loop_c4 <= (int)(oc4div)-C8NUM; loop_c4 += C8NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c4 = dst + loop_c4; + __m128 bias1 = _mm_setzero_ps(); + __m128 bias2 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias2 = _mm_loadu_ps(bias + C4NUM); + bias += C8NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + C4NUM); + __m128 src3 = _mm_loadu_ps(src + C8NUM); + __m128 src4 = _mm_loadu_ps(src + C12NUM); + __m128 src5 = _mm_loadu_ps(src + src_stride); + __m128 src6 = _mm_loadu_ps(src + src_stride + C4NUM); + __m128 src7 = _mm_loadu_ps(src + src_stride + C8NUM); + __m128 src8 = _mm_loadu_ps(src + src_stride + C12NUM); + src += C16NUM; + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias1); + src3 = _mm_add_ps(src3, bias1); + src4 = _mm_add_ps(src4, bias1); + src5 = _mm_add_ps(src5, bias2); + src6 = _mm_add_ps(src6, bias2); + src7 = _mm_add_ps(src7, bias2); + src8 = _mm_add_ps(src8, bias2); + + ActBlock8(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src5); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src2); + _mm_storeu_ps(dst_c4 + C4NUM, src6); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src3); + _mm_storeu_ps(dst_c4 + C4NUM, src7); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src4); + _mm_storeu_ps(dst_c4 + C4NUM, src8); + dst_c4 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + src_stride); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + + ActBlock1(&src1, relu_type == 1, relu_type == C3NUM); + ActBlock1(&src2, relu_type == 1, relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src2); + dst_c4 += stride; + src += C4NUM; + } + src += plane_stride; + src += src_stride; + } + for (; loop_c4 < (int)(oc4div); loop_c4 += C4NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c4 = dst + loop_c4; + __m128 bias1 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias += C4NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + C4NUM); + __m128 src3 = _mm_loadu_ps(src + 8); + __m128 src4 = _mm_loadu_ps(src + 12); + src += C16NUM; + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias1); + src3 = _mm_add_ps(src3, bias1); + src4 = _mm_add_ps(src4, bias1); + + ActBlock4(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src2); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src3); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src4); + dst_c4 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m128 src1 = _mm_loadu_ps(src); + src1 = _mm_add_ps(src1, bias1); + + ActBlock1(&src1, relu_type == 1, relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + dst_c4 += stride; + src += C4NUM; + } + src += plane_stride; + } + if (oc4mod == 0) { + return; + } + __m128 bias1 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias += C4NUM; + } + float *dst_c1 = dst + oc4div; + for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m128 src1 = _mm_loadu_ps(src); + src += C4NUM; + src1 = _mm_add_ps(src1, bias1); + + ActBlock1(&src1, relu_type == 1, relu_type == C3NUM); + + switch (oc4mod) { + case 1: + _mm_store_ss(dst_c1, src1); + dst_c1 += stride; + break; + case C2NUM: + _mm_storel_pi((__m64 *)(dst_c1), src1); + dst_c1 += stride; + break; + case C3NUM: + _mm_storel_pi((__m64 *)(dst_c1), src1); + src1 = _mm_unpackhi_ps(src1, src1); + _mm_store_ss(dst_c1 + C2NUM, src1); + dst_c1 += stride; + break; + case C4NUM: + _mm_storeu_ps(dst_c1, src1); + dst_c1 += stride; + break; + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradTrans.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradTrans.c new file mode 100644 index 0000000000000000000000000000000000000000..168a5273f6f404264125b3e97ced86ac0d65b7b6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradTrans.c @@ -0,0 +1,376 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" + +void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + size_t len_c4 = length * 4; + size_t S_step = length * w * 4; + for (int h1 = 0; h1 < h; ++h1) { + const float *SW = S; + memset(M, 0, len_c4 * w * sizeof(float)); + for (int w_tmp = w; w_tmp > 0; --w_tmp) { + const float *SK = SW; + const float *BK = B; + int k_tmp = k; + for (; k_tmp >= 7; k_tmp -= 7) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + __m128 k4 = _mm_load_ps1(BK + 3 * h); + __m128 k5 = _mm_load_ps1(BK + 4 * h); + __m128 k6 = _mm_load_ps1(BK + 5 * h); + __m128 k7 = _mm_load_ps1(BK + 6 * h); + BK += 7 * h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s1 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s1, k1, M1); + __m128 s2 = _mm_loadu_ps(SK + S_step); + M2 = _mm_fmadd_ps(s2, k2, M2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_fmadd_ps(s3, k3, M1); + __m128 s4 = _mm_loadu_ps(SK + 3 * S_step); + M2 = _mm_fmadd_ps(s4, k4, M2); + __m128 s5 = _mm_loadu_ps(SK + 4 * S_step); + M1 = _mm_fmadd_ps(s5, k5, M1); + __m128 s6 = _mm_loadu_ps(SK + 5 * S_step); + M2 = _mm_fmadd_ps(s6, k6, M2); + __m128 s7 = _mm_loadu_ps(SK + 6 * S_step); + M1 = _mm_fmadd_ps(s7, k7, M1); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(SK + S_step); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + __m128 s4 = _mm_loadu_ps(SK + 3 * S_step); + s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); + __m128 s5 = _mm_loadu_ps(SK + 4 * S_step); + M1 = _mm_add_ps(M1, _mm_mul_ps(s5, k5)); + __m128 s6 = _mm_loadu_ps(SK + 5 * S_step); + s1 = _mm_add_ps(s1, _mm_mul_ps(s6, k6)); + __m128 s7 = _mm_loadu_ps(SK + 6 * S_step); + M1 = _mm_add_ps(M1, _mm_mul_ps(s7, k7)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + M -= len_c4; + SK += 7 * S_step - len_c4; + } + for (; k_tmp >= 4; k_tmp -= 4) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + __m128 k4 = _mm_load_ps1(BK + 3 * h); + BK += 4 * h; + int len_tmp = length; +#ifdef ENABLE_AVX + for (; len_tmp >= C2NUM; len_tmp -= C2NUM, SK += C8NUM, M += C8NUM) { + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_loadu_ps(M + C4NUM); + __m128 s1 = _mm_loadu_ps(SK); + __m128 s11 = _mm_loadu_ps(SK + C4NUM); + M1 = _mm_fmadd_ps(s1, k1, M1); + M2 = _mm_fmadd_ps(s11, k1, M2); + __m128 s2 = _mm_loadu_ps(SK + S_step); + __m128 s22 = _mm_loadu_ps(SK + S_step + C4NUM); + M1 = _mm_fmadd_ps(s2, k2, M1); + M2 = _mm_fmadd_ps(s22, k2, M2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + __m128 s33 = _mm_loadu_ps(SK + 2 * S_step + C4NUM); + M1 = _mm_fmadd_ps(s3, k3, M1); + M2 = _mm_fmadd_ps(s33, k3, M2); + __m128 s4 = _mm_loadu_ps(SK + 3 * S_step); + __m128 s44 = _mm_loadu_ps(SK + 3 * S_step + C4NUM); + M1 = _mm_fmadd_ps(s4, k4, M1); + M2 = _mm_fmadd_ps(s44, k4, M2); + _mm_storeu_ps(M, M1); + _mm_storeu_ps(M + C4NUM, M2); + } +#endif + for (; len_tmp > 0; --len_tmp, SK += 4, M += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s1 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s1, k1, M1); + __m128 s2 = _mm_loadu_ps(SK + S_step); + M2 = _mm_fmadd_ps(s2, k2, M2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_fmadd_ps(s3, k3, M1); + __m128 s4 = _mm_loadu_ps(SK + 3 * S_step); + M2 = _mm_fmadd_ps(s4, k4, M2); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(SK + S_step); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + __m128 s4 = _mm_loadu_ps(SK + 3 * S_step); + s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + M -= len_c4; + SK += 4 * S_step - len_c4; + } + for (; k_tmp >= 3; k_tmp -= 3) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + BK += 3 * h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += 4, M += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s1 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s1, k1, M1); + __m128 s2 = _mm_loadu_ps(SK + S_step); + M2 = _mm_fmadd_ps(s2, k2, M2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_fmadd_ps(s3, k3, M1); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(SK + S_step); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + M -= len_c4; + SK += 3 * S_step - len_c4; + } + for (; k_tmp > 0; k_tmp -= 1) { + __m128 k1 = _mm_load_ps1(BK); + BK += h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += 4, M += 4) { + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); +#ifdef ENABLE_AVX + M1 = _mm_fmadd_ps(s0, k1, M1); +#else + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); +#endif + _mm_storeu_ps(M, M1); + } + M -= len_c4; + SK += S_step - len_c4; + } + SW += len_c4; + M += len_c4; + } + B += 1; + } +} + +void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + size_t len_c4 = length * 4, k_step = len_c4 * k; + for (int h1 = 0; h1 < h; ++h1, S += k_step) { + const float *BW = B; + memset(M, 0, len_c4 * w * sizeof(float)); + for (int ww = 0; ww < w; ++ww, BW += 1, M += len_c4) { + const float *SK = S, *BK = BW; + int k_tmp = k; + for (; k_tmp >= 7; k_tmp -= 7, M -= len_c4) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + __m128 k4 = _mm_load_ps1(BK + 3 * h); + __m128 k5 = _mm_load_ps1(BK + 4 * h); + __m128 k6 = _mm_load_ps1(BK + 5 * h); + __m128 k7 = _mm_load_ps1(BK + 6 * h); + BK += 7 * h; + const float *S2 = SK + len_c4, *S3 = S2 + len_c4; + const float *S4 = S3 + len_c4, *S5 = S4 + len_c4; + const float *S6 = S5 + len_c4, *S7 = S6 + len_c4; + for (int len_tmp = length; len_tmp > 0; + --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4, S4 += 4, S5 += 4, S6 += 4, S7 += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s1 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s1, k1, M1); + __m128 s2 = _mm_loadu_ps(S2); + M2 = _mm_fmadd_ps(s2, k2, M2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_fmadd_ps(s3, k3, M1); + __m128 s4 = _mm_loadu_ps(S4); + M2 = _mm_fmadd_ps(s4, k4, M2); + __m128 s5 = _mm_loadu_ps(S5); + M1 = _mm_fmadd_ps(s5, k5, M1); + __m128 s6 = _mm_loadu_ps(S6); + M2 = _mm_fmadd_ps(s6, k6, M2); + __m128 s7 = _mm_loadu_ps(S7); + M1 = _mm_fmadd_ps(s7, k7, M1); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(S2); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + __m128 s4 = _mm_loadu_ps(S4); + s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); + __m128 s5 = _mm_loadu_ps(S5); + M1 = _mm_add_ps(M1, _mm_mul_ps(s5, k5)); + __m128 s6 = _mm_loadu_ps(S6); + s1 = _mm_add_ps(s1, _mm_mul_ps(s6, k6)); + __m128 s7 = _mm_loadu_ps(S7); + M1 = _mm_add_ps(M1, _mm_mul_ps(s7, k7)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + SK = S7; + } + for (; k_tmp >= 4; k_tmp -= 4, M -= len_c4) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + __m128 k4 = _mm_load_ps1(BK + 3 * h); + BK += 4 * h; + const float *S2 = SK + len_c4; + const float *S3 = S2 + len_c4; + const float *S4 = S3 + len_c4; + int len_tmp = length; +#ifdef ENABLE_AVX + for (; len_tmp >= C2NUM; len_tmp -= C2NUM, M += C8NUM, SK += C8NUM, S2 += C8NUM, S3 += C8NUM, S4 += C8NUM) { + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_loadu_ps(M + C4NUM); + __m128 s1 = _mm_loadu_ps(SK); + __m128 s11 = _mm_loadu_ps(SK + C4NUM); + M1 = _mm_fmadd_ps(s1, k1, M1); + M2 = _mm_fmadd_ps(s11, k1, M2); + __m128 s2 = _mm_loadu_ps(S2); + __m128 s22 = _mm_loadu_ps(S2 + C4NUM); + M1 = _mm_fmadd_ps(s2, k2, M1); + M2 = _mm_fmadd_ps(s22, k2, M2); + __m128 s3 = _mm_loadu_ps(S3); + __m128 s33 = _mm_loadu_ps(S3 + C4NUM); + M1 = _mm_fmadd_ps(s3, k3, M1); + M2 = _mm_fmadd_ps(s33, k3, M2); + __m128 s4 = _mm_loadu_ps(S4); + __m128 s44 = _mm_loadu_ps(S4 + C4NUM); + M1 = _mm_fmadd_ps(s4, k4, M1); + M2 = _mm_fmadd_ps(s44, k4, M2); + _mm_storeu_ps(M, M1); + _mm_storeu_ps(M + C4NUM, M2); + } +#endif + for (; len_tmp > 0; --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4, S4 += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s1 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s1, k1, M1); + __m128 s2 = _mm_loadu_ps(S2); + M2 = _mm_fmadd_ps(s2, k2, M2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_fmadd_ps(s3, k3, M1); + __m128 s4 = _mm_loadu_ps(S4); + M2 = _mm_fmadd_ps(s4, k4, M2); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(S2); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + __m128 s4 = _mm_loadu_ps(S4); + s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + SK = S4; + } + for (; k_tmp >= 3; k_tmp -= 3, M -= len_c4) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + BK += 3 * h; + const float *S2 = SK + len_c4; + const float *S3 = S2 + len_c4; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s0, k1, M1); + __m128 s1 = _mm_loadu_ps(S2); + M2 = _mm_fmadd_ps(s1, k2, M2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_fmadd_ps(s3, k3, M1); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(S2); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + SK = S3; + } + for (; k_tmp >= 1; k_tmp -= 1, M -= len_c4) { + __m128 k1 = _mm_load_ps1(BK); + BK += h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4) { + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); +#ifdef ENABLE_AVX + M1 = _mm_fmadd_ps(s0, k1, M1); +#else + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); +#endif + _mm_storeu_ps(M, M1); + } + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/sse_common.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/sse_common.h new file mode 100644 index 0000000000000000000000000000000000000000..5885954a77e715eec6396f009c4b87e0767fc933 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/sse_common.h @@ -0,0 +1,390 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_ +#define MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_ + +#define SSE_ROW_NUM_1 1 +#define SSE_ROW_NUM_2 2 +#define SSE_ROW_NUM_3 3 + +#define SSE_INDEX_1 1 +#define SSE_INDEX_2 2 +#define SSE_INDEX_3 3 +#define SSE_INDEX_4 4 +#define SSE_INDEX_5 5 +#define SSE_INDEX_6 6 + +#define SSE_SHUFFLE_0321 (_MM_SHUFFLE(0, 3, 2, 1)) + +#define SSE_ACT_RELU 1 +#define SSE_ACT_RELU6 3 + +static inline void ActBlock1(__m128 *v1, size_t relu, size_t relu6) { + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + *v1 = _mm_max_ps(zero_ma, *v1); + } + if (relu6) { + __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + *v1 = _mm_min_ps(relu6_ma, *v1); + } +} + +static inline void ActBlock2(__m128 *v1, __m128 *v2, size_t relu, size_t relu6) { + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + *v1 = _mm_max_ps(zero_ma, *v1); + *v2 = _mm_max_ps(zero_ma, *v2); + } + if (relu6) { + __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + *v1 = _mm_min_ps(relu6_ma, *v1); + *v2 = _mm_min_ps(relu6_ma, *v2); + } +} + +static inline void ActBlock4(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, size_t relu, size_t relu6) { + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + *v1 = _mm_max_ps(zero_ma, *v1); + *v2 = _mm_max_ps(zero_ma, *v2); + *v3 = _mm_max_ps(zero_ma, *v3); + *v4 = _mm_max_ps(zero_ma, *v4); + } + if (relu6) { + __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + *v1 = _mm_min_ps(relu6_ma, *v1); + *v2 = _mm_min_ps(relu6_ma, *v2); + *v3 = _mm_min_ps(relu6_ma, *v3); + *v4 = _mm_min_ps(relu6_ma, *v4); + } +} + +static inline void ActBlock12(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7, + __m128 *v8, __m128 *v9, __m128 *v10, __m128 *v11, __m128 *v12, size_t relu, + size_t relu6) { + if (relu || relu6) { + __m128 zero_ma = _mm_setzero_ps(); + *v1 = _mm_max_ps(zero_ma, *v1); + *v2 = _mm_max_ps(zero_ma, *v2); + *v3 = _mm_max_ps(zero_ma, *v3); + *v4 = _mm_max_ps(zero_ma, *v4); + *v5 = _mm_max_ps(zero_ma, *v5); + *v6 = _mm_max_ps(zero_ma, *v6); + *v7 = _mm_max_ps(zero_ma, *v7); + *v8 = _mm_max_ps(zero_ma, *v8); + *v9 = _mm_max_ps(zero_ma, *v9); + *v10 = _mm_max_ps(zero_ma, *v10); + *v11 = _mm_max_ps(zero_ma, *v11); + *v12 = _mm_max_ps(zero_ma, *v12); + } + if (relu6) { + __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + *v1 = _mm_min_ps(relu6_ma, *v1); + *v2 = _mm_min_ps(relu6_ma, *v2); + *v3 = _mm_min_ps(relu6_ma, *v3); + *v4 = _mm_min_ps(relu6_ma, *v4); + *v5 = _mm_min_ps(relu6_ma, *v5); + *v6 = _mm_min_ps(relu6_ma, *v6); + *v7 = _mm_min_ps(relu6_ma, *v7); + *v8 = _mm_min_ps(relu6_ma, *v8); + *v9 = _mm_min_ps(relu6_ma, *v9); + *v10 = _mm_min_ps(relu6_ma, *v10); + *v11 = _mm_min_ps(relu6_ma, *v11); + *v12 = _mm_min_ps(relu6_ma, *v12); + } +} + +static inline void ActBlock8(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7, + __m128 *v8, size_t relu_type) { + __m128 relu6 = _mm_set_ps1(6.0); + __m128 zero = _mm_setzero_ps(); + switch (relu_type) { + case SSE_ACT_RELU6: + *v1 = _mm_min_ps(*v1, relu6); + *v2 = _mm_min_ps(*v2, relu6); + *v3 = _mm_min_ps(*v3, relu6); + *v4 = _mm_min_ps(*v4, relu6); + *v5 = _mm_min_ps(*v5, relu6); + *v6 = _mm_min_ps(*v6, relu6); + *v7 = _mm_min_ps(*v7, relu6); + *v8 = _mm_min_ps(*v8, relu6); + case SSE_ACT_RELU: + *v1 = _mm_max_ps(*v1, zero); + *v2 = _mm_max_ps(*v2, zero); + *v3 = _mm_max_ps(*v3, zero); + *v4 = _mm_max_ps(*v4, zero); + *v5 = _mm_max_ps(*v5, zero); + *v6 = _mm_max_ps(*v6, zero); + *v7 = _mm_max_ps(*v7, zero); + *v8 = _mm_max_ps(*v8, zero); + default: + break; + } +} + +static inline void WriteCol1(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_store_ss(*dst, *dst1); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol2(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int r) { + _mm_store_ss(*dst, *dst1); + *dst1 = _mm_shuffle_ps(*dst1, *dst1, SSE_SHUFFLE_0321); + _mm_store_ss(*dst, *dst1); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321); + _mm_store_ss(*dst, *dst3); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321); + _mm_store_ss(*dst, *dst5); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321); + _mm_store_ss(*dst, *dst7); + } +} + +static inline void WriteCol2Opt(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int r) { + _mm_store_ss(*dst, *dst1); + *dst1 = _mm_shuffle_ps(*dst1, *dst1, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst1); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst3); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst5); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst7); + *dst += stride; + *dst += SSE_INDEX_2; + } +} + +static inline void WriteCol3(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_2, *dst3); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_2, *dst5); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_2, *dst7); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol4(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol5(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_store_ss(*dst + SSE_INDEX_4, *dst2); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_store_ss(*dst + SSE_INDEX_4, *dst4); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_store_ss(*dst + SSE_INDEX_4, *dst6); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_store_ss(*dst + SSE_INDEX_4, *dst8); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol6(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_store_ss(*dst + SSE_INDEX_4, *dst2); + *dst2 = _mm_shuffle_ps(*dst2, *dst2, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst2); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_store_ss(*dst + SSE_INDEX_4, *dst4); + *dst4 = _mm_shuffle_ps(*dst4, *dst4, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst4); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_store_ss(*dst + SSE_INDEX_4, *dst6); + *dst6 = _mm_shuffle_ps(*dst6, *dst6, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst6); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_store_ss(*dst + SSE_INDEX_4, *dst8); + *dst8 = _mm_shuffle_ps(*dst8, *dst8, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst8); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol7(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_store_ss(*dst + SSE_INDEX_4, *dst2); + *dst2 = _mm_shuffle_ps(*dst2, *dst2, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst2); + *dst2 = _mm_shuffle_ps(*dst2, *dst2, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_6, *dst2); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_store_ss(*dst + SSE_INDEX_4, *dst4); + *dst4 = _mm_shuffle_ps(*dst4, *dst4, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst4); + *dst4 = _mm_shuffle_ps(*dst4, *dst4, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_6, *dst4); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_store_ss(*dst + SSE_INDEX_4, *dst6); + *dst6 = _mm_shuffle_ps(*dst6, *dst6, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst6); + *dst6 = _mm_shuffle_ps(*dst6, *dst6, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_6, *dst6); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_store_ss(*dst + SSE_INDEX_4, *dst8); + *dst8 = _mm_shuffle_ps(*dst8, *dst8, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst8); + *dst8 = _mm_shuffle_ps(*dst8, *dst8, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_6, *dst8); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol8(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_storeu_ps(*dst + SSE_INDEX_4, *dst2); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_storeu_ps(*dst + SSE_INDEX_4, *dst4); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_storeu_ps(*dst + SSE_INDEX_4, *dst6); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_storeu_ps(*dst + SSE_INDEX_4, *dst8); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void DoBiasBlock8(const float *bias_ptr, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, + __m128 *dst5, __m128 *dst6, __m128 *dst7, __m128 *dst8) { + __m128 bias1 = _mm_loadu_ps(bias_ptr); + __m128 bias2 = _mm_loadu_ps(bias_ptr + C4NUM); + *dst1 = _mm_add_ps(*dst1, bias1); + *dst2 = _mm_add_ps(*dst2, bias2); + *dst3 = _mm_add_ps(*dst3, bias1); + *dst4 = _mm_add_ps(*dst4, bias2); + *dst5 = _mm_add_ps(*dst5, bias1); + *dst6 = _mm_add_ps(*dst6, bias2); + *dst7 = _mm_add_ps(*dst7, bias1); + *dst8 = _mm_add_ps(*dst8, bias2); +} + +#endif // MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel.c new file mode 100644 index 0000000000000000000000000000000000000000..19a32c48ed61ce5d32075d7ef0c9f9e454ea745d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel.c @@ -0,0 +1,124 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/init_exec_env.h" + +static KernelCreator g_kernelCreatorRegistry[PrimType_MAX][16]; + +void RegKernelCreator(int opType, int dataType, KernelCreator creator) { + g_kernelCreatorRegistry[opType][REGIST_DT(dataType)] = creator; +} + +void Init_MSC_VER_kernels(void) { +#ifdef _MSC_VER + /* VS env do not support automatic register + * register here first time */ + static bool inited = false; + if (inited == false) { + init_vs_kernels(g_kernelCreatorRegistry); + inited = true; + } +#endif + return; +} + +bool checkOpValid(int opType) { + if (opType < PrimType_MIN || opType >= PrimType_MAX) { + return false; + } + return true; +} + +bool SupportKernelC(int opType, int dataType) { + Init_MSC_VER_kernels(); + const int length = 16; + if (REGIST_DT(dataType) < 0 || REGIST_DT(dataType) >= length) { + return false; + } + if (!checkOpValid(opType)) { + return false; + } + KernelCreator creator = g_kernelCreatorRegistry[opType][REGIST_DT(dataType)]; + return creator != NULL; +} + +int DefaultThreadUpdate(int32_t type, int64_t load, int64_t store, int64_t unit, int thread) { + return thread > 0 ? thread : 1; +} + +int NNACLKernelInferShape(struct KernelBase *self) { return NNACL_ERR; } + +int NNACLCheckKernelBase(KernelBase *kernel_base) { + CheckExecEnv(kernel_base); + + if (kernel_base->param_ == NULL) { + return NNACL_ERR; + } + + if (kernel_base->thread_nr_ <= 0 || kernel_base->thread_nr_ > MAX_THREAD_NUM) { + return NNACL_ERR; + } + + if (kernel_base->in_size_ == 0 || kernel_base->in_ == NULL) { + return NNACL_ERR; + } + if (kernel_base->out_size_ == 0 || kernel_base->out_ == NULL) { + return NNACL_ERR; + } + return NNACL_OK; +} + +KernelBase *CreateKernel(OpParameter *param, TensorC **ins, size_t in_size, TensorC **outs, size_t out_size, + int data_type, ExecEnv *env) { + Init_MSC_VER_kernels(); + if (param == NULL) { + return NULL; + } + if (!checkOpValid(param->type_)) { + return NULL; + } + + KernelCreator creator = g_kernelCreatorRegistry[param->type_][REGIST_DT(data_type)]; + if (creator == NULL) { + return NULL; + } + + KernelBase *kernel_base = creator(param, data_type); + if (kernel_base == NULL) { + return NULL; + } + + kernel_base->InferShape = NNACLKernelInferShape; + kernel_base->UpdateThread = DefaultThreadUpdate; + kernel_base->env_ = env; + kernel_base->param_ = param; + kernel_base->thread_nr_ = param->thread_num_; + kernel_base->train_session_ = param->is_train_session_; + kernel_base->in_ = ins; + kernel_base->in_size_ = in_size; + kernel_base->out_ = outs; + kernel_base->out_size_ = out_size; + + int ret = NNACLCheckKernelBase(kernel_base); + if (ret != NNACL_OK) { + return NULL; + } + + return kernel_base; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..84378c7fe54716e3fe688eaff45b991730a3a37c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel.h @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_H_ +#define NNACL_KERNEL_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/infer/common_infer.h" + +typedef struct ExecEnv { + void *allocator_; + void *thread_pool_; + void *(*Alloc)(void *allocator, size_t sz); + void (*Free)(void *allocator, void *ptr); + int (*ParallelLaunch)(void *thread_pool, void *task, void *param, int task_num); +} ExecEnv; + +typedef struct KernelBase { + int (*Release)(struct KernelBase *self); + int (*Prepare)(struct KernelBase *self); + int (*Compute)(struct KernelBase *self); + int (*Resize)(struct KernelBase *self); + int (*InferShape)(struct KernelBase *self); + int (*UpdateThread)(int32_t type, int64_t load, int64_t store, int64_t unit, int thread); + OpParameter *param_; + int thread_nr_; + ExecEnv *env_; + TensorC **in_; + size_t in_size_; + TensorC **out_; + size_t out_size_; + bool train_session_; + void *workspace_; /* only used in train */ + int work_size_; /* only used in train */ +} KernelBase; + +#ifdef _MSC_VER +#define REG_KERNEL_CREATOR(op, data_type, func) +#else +#define REG_KERNEL_CREATOR(op, data_type, func) \ + __attribute__((constructor(102))) void Reg##op##data_type##Creator() { RegKernelCreator(op, data_type, func); } +#endif + +#define REGIST_DT(DataType) (DataType - kNumberTypeBegin - 1) +typedef KernelBase *(*KernelCreator)(OpParameter *param, int data_type); +void RegKernelCreator(int opType, int dataType, KernelCreator func); + +#ifdef __cplusplus +extern "C" { +#endif +KernelBase *CreateKernel(OpParameter *param, TensorC **ins, size_t in_size, TensorC **outs, size_t out_size, + int data_type, ExecEnv *env); +bool SupportKernelC(int opType, int dataType); +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/activation.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/activation.c new file mode 100644 index 0000000000000000000000000000000000000000..50cb13897dc2987f6af7d8a8c836b76edbfc2469 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/activation.c @@ -0,0 +1,194 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/activation.h" +#include "nnacl_c/activation_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/activation_fp16.h" +#endif + +typedef struct ActivationStruct { + KernelBase base; + int data_type_; + ActType act_type_; +} ActivationStruct; + +int ActivationResize(struct KernelBase *self) { + ActivationStruct *activation = (ActivationStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(activation); + self->thread_nr_ = self->UpdateThread(TC_TYPE(PrimType_Activation, activation->act_type_), 1, 1, + NNACLGetElementNum(self->out_[0]), self->thread_nr_); + return NNACL_OK; +} + +int activation_fp32_run(ActivationStruct *activation, int task_id, int count, int stride) { + ActivationParameter *param = (ActivationParameter *)activation->base.param_; + float *input = activation->base.in_[0]->data_; + float *output = activation->base.out_[0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + + switch (activation->act_type_) { + case ActType_Relu: + return Fp32Relu(input + task_id * stride, count, output + task_id * stride); + case ActType_Relu6: + return Fp32Relu6(input + task_id * stride, count, output + task_id * stride); + case ActType_LeakyRelu: + return LRelu(input + task_id * stride, count, output + task_id * stride, param->alpha_); + case ActType_Sigmoid: + return Sigmoid(input + task_id * stride, count, output + task_id * stride); + case ActType_Tanh: + return Tanh(input + task_id * stride, count, output + task_id * stride); + case ActType_Swish: + return Swish(input + task_id * stride, count, output + task_id * stride); + case ActType_HSwish: + return HSwish(input + task_id * stride, count, output + task_id * stride); + case ActType_HSigmoid: + return HSigmoid(input + task_id * stride, count, output + task_id * stride); + case ActType_HardTanh: + return HardTanh(input + task_id * stride, count, output + task_id * stride, param->min_val_, param->max_val_); + case ActType_Gelu: + return Gelu(input + task_id * stride, count, output + task_id * stride, param->approximate_); + case ActType_Softplus: + return Softplus(input + task_id * stride, count, output + task_id * stride); + case ActType_Elu: + return Elu(input + task_id * stride, count, output + task_id * stride, param->alpha_); + default: + return NNACL_ACTIVATION_TYPE_INVALID; + } +} + +int activation_int32_run(ActivationStruct *activation, int task_id, int count, int stride) { + int32_t *input = activation->base.in_[0]->data_; + int32_t *output = activation->base.out_[0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + + switch (activation->act_type_) { + case ActType_Relu: + return Int32Relu(input + task_id * stride, count, output + task_id * stride); + default: + return NNACL_ACTIVATION_TYPE_INVALID; + } +} + +int activation_fp16_run(ActivationStruct *activation, int task_id, int count, int stride) { +#ifdef ENABLE_FP16 + ActivationParameter *param = (ActivationParameter *)activation->base.param_; + float16_t *input = activation->base.in_[0]->data_; + float16_t *output = activation->base.out_[0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + + switch (activation->act_type_) { + case ActType_Relu: + return ReluFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_Relu6: + return Relu6Fp16(input + stride * task_id, output + stride * task_id, count); + case ActType_LeakyRelu: + return LReluFp16(input + stride * task_id, output + stride * task_id, count, param->alpha_); + case ActType_Sigmoid: + return SigmoidFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_Tanh: + return TanhFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_HSwish: + return HSwishFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_Swish: + return SwishFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_HSigmoid: + return HSigmoidFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_HardTanh: + return HardTanhFp16(input + stride * task_id, count, output + stride * task_id, param->min_val_, param->max_val_); + case ActType_Gelu: + return GeluFp16(input + stride * task_id, count, output + stride * task_id, true); + case ActType_Softplus: + return SoftplusFp16(input + stride * task_id, count, output + stride * task_id); + case ActType_Elu: + return EluFp16(input + stride * task_id, count, output + stride * task_id, param->alpha_); + default: + return NNACL_ACTIVATION_TYPE_INVALID; + } +#endif + return NNACL_DISABLE_FP16; +} + +int ActivationImpl(void *cdata, int task_id, float l, float r) { + ActivationStruct *activation = (ActivationStruct *)cdata; + + int ele_num = NNACLGetElementNum(activation->base.in_[0]); + NNACL_CHECK_ZERO_RETURN_ERR(activation->base.thread_nr_); + int stride = UP_DIV(ele_num, activation->base.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(stride, task_id, NNACL_ERR); + int count = MSMIN(stride, ele_num - stride * task_id); + if (count <= 0) { + return NNACL_OK; + } + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(stride, task_id), NNACL_ERR); + + switch (activation->data_type_) { + case kNumberTypeFloat32: + return activation_fp32_run(activation, task_id, count, stride); + case kNumberTypeFloat16: + return activation_fp16_run(activation, task_id, count, stride); + case kNumberTypeInt32: + return activation_int32_run(activation, task_id, count, stride); + default: + return NNACL_ACTIVATION_TYPE_INVALID; + } +} + +int ActivationCompute(struct KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, ActivationImpl, self, self->thread_nr_); +} + +KernelBase *CreateActivation(OpParameter *param, int data_type) { + ActivationParameter *act = (ActivationParameter *)(param); + + int type = act->type_; + if (data_type == kNumberTypeInt32) { + if (type != ActType_Relu) { + return NULL; + } + } + + if (data_type == kNumberTypeFloat32 || data_type == kNumberTypeFloat16) { + if (type != ActType_Relu && type != ActType_Relu6 && type != ActType_LeakyRelu && type != ActType_Sigmoid && + type != ActType_Tanh && type != ActType_HSwish && type != ActType_Swish && type != ActType_HardTanh && + type != ActType_Gelu && type != ActType_HSigmoid && type != ActType_Softplus && type != ActType_Elu) { + return NULL; + } + } + + ActivationStruct *activation = (ActivationStruct *)malloc(sizeof(ActivationStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(activation); + memset(activation, 0, sizeof(ActivationStruct)); + + activation->data_type_ = data_type; + activation->act_type_ = act->type_; + activation->base.Prepare = DefaultPrepare1In1Out; + activation->base.Release = DefaultRelease; + activation->base.Resize = ActivationResize; + activation->base.Compute = ActivationCompute; + return (KernelBase *)activation; +} + +REG_KERNEL_CREATOR(PrimType_Activation, kNumberTypeFloat32, CreateActivation) +REG_KERNEL_CREATOR(PrimType_Activation, kNumberTypeFloat16, CreateActivation) +REG_KERNEL_CREATOR(PrimType_Activation, kNumberTypeUInt32, CreateActivation) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/activation.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/activation.h new file mode 100644 index 0000000000000000000000000000000000000000..b7b06c0138f31e0a11b2d7fd834d22240bb5fcb0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/activation.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_ACTIVATION_H_ +#define NNACL_KERNEL_ACTIVATION_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateActivation(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ACTIVATION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/addn.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/addn.c new file mode 100644 index 0000000000000000000000000000000000000000..3916981a6064b3cf599b19baae7cbe140fd8cf65 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/addn.c @@ -0,0 +1,144 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/addn.h" +#include "nnacl_c/fp32/add_fp32.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/arithmetic_fp16.h" +#endif + +int AddNLaunch(void *cdata, int task_id, float l, float r) { + AddNStruct *addn = (AddNStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(addn); + + int count_per_thread = UP_DIV(addn->elements_num_, addn->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, count_per_thread, NNACL_ERR); + int count = MSMIN(count_per_thread, addn->elements_num_ - task_id * count_per_thread); + int stride = count_per_thread * task_id; + +#ifdef ENABLE_FP16 + if (addn->data_type_ == kNumberTypeFloat16) { + return ElementAddFp16((float16_t *)addn->in1_addr_ + stride, (float16_t *)addn->in2_addr_ + stride, + (float16_t *)addn->out_addr_ + stride, count); + } +#endif + return ElementAdd((float *)addn->in1_addr_ + stride, (float *)addn->in2_addr_ + stride, + (float *)addn->out_addr_ + stride, count); +} + +void AddNCompute(AddNStruct *addn, bool same_shape, bool first_scalar) { +#ifdef ENABLE_FP16 + if (addn->data_type_ == kNumberTypeFloat16) { + if (same_shape) { + ElementAddFp16((float16_t *)addn->in1_addr_, (float16_t *)addn->in2_addr_, (float16_t *)addn->out_addr_, + addn->elements_num_); + } else { + ElementOptAddFp16((float16_t *)addn->in1_addr_, (float16_t *)addn->in2_addr_, (float16_t *)addn->out_addr_, + addn->elements_num_, first_scalar); + } + return; + } +#endif + + if (same_shape) { + ElementAdd((float *)addn->in1_addr_, (float *)addn->in2_addr_, (float *)addn->out_addr_, addn->elements_num_); + } else { + ElementOptAdd((float *)addn->in1_addr_, (float *)addn->in2_addr_, (float *)addn->out_addr_, addn->elements_num_, + first_scalar); + } + return; +} + +int AddNComputeNoParallel(AddNStruct *addn) { + TensorC *in0_tensor = addn->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in0_tensor); + TensorC *in1_tensor = addn->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in1_tensor); + AddNCompute(addn, NNACLIsShapeSame(in0_tensor, in1_tensor), NNACLGetElementNum(in0_tensor) == 1); + + for (size_t i = Index2; i < addn->base_.in_size_; i++) { + TensorC *in_tensor = addn->base_.in_[i]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + addn->in1_addr_ = in_tensor->data_; + addn->in2_addr_ = addn->out_addr_; + AddNCompute(addn, NNACLIsShapeSame(in_tensor, addn->base_.out_[OUTPUT_INDEX]), NNACLGetElementNum(in_tensor) == 1); + } + return NNACL_OK; +} + +int AddnResize(struct KernelBase *self) { + AddNStruct *addn = (AddNStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(addn); + + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + addn->elements_num_ = NNACLGetElementNum(out_tensor); + return NNACL_OK; +} + +int AddnCompute(struct KernelBase *self) { + AddNStruct *addn = (AddNStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(addn); + + addn->in1_addr_ = self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(addn->in1_addr_); + addn->in2_addr_ = self->in_[SECOND_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(addn->in2_addr_); + addn->out_addr_ = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(addn->out_addr_); + + if (addn->elements_num_ < self->thread_nr_) { + return AddNComputeNoParallel(addn); + } + + for (int i = 0; i < self->in_size_; i++) { + TensorC *in_tensor = self->in_[i]; + if (!NNACLIsShapeSame(in_tensor, self->out_[OUTPUT_INDEX])) { + return NNACL_ADDN_SHAPE_UNMATCH; + } + } + + int ret = self->env_->ParallelLaunch(self->env_->thread_pool_, AddNLaunch, self, self->thread_nr_); + if (ret != NNACL_OK) { + return ret; + } + + for (size_t i = Index2; i < self->in_size_; ++i) { + addn->in1_addr_ = self->in_[i]->data_; + NNACL_CHECK_NULL_RETURN_ERR(addn->in1_addr_); + addn->in2_addr_ = addn->out_addr_; + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, AddNLaunch, self, self->thread_nr_); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +KernelBase *CreateAddN(OpParameter *param, int data_type) { + AddNStruct *addn = (AddNStruct *)malloc(sizeof(AddNStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(addn); + addn->data_type_ = data_type; + addn->base_.Prepare = DefaultPrepare1In1Out; + addn->base_.Resize = AddnResize; + addn->base_.Release = DefaultRelease; + addn->base_.Compute = AddnCompute; + return (KernelBase *)addn; +} + +REG_KERNEL_CREATOR(PrimType_AddN, kNumberTypeFloat16, CreateAddN) +REG_KERNEL_CREATOR(PrimType_AddN, kNumberTypeFloat32, CreateAddN) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/addn.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/addn.h new file mode 100644 index 0000000000000000000000000000000000000000..90430088583b7f421d1461d55a0034916e7b597f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/addn.h @@ -0,0 +1,35 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ADDN_H_ +#define NNACL_KERNEL_ADDN_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct AddNStruct { + KernelBase base_; + int data_type_; + int elements_num_; + void *in1_addr_; + void *in2_addr_; + void *out_addr_; +} AddNStruct; + +KernelBase *CreateAddN(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ADDN_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arg_min_max.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arg_min_max.c new file mode 100644 index 0000000000000000000000000000000000000000..2bd0dd99a51b456d938c58ed38ca110587ec4849 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arg_min_max.c @@ -0,0 +1,127 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/arg_min_max.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/arg_min_max_parameter.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/fp32/arg_min_max_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/arg_min_max_fp16.h" +#endif + +int ArgMinMaxPrepare(KernelBase *self) { + ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arg_min_max); + ArgMinMaxParameter *param = (ArgMinMaxParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + arg_min_max->arg_elements_alloc_ = param->topk_ > Num1 || param->keep_dims_; + arg_min_max->compute_.topk_ = param->topk_; + arg_min_max->compute_.axis_ = param->axis_; + arg_min_max->compute_.keep_dims_ = param->keep_dims_; + arg_min_max->compute_.out_value_ = param->out_value_; + arg_min_max->compute_.get_max_ = self->param_->type_ == PrimType_ArgMinFusion ? false : true; + return NNACL_OK; +} + +int ArgMinMaxResize(KernelBase *self) { + ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arg_min_max); + ArgMinMaxComputeParam *compute = &arg_min_max->compute_; + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + ComputeStrides(input_tensor->shape_, compute->in_strides_, input_tensor->shape_size_); + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + ComputeStrides(output_tensor->shape_, compute->out_strides_, output_tensor->shape_size_); + + compute->dims_size_ = (int)input_tensor->shape_size_; + compute->axis_ = compute->axis_ < 0 ? compute->axis_ + compute->dims_size_ : compute->axis_; + NNACL_CHECK_FALSE(compute->topk_ <= 0, NNACL_ARG_MIN_MAX_AXIS_INVALID); + NNACL_CHECK_FALSE(compute->topk_ > input_tensor->shape_[compute->axis_], NNACL_ARG_MIN_MAX_AXIS_INVALID); + return NNACL_OK; +} + +int ArgMinMaxCompute(KernelBase *self) { + ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arg_min_max); + ArgMinMaxParameter *param = (ArgMinMaxParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + void *in_data = in_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(in_data); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + void *out_data = out_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + + void *out_value = NULL; + if (self->out_size_ == TWO_TENSOR) { + out_value = self->out_[Index1]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_value); + } + + if (arg_min_max->arg_elements_alloc_) { + int arg_size = in_tensor->shape_[arg_min_max->compute_.axis_] * (int)sizeof(ArgElement); + NNACL_CHECK_MALLOC_SIZE(arg_size); + arg_min_max->compute_.arg_elements_ = (ArgElement *)self->env_->Alloc(self->env_->allocator_, arg_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arg_min_max->compute_.arg_elements_); + } + + int ret = NNACL_OK; + int *in_shape = in_tensor->shape_; + if (in_tensor->data_type_ == kNumberTypeFloat32) { + ArgMinMaxFp32((float *)in_data, out_data, (float *)out_value, in_shape, &arg_min_max->compute_); +#ifdef ENABLE_FP16 + } else if (in_tensor->data_type_ == kNumberTypeFloat16) { + ArgMinMaxFp16((float16_t *)in_data, out_data, (float16_t *)out_value, in_shape, &arg_min_max->compute_); +#endif + } else if (in_tensor->data_type_ == kNumberTypeInt32) { + ArgMinMaxInt32((int32_t *)in_data, out_data, (int32_t *)out_value, in_shape, &arg_min_max->compute_); + } else { + ret = NNACL_UNSUPPORTED_DATA_TYPE; + } + + if (arg_min_max->arg_elements_alloc_) { + self->env_->Free(self->env_->allocator_, arg_min_max->compute_.arg_elements_); + arg_min_max->compute_.arg_elements_ = NULL; + } + return ret; +} + +KernelBase *CreateArgMinMax(OpParameter *param, int data_type) { + ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)malloc(sizeof(ArgMinMaxStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arg_min_max); + memset(arg_min_max, 0, sizeof(ArgMinMaxStruct)); + + arg_min_max->base_.Prepare = ArgMinMaxPrepare; + arg_min_max->base_.Resize = ArgMinMaxResize; + arg_min_max->base_.Release = DefaultRelease; + arg_min_max->base_.Compute = ArgMinMaxCompute; + return (KernelBase *)arg_min_max; +} + +REG_KERNEL_CREATOR(PrimType_ArgMinFusion, kNumberTypeInt32, CreateArgMinMax) +REG_KERNEL_CREATOR(PrimType_ArgMinFusion, kNumberTypeFloat16, CreateArgMinMax) +REG_KERNEL_CREATOR(PrimType_ArgMinFusion, kNumberTypeFloat32, CreateArgMinMax) + +REG_KERNEL_CREATOR(PrimType_ArgMaxFusion, kNumberTypeInt32, CreateArgMinMax) +REG_KERNEL_CREATOR(PrimType_ArgMaxFusion, kNumberTypeFloat16, CreateArgMinMax) +REG_KERNEL_CREATOR(PrimType_ArgMaxFusion, kNumberTypeFloat32, CreateArgMinMax) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arg_min_max.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arg_min_max.h new file mode 100644 index 0000000000000000000000000000000000000000..52c7f7f3ed5742c07e0cd8729e47fb91e01492c9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arg_min_max.h @@ -0,0 +1,63 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ARG_MIN_MAX_H_ +#define NNACL_KERNEL_ARG_MIN_MAX_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#ifdef ENABLE_ARM64 +#include +#endif + +typedef struct ArgElement { + uint32_t index_; + union ArgData { + int8_t i8_data_; + int32_t i_data_; + float f_data_; +#ifdef ENABLE_ARM +#ifdef ENABLE_FP16 + float16_t f16_data_; +#endif +#endif + } data_; +} ArgElement; + +typedef int (*COMPARE_FUNCTION)(const void *a, const void *b); + +typedef struct ArgMinMaxComputeParam { + int32_t axis_; + int32_t dims_size_; + int32_t topk_; + bool get_max_; + bool keep_dims_; + bool out_value_; + int32_t in_strides_[COMM_SHAPE_SIZE]; + int32_t out_strides_[COMM_SHAPE_SIZE]; + ArgElement *arg_elements_; +} ArgMinMaxComputeParam; + +typedef struct ArgMinMaxStruct { + KernelBase base_; + ArgMinMaxComputeParam compute_; + bool arg_elements_alloc_; +} ArgMinMaxStruct; + +KernelBase *CreateArgMinMax(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ARG_MIN_MAX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic.c new file mode 100644 index 0000000000000000000000000000000000000000..973efbf57348ace2ed817b1683839a5a76d40fc1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic.c @@ -0,0 +1,725 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either arithmeticress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/arithmetic.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/mul_fp32.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/arithmetic_fp16.h" +#endif + +void InitArithmeticRunFunction(KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + + ArithmeticFuncions fun_table[] = { + {PrimType_MulFusion, ActType_Relu, ElementMulRelu, ElementMulReluInt, NULL, ElementOptMulRelu, ElementOptMulReluInt, + NULL}, + {PrimType_MulFusion, ActType_Relu6, ElementMulRelu6, ElementMulRelu6Int, NULL, ElementOptMulRelu6, + ElementOptMulRelu6Int, NULL}, + {PrimType_MulFusion, ActType_No, ElementMul, ElementMulInt, NULL, ElementOptMul, ElementOptMulInt, NULL}, + {PrimType_AddFusion, ActType_Relu, ElementAddRelu, NULL, NULL, ElementOptAddRelu, NULL, NULL}, + {PrimType_AddFusion, ActType_Relu6, ElementAddRelu6, NULL, NULL, ElementOptAddRelu6, NULL, NULL}, + {PrimType_AddFusion, ActType_No, ElementAdd, ElementAddInt, NULL, ElementOptAdd, ElementOptAddInt, NULL}, + {PrimType_SubFusion, ActType_Relu, ElementSubRelu, NULL, NULL, ElementOptSubRelu, NULL, NULL}, + {PrimType_SubFusion, ActType_Relu6, ElementSubRelu6, NULL, NULL, ElementOptSubRelu6, NULL, NULL}, + {PrimType_SubFusion, ActType_No, ElementSub, ElementSubInt, NULL, ElementOptSub, ElementOptSubInt, NULL}, + {PrimType_DivFusion, ActType_Relu, ElementDivRelu, NULL, NULL, ElementOptDivRelu, NULL, NULL}, + {PrimType_DivFusion, ActType_Relu6, ElementDivRelu6, NULL, NULL, ElementOptDivRelu6, NULL, NULL}, + {PrimType_DivFusion, ActType_No, ElementDiv, NULL, NULL, ElementOptDiv, ElementOptDivInt, NULL}, + {PrimType_RealDiv, ActType_Relu, ElementDivRelu, NULL, NULL, ElementOptDivRelu, NULL, NULL}, + {PrimType_RealDiv, ActType_Relu6, ElementDivRelu6, NULL, NULL, ElementOptDivRelu6, NULL, NULL}, + {PrimType_RealDiv, ActType_No, ElementDiv, NULL, NULL, ElementOptDiv, ElementOptDivInt, NULL}, + {PrimType_LogicalAnd, ActType_No, ElementLogicalAnd, ElementLogicalAndInt, ElementLogicalAndBool, + ElementOptLogicalAnd, ElementOptLogicalAndInt, ElementOptLogicalAndBool}, + {PrimType_LogicalOr, ActType_No, ElementLogicalOr, NULL, ElementLogicalOrBool, NULL, NULL, ElementOptLogicalOrBool}, + {PrimType_Maximum, ActType_No, ElementMaximum, ElementMaximumInt, NULL, ElementOptMaximum, ElementOptMaximumInt, + NULL}, + {PrimType_Minimum, ActType_No, ElementMinimum, ElementMinimumInt, NULL, ElementOptMinimum, ElementOptMinimumInt, + NULL}, + {PrimType_FloorMod, ActType_No, ElementFloorMod, ElementFloorModInt, NULL, ElementOptFloorMod, + ElementOptFloorModInt, NULL}, + {PrimType_FloorDiv, ActType_No, ElementFloorDiv, ElementFloorDivInt, NULL, ElementOptFloorDiv, + ElementOptFloorDivInt, NULL}, + {PrimType_Mod, ActType_No, ElementMod, ElementModInt, NULL, ElementOptMod, ElementOptModInt, NULL}, + {PrimType_SquaredDifference, ActType_No, ElementSquaredDifference, NULL, NULL, ElementOptSquaredDifference, NULL, + NULL}}; + + size_t length = sizeof(fun_table) / sizeof(ArithmeticFuncions); + for (size_t i = 0; i < length; i++) { + if (fun_table[i].primitive_type_ == arithmetic->primitive_type_ && + fun_table[i].activation_type_ == ((ArithmeticParameter *)(arithmetic->base_.param_))->activation_type_) { + arithmetic->functions_ = fun_table[i]; + return; + } + } +} + +int ArithmeticRelease(struct KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + for (int i = 0; i < TWO_TENSOR; i++) { + if (arithmetic->broadcast_buffer_[i] != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->broadcast_buffer_[i]); + arithmetic->broadcast_buffer_[i] = NULL; + } + } + + for (int i = 0; i < arithmetic->block_boundary_infos_size_; i++) { + if (arithmetic->block_boundary_infos_[i].a_offset_ != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->block_boundary_infos_[i].a_offset_); + arithmetic->block_boundary_infos_[i].a_offset_ = NULL; + } + if (arithmetic->block_boundary_infos_[i].b_offset_ != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->block_boundary_infos_[i].b_offset_); + arithmetic->block_boundary_infos_[i].b_offset_ = NULL; + } + } + arithmetic->block_boundary_infos_size_ = 0; + + if (arithmetic->a_matrix_.batch_post_sum_ != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->a_matrix_.batch_post_sum_); + arithmetic->a_matrix_.batch_post_sum_ = NULL; + } + + if (arithmetic->b_matrix_.batch_post_sum_ != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->b_matrix_.batch_post_sum_); + arithmetic->b_matrix_.batch_post_sum_ = NULL; + } + + if (arithmetic->c_matrix_.batch_post_sum_ != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->c_matrix_.batch_post_sum_); + arithmetic->c_matrix_.batch_post_sum_ = NULL; + } + return NNACL_OK; +} + +void ArithmeticComputeOffset(ArithmeticStruct *arithmetic, int task_id) { + ArithmeticBlockBoundaryInfo *block_info = &arithmetic->block_boundary_infos_[task_id]; + block_info->init_offset_ = true; + + int64_t b_start = block_info->batch_begin_; + int64_t b_end = block_info->batch_end_; + int64_t s_end = block_info->size_end_; + if (s_end != 0) { + ++b_end; + } + int offset_index = 0; + for (; b_start < b_end; ++b_start) { + int64_t delta = b_start; + int64_t a_offset = 0; + int64_t b_offset = 0; + for (int j = 0; j <= arithmetic->batch_tail_dim_; ++j) { + if (j > 0) { + delta = delta % arithmetic->c_matrix_.batch_post_sum_[j]; + } + if (j < arithmetic->batch_tail_dim_) { + a_offset += (delta / arithmetic->c_matrix_.batch_post_sum_[j + 1] * arithmetic->a_matrix_.shape_[j] / + arithmetic->c_matrix_.shape_[j]) * + arithmetic->a_matrix_.batch_post_sum_[j + 1]; + b_offset += (delta / arithmetic->c_matrix_.batch_post_sum_[j + 1] * arithmetic->b_matrix_.shape_[j] / + arithmetic->c_matrix_.shape_[j]) * + arithmetic->b_matrix_.batch_post_sum_[j + 1]; + } else { + a_offset += (delta * arithmetic->a_matrix_.shape_[j] / arithmetic->c_matrix_.shape_[j]); + b_offset += (delta * arithmetic->b_matrix_.shape_[j] / arithmetic->c_matrix_.shape_[j]); + } + } + block_info->a_offset_[offset_index] = a_offset * arithmetic->a_matrix_.inner_size_ * arithmetic->in_data_size_; + block_info->b_offset_[offset_index] = b_offset * arithmetic->b_matrix_.inner_size_ * arithmetic->in_data_size_; + offset_index++; + } +} + +int ArithmeticDoExecute(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)base; + int data_type = arithmetic->base_.in_[FIRST_INPUT]->data_type_; + NNACL_CHECK_NULL_RETURN_ERR(input0); + NNACL_CHECK_NULL_RETURN_ERR(input1); + + if (data_type == kNumberTypeFloat32) { + if (arithmetic->scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.optimzie_f32_); + return arithmetic->functions_.optimzie_f32_((const float *)input0, (const float *)input1, (float *)output, size, + arithmetic->in_elements_num0_ == 1); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.compute_f32_); + return arithmetic->functions_.compute_f32_((const float *)input0, (const float *)input1, (float *)output, size); + } + } + + if (data_type == kNumberTypeBool) { + if (arithmetic->scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.optimzie_bool_); + return arithmetic->functions_.optimzie_bool_((const bool *)input0, (const bool *)input1, (bool *)output, size, + arithmetic->in_elements_num0_ == 1); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.compute_bool_); + return arithmetic->functions_.compute_bool_((const bool *)input0, (const bool *)input1, (bool *)output, size); + } + } + + if (data_type == kNumberTypeInt32) { + if (arithmetic->scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.optimzie_int_); + return arithmetic->functions_.optimzie_int_((const int *)input0, (const int *)input1, (int *)output, size, + arithmetic->in_elements_num0_ == 1); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.compute_int_); + return arithmetic->functions_.compute_int_((const int *)input0, (const int *)input1, (int *)output, size); + } + } + + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +int ArithmeticRun(void *cdata, int task_id, float l, float r) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)cdata; + NNACL_CHECK_FALSE(task_id < 0, NNACL_ERR); + NNACL_CHECK_FALSE(task_id >= arithmetic->block_boundary_infos_size_, NNACL_ERR); + + if (arithmetic->block_boundary_infos_[task_id].init_offset_ == false) { + ArithmeticComputeOffset(arithmetic, task_id); + } + + ArithmeticBlockBoundaryInfo *block_info = &arithmetic->block_boundary_infos_[task_id]; + int64_t b_start = block_info->batch_begin_; + int64_t s_start = block_info->size_begin_; + int64_t s_end = block_info->size_end_; + int64_t index_start = 0; + int64_t index_end = block_info->batch_end_ - b_start; + uint8_t *a_ptr = (uint8_t *)(arithmetic->a_matrix_.data_) + block_info->a_offset_[index_start]; + uint8_t *b_ptr = (uint8_t *)(arithmetic->b_matrix_.data_) + block_info->b_offset_[index_start]; + uint8_t *c_ptr = (uint8_t *)(arithmetic->c_matrix_.data_) + + (b_start * arithmetic->c_matrix_.inner_size_ + s_start) * arithmetic->out_data_size_; + if (arithmetic->a_matrix_.inner_size_ > 1) { + a_ptr += s_start * arithmetic->in_data_size_; + } + if (arithmetic->b_matrix_.inner_size_ > 1) { + b_ptr += s_start * arithmetic->in_data_size_; + } + + if (index_start == index_end) { + return arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, s_end - s_start); + } + + int64_t size = arithmetic->c_matrix_.inner_size_ - s_start; + int ret = arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, size); + if (ret != NNACL_OK) { + return ret; + } + + ++index_start; + c_ptr += size * arithmetic->out_data_size_; + int64_t c_stride = arithmetic->c_matrix_.inner_size_ * arithmetic->out_data_size_; + for (; index_start < index_end; ++index_start) { + a_ptr = (uint8_t *)(arithmetic->a_matrix_.data_) + block_info->a_offset_[index_start]; + b_ptr = (uint8_t *)(arithmetic->b_matrix_.data_) + block_info->b_offset_[index_start]; + ret = arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, arithmetic->c_matrix_.inner_size_); + if (ret != NNACL_OK) { + return ret; + } + c_ptr += c_stride; + } + if (s_end == 0) { + return NNACL_OK; + } + a_ptr = (uint8_t *)(arithmetic->a_matrix_.data_) + block_info->a_offset_[index_start]; + b_ptr = (uint8_t *)(arithmetic->b_matrix_.data_) + block_info->b_offset_[index_start]; + return arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, s_end); +} + +void ResetArithmeticMatric(KernelBase *base, ArithmeticMatrixInfo *matrix) { + matrix->is_valid_ = false; + matrix->data_ = NULL; + matrix->inner_size_ = 1; + matrix->shape_size_ = 0; + + if (matrix->batch_post_sum_ != NULL) { + base->env_->Free(base->env_->allocator_, matrix->batch_post_sum_); + matrix->batch_post_sum_ = NULL; + } +} + +int UpdateArithmeticParameter(ArithmeticStruct *arithmetic) { + NNACL_CHECK_TRUE_RET(arithmetic->a_matrix_.shape_size_ == arithmetic->b_matrix_.shape_size_, + NNACL_ARITHMETIC_SHAPE_INVALID); + + arithmetic->ndim_ = arithmetic->a_matrix_.shape_size_; + ResetArithmeticMatric(&arithmetic->base_, &arithmetic->c_matrix_); + + for (size_t i = 0; i < arithmetic->ndim_; ++i) { + NNACL_CHECK_TRUE_RET(arithmetic->a_matrix_.shape_[i] <= INT_MAX, NNACL_ARITHMETIC_SHAPE_INVALID); + NNACL_CHECK_TRUE_RET(arithmetic->b_matrix_.shape_[i] <= INT_MAX, NNACL_ARITHMETIC_SHAPE_INVALID); + arithmetic->in_shape0_[i] = arithmetic->a_matrix_.shape_[i]; + arithmetic->in_shape1_[i] = arithmetic->b_matrix_.shape_[i]; + arithmetic->out_shape_[i] = MSMAX(arithmetic->in_shape0_[i], arithmetic->in_shape1_[i]); + arithmetic->c_matrix_.shape_[arithmetic->c_matrix_.shape_size_++] = + MSMAX(arithmetic->a_matrix_.shape_[i], arithmetic->b_matrix_.shape_[i]); + } + return NNACL_OK; +} + +int OptimizeArithmeticShape(ArithmeticStruct *arithmetic) { + ArithmeticMatrixInfo *a = &arithmetic->a_matrix_; + ArithmeticMatrixInfo *b = &arithmetic->b_matrix_; + arithmetic->ndim_ = a->shape_size_ >= b->shape_size_ ? a->shape_size_ : b->shape_size_; + + int shape0[MAX_LEN] = {0}; + int shape1[MAX_LEN] = {0}; + /* init a & b shape */ + int i = 0; + for (; i < arithmetic->ndim_; ++i) { + shape0[i] = 1; + shape1[i] = 1; + } + + /* init matrix shape dim */ + int a_matrix_size = arithmetic->ndim_ - a->shape_size_; + for (i = a_matrix_size; i < arithmetic->ndim_; i++) { + shape0[i] = a->shape_[i - a_matrix_size]; + } + + int b_matrix_size = arithmetic->ndim_ - b->shape_size_; + for (i = b_matrix_size; i < arithmetic->ndim_; i++) { + shape1[i] = b->shape_[i - b_matrix_size]; + } + + /* horizontal shape dims */ + int shape0_temp[MAX_LEN] = {0}; + int shape1_temp[MAX_LEN] = {0}; + int shape_temp_size = 0; + for (i = 0; i < arithmetic->ndim_;) { // horizontal comparison, merge the part of continuous 1. + shape0_temp[shape_temp_size] = shape0[i]; + shape1_temp[shape_temp_size] = shape1[i]; + shape_temp_size++; + if (shape0[i] != 1 && shape1[i] != 1) { + ++i; + continue; + } + + size_t j0 = i; + while (j0 < arithmetic->ndim_ && shape0[j0] == 1) { + ++j0; + } + size_t j1 = i; + while (j1 < arithmetic->ndim_ && shape1[j1] == 1) { + ++j1; + } + size_t j = MSMAX(j0, j1); + while ((++i) < j) { + shape0_temp[shape_temp_size - 1] *= shape0[i]; + shape1_temp[shape_temp_size - 1] *= shape1[i]; + } + } + + arithmetic->a_matrix_.shape_size_ = 0; + arithmetic->b_matrix_.shape_size_ = 0; + + for (i = 0; i < shape_temp_size;) { // vertical comparison, merge the part of continuous equation. + if (shape0_temp[i] == 1 && shape1_temp[i] == 1) { + ++i; + continue; + } + shape0[arithmetic->a_matrix_.shape_size_++] = shape0_temp[i]; + shape1[arithmetic->b_matrix_.shape_size_++] = shape1_temp[i]; + if (shape0_temp[i] != shape1_temp[i]) { + ++i; + continue; + } + while ((++i) < shape_temp_size) { + if (shape0_temp[i] != shape1_temp[i]) { + break; + } + shape0[arithmetic->a_matrix_.shape_size_ - 1] *= shape0_temp[i]; + shape1[arithmetic->b_matrix_.shape_size_ - 1] *= shape1_temp[i]; + } + } + + memcpy(arithmetic->a_matrix_.shape_, shape0, arithmetic->a_matrix_.shape_size_ * sizeof(int)); + memcpy(arithmetic->b_matrix_.shape_, shape1, arithmetic->b_matrix_.shape_size_ * sizeof(int)); + + return UpdateArithmeticParameter(arithmetic); +} + +int ResetArithmeticStatus(ArithmeticStruct *arithmetic) { + ResetArithmeticMatric(&arithmetic->base_, &arithmetic->a_matrix_); + ResetArithmeticMatric(&arithmetic->base_, &arithmetic->b_matrix_); + ResetArithmeticMatric(&arithmetic->base_, &arithmetic->c_matrix_); + + arithmetic->a_matrix_.shape_size_ = arithmetic->base_.in_[FIRST_INPUT]->shape_size_; + memcpy(arithmetic->a_matrix_.shape_, arithmetic->base_.in_[FIRST_INPUT]->shape_, + arithmetic->a_matrix_.shape_size_ * sizeof(int)); + arithmetic->b_matrix_.shape_size_ = arithmetic->base_.in_[SECOND_INPUT]->shape_size_; + memcpy(arithmetic->b_matrix_.shape_, arithmetic->base_.in_[SECOND_INPUT]->shape_, + arithmetic->b_matrix_.shape_size_ * sizeof(int)); + + return OptimizeArithmeticShape(arithmetic); +} + +void ArithmeticDoBroadcast(ArithmeticStruct *arithmetic, void *in_data, void *out_data, int input_index) { + int *in_shape = input_index == FIRST_INPUT ? arithmetic->in_shape0_ : arithmetic->in_shape1_; + int *in_stride = input_index == FIRST_INPUT ? arithmetic->in_strides0_ : arithmetic->in_strides1_; + int *multiples = input_index == FIRST_INPUT ? arithmetic->multiples0_ : arithmetic->multiples1_; + return arithmetic->tile_function_(in_data, out_data, 0, arithmetic->ndim_, in_shape, in_stride, + arithmetic->out_strides_, multiples); +} + +int CheckDivDataInvalid(ArithmeticStruct *arithmetic) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->base_.in_[SECOND_INPUT]); + if ((arithmetic->primitive_type_ == PrimType_DivFusion || arithmetic->primitive_type_ == PrimType_RealDiv) && + arithmetic->base_.in_[SECOND_INPUT]->data_type_ == kNumberTypeInt32) { + int element_num = NNACLGetElementNum(arithmetic->base_.in_[SECOND_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->base_.in_[SECOND_INPUT]->data_); + int *int_data = (int *)(arithmetic->base_.in_[SECOND_INPUT]->data_); + for (int i = 0; i < element_num; i++) { + if (int_data[i] == 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + } + return NNACL_OK; +} + +int ArithmeticBroadCastConstTensor(ArithmeticStruct *arithmetic) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + + CalcStructMultiplesAndStrides(arithmetic); + +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE + bool prefer_explicit_broadcast = false; +#else + bool prefer_explicit_broadcast = arithmetic->ndim_ != 1; +#endif + prefer_explicit_broadcast = + prefer_explicit_broadcast && (arithmetic->base_.in_[FIRST_INPUT]->data_type_ != kNumberTypeBool); + + bool exist_broadcast_ = false; + int buffer_size = NNACLGetElementNum(arithmetic->base_.out_[OUTPUT_INDEX]) * arithmetic->in_data_size_; + if (arithmetic->a_matrix_.is_const_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->base_.in_[FIRST_INPUT]->data_); + if (arithmetic->in_elements_num0_ != arithmetic->out_elements_num_ && prefer_explicit_broadcast) { + exist_broadcast_ = true; + + arithmetic->a_matrix_.data_ = arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, buffer_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->a_matrix_.data_); + arithmetic->broadcast_buffer_[Index0] = arithmetic->a_matrix_.data_; + + ArithmeticDoBroadcast(arithmetic, arithmetic->base_.in_[FIRST_INPUT]->data_, arithmetic->a_matrix_.data_, Index0); + arithmetic->in_elements_num0_ = arithmetic->out_elements_num_; + + // shape must be equal to out + for (size_t i = 0; i < arithmetic->ndim_; ++i) { + arithmetic->in_shape0_[i] = arithmetic->out_shape_[i]; + arithmetic->in_strides0_[i] = arithmetic->out_strides_[i]; + } + memcpy(arithmetic->a_matrix_.shape_, arithmetic->c_matrix_.shape_, arithmetic->ndim_ * sizeof(int)); + arithmetic->a_matrix_.is_valid_ = true; + } + } + + if (arithmetic->b_matrix_.is_const_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->base_.in_[SECOND_INPUT]->data_); + int ret = CheckDivDataInvalid(arithmetic); + if (ret != NNACL_OK) { + return ret; + } + if (arithmetic->in_elements_num1_ != arithmetic->out_elements_num_ && prefer_explicit_broadcast) { + exist_broadcast_ = true; + + arithmetic->b_matrix_.data_ = arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, buffer_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->b_matrix_.data_); + arithmetic->broadcast_buffer_[Index1] = arithmetic->b_matrix_.data_; + + ArithmeticDoBroadcast(arithmetic, arithmetic->base_.in_[Index1]->data_, arithmetic->b_matrix_.data_, Index1); + arithmetic->in_elements_num1_ = arithmetic->out_elements_num_; + // shape must be equal to out + for (size_t i = 0; i < arithmetic->ndim_; ++i) { + arithmetic->in_shape1_[i] = arithmetic->out_shape_[i]; + arithmetic->in_strides1_[i] = arithmetic->out_strides_[i]; + } + + memcpy(arithmetic->b_matrix_.shape_, arithmetic->c_matrix_.shape_, arithmetic->ndim_ * sizeof(int)); + arithmetic->b_matrix_.is_valid_ = true; + } + } + if (!exist_broadcast_) { + return NNACL_OK; + } + return OptimizeArithmeticShape(arithmetic); +} + +int ArithmeticComputeOfflineInfo(ArithmeticStruct *arithmetic) { + int bread_pos = -1; + int last_dim = arithmetic->a_matrix_.shape_size_ - 1; + for (int i = last_dim; i >= 0; --i) { + if (arithmetic->a_matrix_.shape_[i] != arithmetic->b_matrix_.shape_[i]) { + bread_pos = i; + break; + } + } + arithmetic->batch_tail_dim_ = bread_pos; + if (bread_pos == last_dim && arithmetic->batch_tail_dim_ >= 0) { + --arithmetic->batch_tail_dim_; + } + + for (int i = last_dim; i > arithmetic->batch_tail_dim_; --i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(arithmetic->a_matrix_.inner_size_, arithmetic->a_matrix_.shape_[i], NNACL_ERR); + arithmetic->a_matrix_.inner_size_ *= arithmetic->a_matrix_.shape_[i]; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(arithmetic->b_matrix_.inner_size_, arithmetic->b_matrix_.shape_[i], NNACL_ERR); + arithmetic->b_matrix_.inner_size_ *= arithmetic->b_matrix_.shape_[i]; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(arithmetic->c_matrix_.inner_size_, arithmetic->c_matrix_.shape_[i], NNACL_ERR); + arithmetic->c_matrix_.inner_size_ *= arithmetic->c_matrix_.shape_[i]; + } + + arithmetic->a_matrix_.batch_post_sum_ = arithmetic->base_.env_->Alloc( + arithmetic->base_.env_->allocator_, (arithmetic->a_matrix_.shape_size_ + 1) * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->a_matrix_.batch_post_sum_); + for (int i = 0; i < arithmetic->a_matrix_.shape_size_ + 1; i++) { + arithmetic->a_matrix_.batch_post_sum_[i] = 1; + } + + arithmetic->b_matrix_.batch_post_sum_ = arithmetic->base_.env_->Alloc( + arithmetic->base_.env_->allocator_, (arithmetic->b_matrix_.shape_size_ + 1) * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->b_matrix_.batch_post_sum_); + for (int i = 0; i < arithmetic->b_matrix_.shape_size_ + 1; i++) { + arithmetic->b_matrix_.batch_post_sum_[i] = 1; + } + + arithmetic->c_matrix_.batch_post_sum_ = arithmetic->base_.env_->Alloc( + arithmetic->base_.env_->allocator_, (arithmetic->c_matrix_.shape_size_ + 1) * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->c_matrix_.batch_post_sum_); + for (int i = 0; i < arithmetic->c_matrix_.shape_size_ + 1; i++) { + arithmetic->c_matrix_.batch_post_sum_[i] = 1; + } + + for (int i = arithmetic->batch_tail_dim_; i >= 0; --i) { + if (i == arithmetic->batch_tail_dim_) { + arithmetic->a_matrix_.batch_post_sum_[i] = arithmetic->a_matrix_.shape_[i]; + arithmetic->b_matrix_.batch_post_sum_[i] = arithmetic->b_matrix_.shape_[i]; + arithmetic->c_matrix_.batch_post_sum_[i] = arithmetic->c_matrix_.shape_[i]; + } else { + arithmetic->a_matrix_.batch_post_sum_[i] = + arithmetic->a_matrix_.shape_[i] * arithmetic->a_matrix_.batch_post_sum_[i + 1]; + arithmetic->b_matrix_.batch_post_sum_[i] = + arithmetic->b_matrix_.shape_[i] * arithmetic->b_matrix_.batch_post_sum_[i + 1]; + arithmetic->c_matrix_.batch_post_sum_[i] = + arithmetic->c_matrix_.shape_[i] * arithmetic->c_matrix_.batch_post_sum_[i + 1]; + } + } + + arithmetic->scalar_opt_ = false; + if (arithmetic->a_matrix_.inner_size_ == 1) { + arithmetic->in_elements_num0_ = 1; + arithmetic->scalar_opt_ = true; + } + if (arithmetic->b_matrix_.inner_size_ == 1) { + arithmetic->in_elements_num1_ = 1; + arithmetic->scalar_opt_ = true; + } + return NNACL_OK; +} + +int ArithmeticChooseThreadCuttingStrategy(ArithmeticStruct *arithmetic) { + int total_num = NNACLGetElementNum(arithmetic->base_.out_[OUTPUT_INDEX]); + arithmetic->base_.thread_nr_ = + arithmetic->base_.UpdateThread(TC_TYPE(arithmetic->primitive_type_, arithmetic->functions_.activation_type_), 1, 1, + total_num, arithmetic->base_.thread_nr_); + + int64_t block_size = UP_DIV(total_num, arithmetic->base_.thread_nr_); + int64_t split_point = 0; + while (split_point < total_num) { + int64_t start = split_point; + int64_t end = start + block_size; + if (end > total_num) { + end = total_num; + } + ArithmeticBlockBoundaryInfo block_boundary_info; + block_boundary_info.size_begin_ = start % arithmetic->c_matrix_.inner_size_; + block_boundary_info.size_end_ = end % arithmetic->c_matrix_.inner_size_; + block_boundary_info.batch_begin_ = start / arithmetic->c_matrix_.inner_size_; + block_boundary_info.batch_end_ = end / arithmetic->c_matrix_.inner_size_; + block_boundary_info.init_offset_ = false; + + int max_offset_size = block_boundary_info.batch_end_ - block_boundary_info.batch_begin_ + TWO_TENSOR; + block_boundary_info.a_offset_ = + (int *)arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, max_offset_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(block_boundary_info.a_offset_); + block_boundary_info.b_offset_ = + (int *)arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, max_offset_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(block_boundary_info.b_offset_); + + arithmetic->block_boundary_infos_[arithmetic->block_boundary_infos_size_++] = block_boundary_info; + split_point = end; + } + + arithmetic->base_.thread_nr_ = arithmetic->block_boundary_infos_size_; + return NNACL_OK; +} + +int ArithmeticResize(struct KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + + ArithmeticRelease(&arithmetic->base_); + + NNACL_CHECK_TRUE_RET(arithmetic->in_data_size_ != 0, NNACL_UNSUPPORTED_DATA_TYPE); + NNACL_CHECK_TRUE_RET(arithmetic->out_data_size_ != 0, NNACL_UNSUPPORTED_DATA_TYPE); + arithmetic->in_elements_num0_ = NNACLGetElementNum(self->in_[FIRST_INPUT]); + arithmetic->in_elements_num1_ = NNACLGetElementNum(self->in_[SECOND_INPUT]); + arithmetic->out_elements_num_ = NNACLGetElementNum(self->in_[OUTPUT_INDEX]); + + int ret = ResetArithmeticStatus(arithmetic); + if (ret != NNACL_OK) { + return ret; + } + + ret = ArithmeticBroadCastConstTensor(arithmetic); + if (ret != NNACL_OK) { + return ret; + } + + ret = ArithmeticComputeOfflineInfo(arithmetic); + if (ret != NNACL_OK) { + return ret; + } + + return ArithmeticChooseThreadCuttingStrategy(arithmetic); +} + +int ArithmeticPrepare(struct KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + + NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ < kNumberTypeBegin, NNACL_ERR); + NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ > kNumberTypeEnd, NNACL_ERR); + NNACL_CHECK_FALSE(self->in_[SECOND_INPUT]->data_type_ > kNumberTypeEnd, NNACL_ERR); + NNACL_CHECK_FALSE(self->in_[SECOND_INPUT]->data_type_ > kNumberTypeEnd, NNACL_ERR); + + if (self->param_->quant_type_ != Quant_None) { + return NNACL_ERR; + } + + arithmetic->primitive_type_ = self->param_->type_; + if (self->param_->type_ == PrimType_Eltwise) { + switch (((ArithmeticParameter *)(self->param_))->eltwise_mode_) { + case Eltwise_PROD: + arithmetic->primitive_type_ = PrimType_MulFusion; + break; + case Eltwise_SUM: + arithmetic->primitive_type_ = PrimType_AddFusion; + break; + case Eltwise_MAXIMUM: + arithmetic->primitive_type_ = PrimType_Maximum; + break; + default: + return NNACL_ELTWISE_INVALID_MOD; + } + } + arithmetic->init_function_(self); + + arithmetic->a_matrix_.is_const_ = NNACLIsConst(self->in_[FIRST_INPUT]); + arithmetic->b_matrix_.is_const_ = NNACLIsConst(self->in_[SECOND_INPUT]); + return NNACL_OK; +} + +int ArithmeticCompute(struct KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ != self->in_[SECOND_INPUT]->data_type_, + NNACL_ARITHMETIC_DATA_TYPE_UNMATCH); + + if (self->train_session_) { + arithmetic->in_data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + } + + if (false == arithmetic->a_matrix_.is_valid_) { + arithmetic->a_matrix_.data_ = self->in_[FIRST_INPUT]->data_; + } + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->a_matrix_.data_); + + if (!arithmetic->b_matrix_.is_const_) { + int ret = CheckDivDataInvalid(arithmetic); + if (ret != NNACL_OK) { + return ret; + } + } + if (false == arithmetic->b_matrix_.is_valid_) { + arithmetic->b_matrix_.data_ = self->in_[SECOND_INPUT]->data_; + } + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->b_matrix_.data_); + + arithmetic->c_matrix_.data_ = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->c_matrix_.data_); + + return self->env_->ParallelLaunch(self->env_->thread_pool_, ArithmeticRun, self, self->thread_nr_); +} + +KernelBase *CreateArithmetic(OpParameter *param, int data_type) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)malloc(sizeof(ArithmeticStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arithmetic); + memset(arithmetic, 0, sizeof(ArithmeticStruct)); + arithmetic->in_data_size_ = DataTypeCSize(data_type); + arithmetic->out_data_size_ = DataTypeCSize(data_type); + arithmetic->block_boundary_infos_size_ = 0; + arithmetic->a_matrix_.batch_post_sum_ = NULL; + arithmetic->b_matrix_.batch_post_sum_ = NULL; + arithmetic->c_matrix_.batch_post_sum_ = NULL; + arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL; + arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL; + arithmetic->tile_function_ = TileOneDimensionFp32; + arithmetic->init_function_ = InitArithmeticRunFunction; + arithmetic->execute_ = ArithmeticDoExecute; + arithmetic->base_.Prepare = ArithmeticPrepare; + arithmetic->base_.Resize = ArithmeticResize; + arithmetic->base_.Release = ArithmeticRelease; + arithmetic->base_.Compute = ArithmeticCompute; + return (KernelBase *)arithmetic; +} + +REG_KERNEL_CREATOR(PrimType_MulFusion, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_MulFusion, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeBool, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_SubFusion, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_SubFusion, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_DivFusion, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_RealDiv, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Mod, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Mod, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeBool, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_LogicalOr, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_LogicalOr, kNumberTypeBool, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Maximum, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Minimum, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Maximum, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Minimum, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_FloorDiv, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_FloorMod, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_FloorDiv, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_FloorMod, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_SquaredDifference, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Eltwise, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_DivFusion, kNumberTypeInt32, CreateArithmetic) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic.h new file mode 100644 index 0000000000000000000000000000000000000000..7d261b40d3378f3e965183e4945502e50313f8bb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic.h @@ -0,0 +1,97 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_ARITHMETIC_H_ +#define NNACL_KERNEL_ARITHMETIC_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/arithmetic_parameter.h" + +typedef struct ArithmeticFuncions { + int primitive_type_; + int activation_type_; + int (*compute_f32_)(const float *in1, const float *in2, float *out, int ele); + int (*compute_int_)(const int *in1, const int *in2, int *out, int ele); + int (*compute_bool_)(const bool *in1, const bool *in2, bool *out, int ele); + int (*optimzie_f32_)(const float *in1, const float *in2, float *out, int ele, bool scalar); + int (*optimzie_int_)(const int *in1, const int *in2, int *out, int ele, bool scalar); + int (*optimzie_bool_)(const bool *in1, const bool *in2, bool *out, int ele, bool scalar); +} ArithmeticFuncions; + +typedef struct ArithmeticMatrixInfo { + bool is_const_; + bool is_valid_; + void *data_; + int64_t inner_size_; + int shape_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int shape_size_; + int *batch_post_sum_; /* shape size + 1 */ +} ArithmeticMatrixInfo; + +typedef struct ArithmeticBlockBoundaryInfo { + int batch_begin_; + int batch_end_; + int size_begin_; // start-offset under the begin batch + int size_end_; // end-num under the ending batch + int *a_offset_; + int *b_offset_; + bool init_offset_; +} ArithmeticBlockBoundaryInfo; + +typedef struct ArithmeticStruct { + KernelBase base_; + bool scalar_opt_; + int primitive_type_; + int ndim_; + int in_data_size_; + int out_data_size_; + int batch_tail_dim_; + + ArithmeticMatrixInfo a_matrix_; + ArithmeticMatrixInfo b_matrix_; + ArithmeticMatrixInfo c_matrix_; + ArithmeticFuncions functions_; + + void *broadcast_buffer_[TWO_TENSOR]; + int block_boundary_infos_size_; + ArithmeticBlockBoundaryInfo block_boundary_infos_[MAX_THREAD_NUM]; + + int in_shape0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int in_elements_num0_; + int in_shape1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int in_elements_num1_; + int out_shape_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int out_elements_num_; + int in_strides0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int in_strides1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int out_strides_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int multiples0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int multiples1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + + void (*tile_function_)(const void *inPtr, void *outPtr, int dim, size_t ndim, const int *inShape, + const int *inStrides, const int *outStrides, const int *multiple); + int (*execute_)(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size); + void (*init_function_)(KernelBase *base); +} ArithmeticStruct; + +KernelBase *CreateArithmetic(OpParameter *param, int data_type); +int ArithmeticPrepare(struct KernelBase *self); +int ArithmeticRelease(struct KernelBase *self); +int ArithmeticCompute(struct KernelBase *self); +int ArithmeticResize(struct KernelBase *self); + +#endif // NNACL_KERNEL_ARITHMETIC_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_compare.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_compare.c new file mode 100644 index 0000000000000000000000000000000000000000..5fee66477d59326498ae8298fbc057e93954bd32 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_compare.c @@ -0,0 +1,166 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/arithmetic_compare.h" +#include "nnacl_c/kernel/arithmetic.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_compare_fp32.h" + +typedef struct ArithmeticCompareFuncions { + int primitive_type_; + int (*compute_f32_)(const float *input0, const float *input1, uint8_t *output, int element_size); + int (*compute_i32_)(const int *input0, const int *input1, uint8_t *output, int element_size); + int (*optimize_f32)(const float *input0, const float *input1, uint8_t *output, int element_size, bool first_scalar); + int (*optimize_i32)(const int *input0, const int *input1, uint8_t *output, int element_size, bool first_scalar); + int (*compute_i64)(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size); + int (*optimize_i64)(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size, + bool first_scalar); + int (*compute_bool)(const bool *input0, const bool *input1, uint8_t *output, int element_size); +} ArithmeticCompareFuncions; + +typedef struct ArithmeticCompareStruct { + ArithmeticStruct arithmetic_; + ArithmeticCompareFuncions functions_; +} ArithmeticCompareStruct; + +void InitArithmeticCompareRunFunction(KernelBase *self) { + ArithmeticCompareStruct *arithmetic_compare = (ArithmeticCompareStruct *)self; + NNACL_CHECK_NULL_RETURN_VOID(arithmetic_compare); + + ArithmeticCompareFuncions fun_table[] = { + {PrimType_Equal, ElementEqualFp32, ElementEqualInt32, ElementOptEqualFp32, ElementOptEqualInt32, NULL, NULL, + ElementEqualBool}, + {PrimType_NotEqual, ElementNotEqualFp32, ElementNotEqualInt32, ElementOptNotEqualFp32, ElementOptNotEqualInt32, + ElementNotEqualInt64, ElementOptNotEqualInt64, NULL}, + {PrimType_Less, ElementLessFp32, ElementLessInt32, ElementOptLessFp32, ElementOptLessInt32, NULL, NULL, NULL}, + {PrimType_LessEqual, ElementLessEqualFp32, ElementLessEqualInt32, ElementOptLessEqualFp32, ElementOptLessEqualInt32, + NULL, NULL, NULL}, + {PrimType_Greater, ElementGreaterFp32, ElementGreaterInt32, ElementOptGreaterFp32, ElementOptGreaterInt32, NULL, + NULL, NULL}, + {PrimType_GreaterEqual, ElementGreaterEqualFp32, ElementGreaterEqualInt32, ElementOptGreaterEqualFp32, + ElementOptGreaterEqualInt32, NULL, NULL, NULL}}; + + size_t length = sizeof(fun_table) / sizeof(ArithmeticCompareFuncions); + for (size_t i = 0; i < length; i++) { + if (fun_table[i].primitive_type_ == arithmetic_compare->arithmetic_.primitive_type_) { + arithmetic_compare->functions_ = fun_table[i]; + return; + } + } +} + +int ArithmeticCompareExecute(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size) { + ArithmeticCompareStruct *arithmetic_compare = (ArithmeticCompareStruct *)base; + NNACL_CHECK_NULL_RETURN_ERR(input0); + NNACL_CHECK_NULL_RETURN_ERR(input1); + + int data_type = base->in_[FIRST_INPUT]->data_type_; + bool first_scalar = arithmetic_compare->arithmetic_.in_elements_num0_ == 1; + + if (data_type == kNumberTypeFloat32) { + if (arithmetic_compare->arithmetic_.scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.optimize_f32); + return arithmetic_compare->functions_.optimize_f32((const float *)input0, (const float *)input1, + (uint8_t *)output, size, first_scalar); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_f32_); + return arithmetic_compare->functions_.compute_f32_((const float *)input0, (const float *)input1, + (uint8_t *)output, size); + } + } + + if (data_type == kNumberTypeInt || data_type == kNumberTypeInt32) { + if (arithmetic_compare->arithmetic_.scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.optimize_i32); + return arithmetic_compare->functions_.optimize_i32((const int *)input0, (const int *)input1, (uint8_t *)output, + size, first_scalar); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_i32_); + return arithmetic_compare->functions_.compute_i32_((const int *)input0, (const int *)input1, (uint8_t *)output, + size); + } + } + + if (data_type == kNumberTypeInt64) { + if (arithmetic_compare->arithmetic_.scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.optimize_i64); + return arithmetic_compare->functions_.optimize_i64((const int64_t *)input0, (const int64_t *)input1, + (uint8_t *)output, size, first_scalar); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_i64); + return arithmetic_compare->functions_.compute_i64((const int64_t *)input0, (const int64_t *)input1, + (uint8_t *)output, size); + } + } + if (data_type == kNumberTypeBool) { + if (!arithmetic_compare->arithmetic_.scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_bool); + return arithmetic_compare->functions_.compute_bool((const bool *)input0, (const bool *)input1, (uint8_t *)output, + size); + } else { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + } + + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +int ArithmeticCompareResize(KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + arithmetic->in_data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + arithmetic->out_data_size_ = DataTypeCSize(self->out_[OUTPUT_INDEX]->data_type_); + return ArithmeticResize(self); +} + +KernelBase *CreateArithmeticCompare(OpParameter *param, int data_type) { + ArithmeticCompareStruct *arithmetic_compare = (ArithmeticCompareStruct *)malloc(sizeof(ArithmeticCompareStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arithmetic_compare); + memset(arithmetic_compare, 0, sizeof(ArithmeticCompareStruct)); + + ArithmeticStruct *arithmetic = (ArithmeticStruct *)arithmetic_compare; + arithmetic->in_data_size_ = DataTypeCSize(data_type); + arithmetic->out_data_size_ = DataTypeCSize(data_type); + arithmetic->block_boundary_infos_size_ = 0; + arithmetic->a_matrix_.batch_post_sum_ = NULL; + arithmetic->b_matrix_.batch_post_sum_ = NULL; + arithmetic->c_matrix_.batch_post_sum_ = NULL; + arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL; + arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL; + arithmetic->tile_function_ = TileOneDimensionFp32; + arithmetic->init_function_ = InitArithmeticCompareRunFunction; + arithmetic->execute_ = ArithmeticCompareExecute; + arithmetic->base_.Prepare = ArithmeticPrepare; + arithmetic->base_.Resize = ArithmeticCompareResize; + arithmetic->base_.Release = ArithmeticRelease; + arithmetic->base_.Compute = ArithmeticCompute; + return (KernelBase *)arithmetic_compare; +} + +REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeBool, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeInt32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeInt32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeInt64, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Less, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Less, kNumberTypeInt32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_LessEqual, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_LessEqual, kNumberTypeInt32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Greater, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Greater, kNumberTypeInt32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_GreaterEqual, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_GreaterEqual, kNumberTypeInt32, CreateArithmeticCompare) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_compare.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_compare.h new file mode 100644 index 0000000000000000000000000000000000000000..868196c0eb07cb89dade41c479f88cf4e030bc23 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_compare.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ARITHMETIC_COMPARE_H_ +#define NNACL_KERNEL_ARITHMETIC_COMPARE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateArithmeticCompare(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ARITHMETIC_COMPARE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_self.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_self.c new file mode 100644 index 0000000000000000000000000000000000000000..6eaf274168c3cd60523049c0e75f26ea7f5ad2cf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_self.c @@ -0,0 +1,199 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/arithmetic_self.h" +#include "nnacl_c/fp32/arithmetic_self_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/arithmetic_self_fp16.h" +#endif + +void ArithmeticSelfGetArithmeticSelfFunction(ArithmeticSelfStruct *arithmetic_self, int primitive_type) { + ArithmeticSelfFunction type_func_table[] = { + {PrimType_Abs, ElementAbs, NULL, ElementAbsInt, NULL}, + {PrimType_Cos, ElementCos, NULL, NULL, NULL}, + {PrimType_Log, ElementLog, NULL, NULL, NULL}, + {PrimType_Log1p, ElementLog1p, NULL, NULL, NULL}, + {PrimType_Square, ElementSquare, NULL, NULL, NULL}, + {PrimType_Sqrt, ElementSqrt, NULL, NULL, NULL}, + {PrimType_Rsqrt, ElementRsqrt, NULL, NULL, NULL}, + {PrimType_Sin, ElementSin, NULL, NULL, NULL}, + {PrimType_LogicalNot, ElementLogicalNot, ElementLogicalNotBool, NULL, NULL}, + {PrimType_Floor, ElementFloor, NULL, NULL, NULL}, + {PrimType_Ceil, ElementCeil, NULL, NULL, NULL}, + {PrimType_Round, ElementRound, NULL, NULL, NULL}, + {PrimType_Neg, ElementNegative, NULL, ElementNegativeInt, NULL}, + {PrimType_Reciprocal, ElementReciprocal, NULL, NULL, NULL}, + {PrimType_Erf, ElementErf, NULL, NULL, NULL}, + {PrimType_IsFinite, NULL, NULL, NULL, ElementIsFinite}}; + for (size_t i = 0; i < sizeof(type_func_table) / sizeof(ArithmeticSelfFunction); i++) { + if (type_func_table[i].primitive_type_ == primitive_type) { + arithmetic_self->function_ = type_func_table[i]; + return; + } + } +} + +void ArithmeticSelfGetArithmeticSelfF16Function(ArithmeticSelfStruct *arithmetic_self, int primitive_type) { +#ifdef ENABLE_FP16 + ArithmeticSelfF16Function type_func_table[] = {{PrimType_Abs, ElementAbsFp16}, + {PrimType_Cos, ElementCosFp16}, + {PrimType_Log, ElementLogFp16}, + {PrimType_Square, ElementSquareFp16}, + {PrimType_Sqrt, ElementSqrtFp16}, + {PrimType_Rsqrt, ElementRsqrtFp16}, + {PrimType_Sin, ElementSinFp16}, + {PrimType_LogicalNot, ElementLogicalNotFp16}, + {PrimType_Floor, ElementFloorFp16}, + {PrimType_Ceil, ElementCeilFp16}, + {PrimType_Round, ElementRoundFp16}, + {PrimType_Neg, ElementNegativeFp16}, + {PrimType_Reciprocal, ElementReciprocalFp16}, + {PrimType_Erf, ElementErfFp16}}; + for (size_t i = 0; i < sizeof(type_func_table) / sizeof(ArithmeticSelfF16Function); i++) { + if (type_func_table[i].primitive_type_ == primitive_type) { + arithmetic_self->f16_function_ = type_func_table[i]; + return; + } + } +#endif + arithmetic_self->f16_function_.primitive_type_ = primitive_type; + return; +} + +int ArithmeticSelfExecute(ArithmeticSelfStruct *arithmetic_self, int task_id) { + int elements_num = NNACLGetElementNum(arithmetic_self->base_.in_[FIRST_INPUT]); + NNACL_CHECK_TRUE_RET(arithmetic_self->base_.thread_nr_, NNACL_ERR); + int stride = UP_DIV(elements_num, arithmetic_self->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, stride, NNACL_ERR); + int offset = task_id * stride; + int count = NNACL_MIN(stride, elements_num - offset); + if (count <= 0) { + return NNACL_OK; + } + + void *in_data = arithmetic_self->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(in_data); + void *out_data = arithmetic_self->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + int in_data_type = arithmetic_self->base_.in_[FIRST_INPUT]->data_type_; + int out_data_type = arithmetic_self->base_.out_[OUTPUT_INDEX]->data_type_; + + if (in_data_type == kNumberTypeFloat32 && out_data_type == kNumberTypeBool) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_float_bool_); + return arithmetic_self->function_.func_float_bool_((float *)in_data + offset, (bool *)out_data + offset, count); + } + + if (in_data_type == kNumberTypeFloat32) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_); + return arithmetic_self->function_.func_((float *)in_data + offset, (float *)out_data + offset, count); + } + + if (in_data_type == kNumberTypeBool) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_bool_); + return arithmetic_self->function_.func_bool_((bool *)in_data + offset, (bool *)out_data + offset, count); + } + + if (in_data_type == kNumberTypeInt32) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_int_); + return arithmetic_self->function_.func_int_((int32_t *)in_data + offset, (int32_t *)out_data + offset, count); + } + +#ifdef ENABLE_FP16 + if (in_data_type == kNumberTypeFloat16) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->f16_function_.func_); + return arithmetic_self->f16_function_.func_((float16_t *)in_data + offset, (float16_t *)out_data + offset, count); + } +#endif + return NNACL_ARITHMETIC_SELF_DATA_TYPE_UNSUPPORT; +} + +int ArithmeticSelfRun(void *cdata, int task_id, float l, float r) { + ArithmeticSelfStruct *arithmetic_self = (ArithmeticSelfStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self); + return ArithmeticSelfExecute(arithmetic_self, task_id); +} + +int ArithmeticSelfResize(KernelBase *self) { + ArithmeticSelfStruct *arithmetic_self = (ArithmeticSelfStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self); + self->thread_nr_ = arithmetic_self->base_.UpdateThread( + TC_PTYPE(arithmetic_self->op_type_), 1, 1, NNACLGetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_); + return NNACL_OK; +} + +int ArithmeticSelfCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, ArithmeticSelfRun, self, self->thread_nr_); +} + +int ArithmeticSelfPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ != ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ != ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_[OUTPUT_INDEX]->category_ == ConstTensor, NNACL_OUTPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_[OUTPUT_INDEX]->category_ == ConstScalar, NNACL_OUTPUT_TENSOR_ERROR); + return NNACL_OK; +} + +KernelBase *CreateArithmeticSelf(OpParameter *param, int data_type) { + ArithmeticSelfStruct *arithmetic_self = (ArithmeticSelfStruct *)malloc(sizeof(ArithmeticSelfStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arithmetic_self); + ArithmeticSelfGetArithmeticSelfFunction(arithmetic_self, param->type_); + ArithmeticSelfGetArithmeticSelfF16Function(arithmetic_self, param->type_); + arithmetic_self->op_type_ = param->type_; + arithmetic_self->base_.Prepare = ArithmeticSelfPrepare; + arithmetic_self->base_.Resize = ArithmeticSelfResize; + arithmetic_self->base_.Release = DefaultRelease; + arithmetic_self->base_.Compute = ArithmeticSelfCompute; + return (KernelBase *)arithmetic_self; +} + +REG_KERNEL_CREATOR(PrimType_LogicalNot, kNumberTypeBool, CreateArithmeticSelf) + +REG_KERNEL_CREATOR(PrimType_Abs, kNumberTypeInt32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Neg, kNumberTypeInt32, CreateArithmeticSelf) + +REG_KERNEL_CREATOR(PrimType_Abs, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Ceil, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Cos, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Erf, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Floor, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Log, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Log1p, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_LogicalNot, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Neg, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Reciprocal, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Round, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Rsqrt, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Square, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Sqrt, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Sin, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_IsFinite, kNumberTypeFloat32, CreateArithmeticSelf) + +REG_KERNEL_CREATOR(PrimType_Abs, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Cos, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Log, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Square, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Sqrt, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Rsqrt, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Sin, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_LogicalNot, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Floor, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Ceil, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Round, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Neg, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Reciprocal, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Erf, kNumberTypeFloat16, CreateArithmeticSelf) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_self.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_self.h new file mode 100644 index 0000000000000000000000000000000000000000..4b8bf8c08ca28c858a6f1a6a72fda17e92d0397f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_self.h @@ -0,0 +1,48 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ARITHMETIC_SELF_H_ +#define NNACL_KERNEL_ARITHMETIC_SELF_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ArithmeticSelfFunction { + int primitive_type_; + int (*func_)(const float *input, float *output, const int element_size); + int (*func_bool_)(const bool *input, bool *output, const int element_size); + int (*func_int_)(const int *input, int *output, const int element_size); + int (*func_float_bool_)(const float *input, bool *output, const int element_size); +} ArithmeticSelfFunction; + +typedef struct ArithmeticSelfF16Function { + int primitive_type_; +#ifdef ENABLE_FP16 + int (*func_)(const float16_t *input, float16_t *output, int element_size); +#endif +} ArithmeticSelfF16Function; + +typedef struct ArithmeticSelfStruct { + KernelBase base_; + int op_type_; + ArithmeticSelfFunction function_; + ArithmeticSelfF16Function f16_function_; +} ArithmeticSelfStruct; + +KernelBase *CreateArithmeticSelf(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ARITHMETIC_SELF_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_norm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_norm.c new file mode 100644 index 0000000000000000000000000000000000000000..ae0d1519c00fc054372843e0867de48cf1f6eeca --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_norm.c @@ -0,0 +1,134 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/batch_norm.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/batchnorm_parameter.h" +#include "nnacl_c/fp32/batchnorm_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/batchnorm_fp16.h" +#endif + +int BatchNormFillParam(BatchNormStruct *batch_norm) { + TensorC *input_tensor = batch_norm->base_.in_[FIRST_INPUT]; + int in_channel = input_tensor->shape_[input_tensor->shape_size_ - 1]; + + TensorC *mean_tensor = batch_norm->base_.in_[SECOND_INPUT]; + int mean_channel = mean_tensor->shape_[mean_tensor->shape_size_ - 1]; + + TensorC *var_tensor = batch_norm->base_.in_[SECOND_INPUT]; + int var_channel = mean_tensor->shape_[var_tensor->shape_size_ - 1]; + + if (in_channel != mean_channel || in_channel != var_channel) { + return NNACL_BATCH_NORM_CHANNEL_SHAPE_INVALID; + } + + batch_norm->channel_ = in_channel; + batch_norm->unit_ = 1; + for (size_t i = 0; i < input_tensor->shape_size_ - 1; i++) { + batch_norm->unit_ *= input_tensor->shape_[i]; + } + if (batch_norm->momentum_ < 0.0f) { + batch_norm->momentum_ = 0.0f; + } + return NNACL_OK; +} + +int BatchNormRun(void *cdata, int task_id, float l, float r) { + BatchNormStruct *bn = (BatchNormStruct *)cdata; + void *in_data = bn->base_.in_[FIRST_INPUT]->data_; + void *out_data = bn->base_.out_[OUTPUT_INDEX]->data_; + if (bn->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + BatchNormFp16((float16_t *)in_data, (float16_t *)bn->mean_, (float16_t *)bn->variance_, bn, task_id, + bn->base_.thread_nr_, (float16_t *)out_data); +#endif + } else { + BatchNormFp32((float *)in_data, (float *)bn->mean_, (float *)bn->variance_, bn, task_id, bn->base_.thread_nr_, + (float *)out_data); + } + return NNACL_OK; +} + +int BatchNormReSize(KernelBase *self) { + BatchNormStruct *batch_norm = (BatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(batch_norm); + + int ret = BatchNormFillParam(batch_norm); + if (ret != NNACL_OK) { + return ret; + } + + (void)batch_norm->base_.Release(self); + + batch_norm->mean_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(self->in_[SECOND_INPUT])); + batch_norm->variance_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(self->in_[THIRD_INPUT])); + if (batch_norm->mean_ == NULL || batch_norm->variance_ == NULL) { + (void)batch_norm->base_.Release(self); + return NNACL_ERR; + } + + (void)memcpy(batch_norm->mean_, self->in_[SECOND_INPUT]->data_, NNACLGetSize(self->in_[SECOND_INPUT])); + (void)memcpy(batch_norm->variance_, self->in_[THIRD_INPUT]->data_, NNACLGetSize(self->in_[THIRD_INPUT])); + return NNACL_OK; +} + +int BatchNormRelease(KernelBase *self) { + BatchNormStruct *batch_norm = (BatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(batch_norm); + + if (batch_norm->mean_ != NULL) { + self->env_->Free(self->env_->allocator_, batch_norm->mean_); + batch_norm->mean_ = NULL; + } + if (batch_norm->variance_ != NULL) { + self->env_->Free(self->env_->allocator_, batch_norm->variance_); + batch_norm->variance_ = NULL; + } + return NNACL_OK; +} + +int BatchNormPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + + BatchNormStruct *batch_norm = (BatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(batch_norm); + batch_norm->momentum_ = -1.0f; + batch_norm->epsilon_ = ((BatchNormParameter *)self->param_)->epsilon_; + return NNACL_OK; +} + +int BatchNormCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, BatchNormRun, self, self->thread_nr_); +} + +KernelBase *CreateBatchNorm(OpParameter *param, int data_type) { + BatchNormStruct *batch_norm = (BatchNormStruct *)malloc(sizeof(BatchNormStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(batch_norm); + memset(batch_norm, 0, sizeof(BatchNormStruct)); + batch_norm->data_type_ = data_type; + batch_norm->base_.Prepare = BatchNormPrepare; + batch_norm->base_.Resize = BatchNormReSize; + batch_norm->base_.Release = BatchNormRelease; + batch_norm->base_.Compute = BatchNormCompute; + return (KernelBase *)batch_norm; +} + +REG_KERNEL_CREATOR(PrimType_BatchNorm, kNumberTypeFloat16, CreateBatchNorm) +REG_KERNEL_CREATOR(PrimType_BatchNorm, kNumberTypeFloat32, CreateBatchNorm) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_norm.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..e1afa44a8e4f78630bf54cc4f402aa2eb9c3db75 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_norm.h @@ -0,0 +1,38 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_BATCH_NORM_H_ +#define NNACL_KERNEL_BATCH_NORM_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct BatchNormStruct { + KernelBase base_; + int data_type_; + void *mean_; + void *variance_; + float momentum_; + int unit_; + int channel_; + float epsilon_; +} BatchNormStruct; + +KernelBase *CreateBatchNorm(OpParameter *param, int data_type); +int BatchNormRelease(KernelBase *self); +int BatchNormFillParam(BatchNormStruct *batch_norm); + +#endif // NNACL_KERNEL_BATCH_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_to_space.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_to_space.c new file mode 100644 index 0000000000000000000000000000000000000000..f0953fbc11df83625259399bb6634aef2c27229d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_to_space.c @@ -0,0 +1,114 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/batch_to_space.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/base/batch_to_space_base.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/batch_to_space_parameter.h" + +int BatchToSpaceProcessInput(BatchToSpaceStruct *batch_to_space) { + TensorC *block_shape = batch_to_space->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(block_shape); + NNACL_CHECK_NULL_RETURN_ERR(block_shape->data_); + TensorC *crop = batch_to_space->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(crop); + NNACL_CHECK_NULL_RETURN_ERR(crop->data_); + + if (NNACLGetElementNum(block_shape) < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) { + return NNACL_BATCH_TO_SPACE_BLOCK_SHAPE_INVALID; + } + if (NNACLGetElementNum(crop) < COMM_SHAPE_SIZE) { + return NNACL_BATCH_TO_SPACE_CROP_INVALID; + } + + int32_t *block_shape_data = (int32_t *)block_shape->data_; + for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) { + batch_to_space->block_shape_[i] = block_shape_data[i]; + } + + int32_t *crops_data = (int32_t *)crop->data_; + batch_to_space->no_crop_ = true; + for (int i = 0; i < COMM_SHAPE_SIZE; ++i) { + batch_to_space->crops_[i] = crops_data[i]; + if (batch_to_space->crops_[i] != 0) { + batch_to_space->no_crop_ = false; + } + } + return NNACL_OK; +} + +int BatchToSpaceCompute(KernelBase *self) { + BatchToSpaceStruct *batch_to_space = (BatchToSpaceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(batch_to_space); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + size_t data_size = DataTypeCSize(input->data_type_); + if (self->in_size_ == Num1) { + if (batch_to_space->no_crop_) { + BatchToSpaceNoCropForNHWC(input->data_, output->data_, input->shape_, output->shape_[Index0], + batch_to_space->block_shape_, data_size); + } else { + BatchToSpaceForNHWC(input->data_, output->data_, input->shape_, output->shape_[Index0], + batch_to_space->block_shape_, batch_to_space->crops_, data_size); + } + } + + if (self->in_size_ == Num3) { + int ret = BatchToSpaceProcessInput(batch_to_space); + if (ret != NNACL_OK) { + return ret; + } + if (batch_to_space->no_crop_) { + BatchToSpaceNoCropForNHWC(input->data_, output->data_, input->shape_, output->shape_[Index0], + batch_to_space->block_shape_, data_size); + } else { + BatchToSpaceForNHWC(input->data_, output->data_, input->shape_, output->shape_[Index0], + batch_to_space->block_shape_, batch_to_space->crops_, data_size); + } + } + return NNACL_OK; +} + +int BatchToSpaceResize(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + NNACL_CHECK_TRUE_RET(self->in_[FIRST_INPUT]->shape_size_ == COMM_SHAPE_SIZE, NNACL_ERR); + return NNACL_OK; +} + +KernelBase *CreateBatchToSpace(OpParameter *param, int data_type) { + BatchToSpaceStruct *batch_to_space = (BatchToSpaceStruct *)malloc(sizeof(BatchToSpaceStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(batch_to_space); + memset(batch_to_space, 0, sizeof(BatchToSpaceStruct)); + BatchToSpaceParameter *bts_param = (BatchToSpaceParameter *)param; + memcpy(batch_to_space->crops_, bts_param->crops_, sizeof(int32_t) * COMM_SHAPE_SIZE); + memcpy(batch_to_space->block_shape_, bts_param->block_shape_, sizeof(int32_t) * BATCH_TO_SPACE_BLOCK_SHAPE_SIZE); + batch_to_space->base_.Prepare = DefaultPrepare1In1Out; + batch_to_space->base_.Resize = BatchToSpaceResize; + batch_to_space->base_.Release = DefaultRelease; + batch_to_space->base_.Compute = BatchToSpaceCompute; + return (KernelBase *)batch_to_space; +} + +REG_KERNEL_CREATOR(PrimType_BatchToSpace, kNumberTypeFloat16, CreateBatchToSpace) +REG_KERNEL_CREATOR(PrimType_BatchToSpace, kNumberTypeFloat32, CreateBatchToSpace) +REG_KERNEL_CREATOR(PrimType_BatchToSpaceND, kNumberTypeFloat16, CreateBatchToSpace) +REG_KERNEL_CREATOR(PrimType_BatchToSpaceND, kNumberTypeFloat32, CreateBatchToSpace) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_to_space.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_to_space.h new file mode 100644 index 0000000000000000000000000000000000000000..3e75a4a6432e9ac7d1b136eeb3679edd760d4729 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_to_space.h @@ -0,0 +1,33 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_BATCH_TO_SPACE_H_ +#define NNACL_KERNEL_BATCH_TO_SPACE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/batch_to_space_parameter.h" + +typedef struct BatchToSpaceStruct { + KernelBase base_; + bool no_crop_; + int32_t crops_[COMM_SHAPE_SIZE]; + int32_t block_shape_[BATCH_TO_SPACE_BLOCK_SHAPE_SIZE]; +} BatchToSpaceStruct; + +KernelBase *CreateBatchToSpace(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_BATCH_TO_SPACE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/biasadd.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/biasadd.c new file mode 100644 index 0000000000000000000000000000000000000000..a565a8c2f39031f6755297a852a082abd56717b3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/biasadd.c @@ -0,0 +1,131 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/biasadd.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/bias_add.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +#define BIAS_ADD_PER_UNIT_LOAD_NUM 2 +#define BIAS_ADD_PER_UNIT_STORE_NUM 1 +#define SPLIT_POINTS_SIZE 32 + +typedef struct BiasAddStruct { + KernelBase base_; + int64_t inner_num_; + int64_t outer_num_; + int64_t total_num_; + bool batch_priority_; + int64_t split_points_[SPLIT_POINTS_SIZE]; + int split_pionts_size_; +} BiasAddStruct; + +int ChooseBiasThreadCuttingStrategy(KernelBase *self) { + BiasAddStruct *bias_add = (BiasAddStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(bias_add); + self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_BiasAdd), BIAS_ADD_PER_UNIT_LOAD_NUM, + BIAS_ADD_PER_UNIT_STORE_NUM, bias_add->total_num_, self->thread_nr_); + if (self->thread_nr_ > SPLIT_POINTS_SIZE) { + self->thread_nr_ = SPLIT_POINTS_SIZE; + } + + bias_add->split_pionts_size_ = 0; + int64_t block_size = 1; + block_size = bias_add->total_num_ / self->thread_nr_; + int64_t remain_data = bias_add->total_num_ - block_size * self->thread_nr_; + int64_t split_point = 0; + while (split_point < bias_add->total_num_) { + bias_add->split_points_[bias_add->split_pionts_size_++] = split_point; + split_point += block_size; + if (remain_data > 0) { + ++split_point; + --remain_data; + } + } + self->thread_nr_ = bias_add->split_pionts_size_; + if (bias_add->inner_num_ >= C64NUM && block_size / bias_add->inner_num_ >= C6NUM) { + bias_add->batch_priority_ = true; + } else { + bias_add->batch_priority_ = false; + } + return NNACL_OK; +} + +int BiasRun(void *cdata, int task_id, float l, float r) { + BiasAddStruct *bias_add = (BiasAddStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(bias_add); + + float *input = (float *)(bias_add->base_.in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(input); + float *bias = (float *)(bias_add->base_.in_[SECOND_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(bias); + float *output = (float *)(bias_add->base_.out_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(output); + + int64_t block_start = bias_add->split_points_[task_id]; + int64_t block_end = bias_add->total_num_; + if ((task_id + 1) < bias_add->split_pionts_size_) { + block_end = bias_add->split_points_[task_id + 1]; + } + BiasAddOpt(input, bias, output, block_start, block_end, bias_add->inner_num_, bias_add->batch_priority_); + return NNACL_OK; +} + +int BiasAddResize(struct KernelBase *self) { + BiasAddStruct *bias_add = (BiasAddStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(bias_add); + + TensorC *in_tensor = self->in_[FIRST_INPUT]; + TensorC *add_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_FALSE(in_tensor->shape_size_ == 0, NNACL_ERR); + NNACL_CHECK_FALSE(add_tensor->shape_size_ == 0, NNACL_ERR); + NNACL_CHECK_FALSE(in_tensor->shape_size_ < add_tensor->shape_size_, NNACL_ERR); + + size_t dim_offset = in_tensor->shape_size_ - add_tensor->shape_size_; + bias_add->inner_num_ = 1; + for (size_t i = 0; i < add_tensor->shape_size_; ++i) { + NNACL_CHECK_FALSE(in_tensor->shape_[i + dim_offset] != add_tensor->shape_[i], NNACL_BIAS_ADD_SHAPE_NOT_MATCH); + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(in_tensor->shape_[i], bias_add->inner_num_), NNACL_BIAS_ADD_SHAPE_OVERFLOW); + bias_add->inner_num_ *= add_tensor->shape_[i]; + } + + bias_add->outer_num_ = 1; + for (size_t i = 0; i < dim_offset; ++i) { + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(in_tensor->shape_[i], bias_add->outer_num_), NNACL_BIAS_ADD_SHAPE_OVERFLOW); + bias_add->outer_num_ *= in_tensor->shape_[i]; + } + + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(bias_add->inner_num_, bias_add->outer_num_), NNACL_BIAS_ADD_SHAPE_OVERFLOW); + bias_add->total_num_ = bias_add->inner_num_ * bias_add->outer_num_; + return ChooseBiasThreadCuttingStrategy(self); +} + +int BiasAddCompute(struct KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, BiasRun, self, self->thread_nr_); +} + +KernelBase *CreateBiasAdd(OpParameter *param, int data_type) { + BiasAddStruct *bias_add = (BiasAddStruct *)malloc(sizeof(BiasAddStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(bias_add); + bias_add->base_.Prepare = DefaultPrepare2In1Out; + bias_add->base_.Resize = BiasAddResize; + bias_add->base_.Release = DefaultRelease; + bias_add->base_.Compute = BiasAddCompute; + return (KernelBase *)bias_add; +} + +REG_KERNEL_CREATOR(PrimType_BiasAdd, kNumberTypeFloat32, CreateBiasAdd) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/biasadd.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/biasadd.h new file mode 100644 index 0000000000000000000000000000000000000000..1a8577c83cf0011410744a2dff660a46357c1050 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/biasadd.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_BIASADD_H_ +#define NNACL_KERNEL_BIASADD_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateBiasAdd(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_BIASADD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/cast.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/cast.c new file mode 100644 index 0000000000000000000000000000000000000000..cfbfaf8519da1fe88e66aec84be69c8d4743b62b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/cast.c @@ -0,0 +1,209 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/cast.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/cast_base.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/tensor_c_utils.h" + +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/cast_fp16.h" +#endif + +int CastToFp32(const TensorC *input, TensorC *output, int offset, int data_num) { + int input_data_type = input->data_type_; + float *output_data = (float *)output->data_; + switch (input_data_type) { + case kNumberTypeBool: + BoolToFloat32((const bool *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeUInt8: + Uint8ToFloat32((const uint8_t *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeInt32: + Int32ToFloat32((const int32_t *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeFloat16: +#ifdef ENABLE_FP16 + Fp16ToFloat32((const float16_t *)(input->data_) + offset, output_data + offset, data_num); +#else + Fp16ToFloat32((const uint16_t *)(input->data_) + offset, output_data + offset, data_num); +#endif + break; + case kNumberTypeInt64: + Int64ToFloat32((const int64_t *)(input->data_) + offset, output_data + offset, data_num); + break; + default: + return NNACL_ERR; + } + return NNACL_OK; +} + +int CastToFp16(const TensorC *input, TensorC *output, int offset, int data_num) { + int input_data_type = input->data_type_; +#ifdef ENABLE_FP16 + float16_t *output_data = (float16_t *)output->data_; + switch (input_data_type) { + case kNumberTypeFloat32: + Float32ToFp16((const float *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeInt64: + Int64ToFp16((const int64_t *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeInt32: + Int32ToFp16((const int32_t *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeBool: + BoolToFp16((const bool *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeUInt8: + Uint8ToFp16((const uint8_t *)(input->data_) + offset, output_data + offset, data_num); + break; + default: + return NNACL_ERR; + } +#else + if (input_data_type == kNumberTypeFloat32) { + Float32ToFp16((const float *)(input->data_) + offset, (uint16_t *)(output->data_) + offset, data_num); + } else { + return NNACL_ERR; + } +#endif + return NNACL_OK; +} + +int CastToOthers(const TensorC *input, TensorC *output, int offset, int data_num) { + int input_data_type = input->data_type_; + int output_data_type = output->data_type_; + void *output_data = output->data_; + if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt64) { + Float32ToInt64((const float *)(input->data_) + offset, (int64_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) { + Float32ToInt32((const float *)(input->data_) + offset, (int32_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) { + Int32ToInt64((const int32_t *)(input->data_) + offset, (int64_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeInt64 && output_data_type == kNumberTypeInt32) { + Int64ToInt32((const int64_t *)(input->data_) + offset, (int32_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt16) { + Float32ToInt16((const float *)(input->data_) + offset, (int16_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeBool && output_data_type == kNumberTypeInt32) { + BoolToInt32((const bool *)(input->data_) + offset, (int32_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeBool) { + Float32ToBool((const float *)(input->data_) + offset, (bool *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeUInt8) { + Float32ToUint8((const float *)(input->data_) + offset, (uint8_t *)(output_data) + offset, data_num); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} + +int CastLaunch(void *cdata, int task_id, float l, float r) { + CastStruct *cast = (CastStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(cast); + + NNACL_CHECK_FALSE(cast->base_.in_size_ < ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(cast->base_.out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + TensorC *in = cast->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in); + NNACL_CHECK_NULL_RETURN_ERR(in->data_); + TensorC *out = cast->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out); + NNACL_CHECK_NULL_RETURN_ERR(out->data_); + + int stride = cast->stride_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, stride, NNACL_ERR); + + int data_num = MSMIN(stride, cast->data_num_ - task_id * stride); + if (data_num <= 0) { + return NNACL_OK; + } + + int offset = task_id * stride; + int input_data_type = in->data_type_; + int output_data_type = out->data_type_; + if (input_data_type == output_data_type) { + size_t datalen = DataTypeCSize((TypeIdC)input_data_type); + memcpy((int8_t *)(out->data_) + offset * datalen, (int8_t *)(in->data_) + offset * datalen, data_num * datalen); + return NNACL_OK; + } + + if (output_data_type == kNumberTypeFloat32) { + return CastToFp32(in, out, offset, data_num); + } else if (output_data_type == kNumberTypeFloat16) { + return CastToFp16(in, out, offset, data_num); + } else { + return CastToOthers(in, out, offset, data_num); + } + return NNACL_OK; +} + +int cast_prepare(struct KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + return NNACL_OK; +} + +// Kernel resize input shape +int cast_resize(struct KernelBase *self) { + CastStruct *cast = (CastStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(cast); + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + int data_num = NNACLGetElementNum(in_tensor); + if (data_num == 0) { + return NNACL_OK; + } + + cast->data_num_ = data_num; + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + // update thread num + cast->base_.thread_nr_ = cast->base_.UpdateThread( + TC_PTYPE(PrimType_Cast), 1, 1, NNACLGetElementNum(cast->base_.out_[FIRST_INPUT]), cast->base_.thread_nr_); + cast->stride_ = UP_DIV(data_num, cast->base_.thread_nr_); + return NNACL_OK; +} + +int cast_release(struct KernelBase *self) { return NNACL_OK; } + +// Cast Op Compute +int cast_compute(struct KernelBase *self) { + CastStruct *cast = (CastStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(cast); + if (cast->data_num_ == 0) { + return NNACL_OK; + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, CastLaunch, self, self->thread_nr_); +} + +KernelBase *CreateCast(OpParameter *param, int data_type) { + CastStruct *cast = (CastStruct *)malloc(sizeof(CastStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(cast); + memset(cast, 0, sizeof(CastStruct)); + cast->base_.Prepare = cast_prepare; + cast->base_.Resize = cast_resize; + cast->base_.Release = cast_release; + cast->base_.Compute = cast_compute; + cast->stride_ = 0; + cast->data_num_ = 0; + return (KernelBase *)cast; +} + +// todo register kernel diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/cast.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/cast.h new file mode 100644 index 0000000000000000000000000000000000000000..1312c4aaf31ca487e3227f510809ea8286df8fe1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/cast.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CAST_H_ +#define NNACL_KERNEL_CAST_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct CastStruct { + KernelBase base_; + int stride_; + int data_num_; +} CastStruct; + +KernelBase *CreateCast(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_CAST_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/clip.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/clip.c new file mode 100644 index 0000000000000000000000000000000000000000..1d8ef8e3c9197303e283ef9fc2f32747f3e867ee --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/clip.c @@ -0,0 +1,123 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/clip.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/clip_parameter.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" + +int GetClipMinMaxValue(TensorC *tensor, float *data) { + NNACL_CHECK_NULL_RETURN_ERR(tensor); + switch (tensor->data_type_) { + case kNumberTypeFloat: + case kNumberTypeFloat32: + *data = *((float *)tensor->data_); + break; + case kNumberTypeInt: + case kNumberTypeInt32: + *data = *((int *)tensor->data_); + break; + default: + return NNACL_CLIP_DATA_TYPE_INVALID; + } + return NNACL_OK; +} + +int ClipResize(struct KernelBase *self) { + ClipStruct *clip = (ClipStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(clip); + clip->base_.thread_nr_ = clip->base_.UpdateThread( + TC_PTYPE(PrimType_Clip), 1, 1, NNACLGetElementNum(clip->base_.out_[FIRST_INPUT]), clip->base_.thread_nr_); + + clip->length_ = NNACLGetElementNum(clip->base_.in_[FIRST_INPUT]); + clip->stride_ = UP_DIV(clip->length_, clip->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(clip->stride_, clip->base_.thread_nr_, NNACL_ERR); + return NNACL_OK; +} + +int ClipImpl(void *cdata, int task_id, float l, float r) { + ClipStruct *clip = (ClipStruct *)cdata; + void *in = clip->base_.in_[FIRST_INPUT]->data_; + void *out = clip->base_.out_[FIRST_INPUT]->data_; + + int stride = clip->stride_ * task_id; + int count = NNACL_MIN(clip->stride_, clip->length_ - stride); + if (count <= 0) { + return NNACL_OK; + } + + switch (clip->base_.in_[FIRST_INPUT]->data_type_) { + case kNumberTypeFloat: + case kNumberTypeFloat32: { + return Fp32Clip((float *)in + stride, count, (float *)out + stride, clip->min_val_, clip->max_val_); + } break; + case kNumberTypeInt: + case kNumberTypeInt32: { + return Int32Clip((int *)in + stride, count, (int *)out + stride, (int)clip->min_val_, (int)clip->max_val_); + } break; + default: + return NNACL_CLIP_DATA_TYPE_INVALID; + } + return NNACL_OK; +} + +int ClipCompute(struct KernelBase *self) { + ClipStruct *clip = (ClipStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(clip); + ClipParameter *param = (ClipParameter *)clip->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + clip->min_val_ = param->min_val_; + clip->max_val_ = param->max_val_; + + int ret = NNACL_OK; + if (clip->base_.in_size_ > ONE_TENSOR) { + TensorC *min_tensor = clip->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(min_tensor); + NNACL_CHECK_NULL_RETURN_ERR(min_tensor->data_); + ret = GetClipMinMaxValue(min_tensor, &(clip->min_val_)); + } + if (clip->base_.in_size_ > TWO_TENSOR) { + TensorC *max_tensor = clip->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(max_tensor); + NNACL_CHECK_NULL_RETURN_ERR(max_tensor->data_); + ret = GetClipMinMaxValue(max_tensor, &(clip->max_val_)); + } + if (ret != NNACL_OK) { + return ret; + } + if (clip->min_val_ >= clip->max_val_) { + return NNACL_CLIP_MINMAX_VALUE_INVALID; + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, ClipImpl, clip, clip->base_.thread_nr_); +} + +KernelBase *CreateClip(OpParameter *param, int data_type) { + ClipStruct *clip = (ClipStruct *)malloc(sizeof(ClipStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(clip); + clip->base_.Prepare = DefaultPrepare1In1Out; + clip->base_.Resize = ClipResize; + clip->base_.Release = DefaultRelease; + clip->base_.Compute = ClipCompute; + return (KernelBase *)clip; +} + +REG_KERNEL_CREATOR(PrimType_Clip, kNumberTypeFloat, CreateClip) +REG_KERNEL_CREATOR(PrimType_Clip, kNumberTypeFloat32, CreateClip) +REG_KERNEL_CREATOR(PrimType_Clip, kNumberTypeInt, CreateClip) +REG_KERNEL_CREATOR(PrimType_Clip, kNumberTypeInt32, CreateClip) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/clip.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/clip.h new file mode 100644 index 0000000000000000000000000000000000000000..23f91ee2eefad52f62862e74c6a77a06fbf0d97a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/clip.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CLIP_H_ +#define NNACL_KERNEL_CLIP_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ClipStruct { + KernelBase base_; + float min_val_; + float max_val_; + int length_; + int stride_; +} ClipStruct; + +KernelBase *CreateClip(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_CLIP_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/concat.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/concat.c new file mode 100644 index 0000000000000000000000000000000000000000..4f382e868b66ce14bbf2a001462331ce6d8564c6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/concat.c @@ -0,0 +1,287 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/concat.h" +#include "nnacl_c/concat_parameter.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/tensor_c_utils.h" + +#define kConcatMinCostPerThread 16384 + +int DoConcat(ConcatStruct *concat, int task_id) { + NNACL_CHECK_FALSE(task_id < 0, NNACL_ERR); + NNACL_CHECK_FALSE(task_id > concat->block_size_, NNACL_ERR); + + int all_bytes = NNACLGetSize(concat->base_.out_[FIRST_INPUT]); + int64_t start = concat->block_splits_[task_id]; + int64_t end = task_id < (concat->block_size_ - 1) ? concat->block_splits_[task_id + 1] : all_bytes; + int64_t start_row = start / concat->inner_sizes_[concat->base_.in_size_]; + int64_t end_row = end / concat->inner_sizes_[concat->base_.in_size_]; + + size_t src_buf_size = concat->base_.in_size_ * sizeof(uint8_t *); + NNACL_CHECK_MALLOC_SIZE(src_buf_size); + uint8_t **src = (uint8_t **)concat->base_.env_->Alloc(concat->base_.env_->allocator_, src_buf_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(src); + for (size_t i = 0; i < concat->base_.in_size_; ++i) { + if (concat->is_with_data_[i]) { + src[i] = concat->inputs_ptr_[i] + start_row * concat->inner_sizes_[i]; + } + } + uint8_t *out = concat->output_ + start; + + int input_index = concat->block_boundary_infos_[task_id].begin_input_; + int end_index = concat->block_boundary_infos_[task_id].end_input_; + if (start_row == end_row) { + if (input_index == end_index) { + memcpy(out, src[input_index] + concat->block_boundary_infos_[task_id].begin_point_, + concat->block_boundary_infos_[task_id].end_point_ - concat->block_boundary_infos_[task_id].begin_point_); + concat->base_.env_->Free(concat->base_.env_->allocator_, src); + return NNACL_OK; + } + int64_t size = concat->inner_sizes_[input_index] - concat->block_boundary_infos_[task_id].begin_point_; + memcpy(out, src[input_index] + concat->block_boundary_infos_[task_id].begin_point_, size); + out += size; + ++input_index; + for (; input_index < end_index; ++input_index) { + memcpy(out, src[input_index], concat->inner_sizes_[input_index]); + out += concat->inner_sizes_[input_index]; + } + memcpy(out, src[input_index], concat->block_boundary_infos_[task_id].end_point_); + concat->base_.env_->Free(concat->base_.env_->allocator_, src); + return NNACL_OK; + } + for (int i = 0; i < input_index; ++i) { + src[i] += concat->inner_sizes_[i]; + } + int64_t size = concat->inner_sizes_[input_index] - concat->block_boundary_infos_[task_id].begin_point_; + memcpy(out, src[input_index] + concat->block_boundary_infos_[task_id].begin_point_, size); + src[input_index] += concat->inner_sizes_[input_index]; + out += size; + ++input_index; + for (; input_index < concat->base_.in_size_; ++input_index) { + memcpy(out, src[input_index], concat->inner_sizes_[input_index]); + src[input_index] += concat->inner_sizes_[input_index]; + out += concat->inner_sizes_[input_index]; + } + ++start_row; + for (; start_row < end_row; ++start_row) { + for (input_index = 0; input_index < concat->base_.in_size_; ++input_index) { + memcpy(out, src[input_index], concat->inner_sizes_[input_index]); + src[input_index] += concat->inner_sizes_[input_index]; + out += concat->inner_sizes_[input_index]; + } + } + for (input_index = 0; input_index < end_index; ++input_index) { + memcpy(out, src[input_index], concat->inner_sizes_[input_index]); + out += concat->inner_sizes_[input_index]; + } + memcpy(out, src[end_index], concat->block_boundary_infos_[task_id].end_point_); + + concat->base_.env_->Free(concat->base_.env_->allocator_, src); + return NNACL_OK; +} + +int ConcatRun(void *cdata, int task_id, float l, float r) { + ConcatStruct *concat = (ConcatStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(concat); + return DoConcat(concat, task_id); +} + +int InitConcatDynamicStatus(ConcatStruct *concat) { + ConcatParameter *param = (ConcatParameter *)concat->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + size_t i = 0; + int64_t output_inner_size = 0; + for (; i < concat->base_.in_size_; i++) { + TensorC *t = concat->base_.in_[i]; + NNACL_CHECK_FALSE(param->axis_ >= t->shape_size_, NNACL_CONCAT_AXIS_INVALID); + int64_t outer_size = 1; + for (int j = 0; j < param->axis_; ++j) { + outer_size *= t->shape_[j]; + } + int inner_size = DataTypeCSize(concat->data_type_); + NNACL_CHECK_TRUE_RET(inner_size > 0, NNACL_UNSUPPORTED_DATA_TYPE); + + for (int j = param->axis_; j < t->shape_size_; ++j) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(inner_size, t->shape_[j], NNACL_CONCAT_SHAPE_INVALID); + inner_size *= t->shape_[j]; + } + if (i == 0) { + concat->outer_size_ = outer_size; + } else { + NNACL_CHECK_TRUE_RET(concat->outer_size_ == outer_size, NNACL_CONCAT_SHAPE_INVALID); + } + if (inner_size == 0) { + concat->is_with_data_[i] = false; + concat->inner_sizes_[i] = inner_size; + continue; + } + concat->is_with_data_[i] = true; + concat->inner_sizes_[i] = inner_size; + output_inner_size += inner_size; + } + concat->inner_sizes_[i] = output_inner_size; + return NNACL_OK; +} + +void ComputeConcatUnitBoundary(ConcatStruct *concat, int64_t *pre_sum, int offset, int *input, int64_t *point) { + size_t index = 0; + for (; index < concat->base_.in_size_; ++index) { + if (offset < pre_sum[index]) { + break; + } + } + *input = index; + *point = concat->inner_sizes_[index] - (pre_sum[index] - offset); +} + +int ChooseConcatThreadCuttingStrategy(ConcatStruct *concat) { + NNACL_CHECK_TRUE_RET(concat->base_.thread_nr_ > 0, NNACL_ERR); + + int all_bytes = NNACLGetSize(concat->base_.out_[FIRST_INPUT]); + int64_t thread_count = MSMAX(1, MSMIN(all_bytes / kConcatMinCostPerThread, concat->base_.thread_nr_)); + + NNACL_CHECK_ZERO_RETURN_ERR(thread_count); + int64_t block_size = all_bytes / thread_count; + int64_t remain_byte = all_bytes - block_size * thread_count; + int64_t *pre_sum = + (int64_t *)concat->base_.env_->Alloc(concat->base_.env_->allocator_, concat->base_.in_size_ * sizeof(int64_t)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(pre_sum); + int64_t init_sum = 0; + for (size_t i = 0; i < concat->base_.in_size_; ++i) { + init_sum += concat->inner_sizes_[i]; + pre_sum[i] = init_sum; + } + + concat->block_size_ = 0; + + int64_t block_spilt = 0; + while (block_spilt < all_bytes) { + concat->block_splits_[concat->block_size_] = block_spilt; + block_spilt += block_size; + if (remain_byte > 0) { + ++block_spilt; + --remain_byte; + } + int64_t start = concat->block_splits_[concat->block_size_]; + int64_t end = block_spilt > all_bytes ? all_bytes : block_spilt; + int64_t start_offset = start - DOWN_ROUND(start, concat->inner_sizes_[concat->base_.in_size_]); + int64_t end_offset = end - DOWN_ROUND(end, concat->inner_sizes_[concat->base_.in_size_]); + ConcatBlockBoundaryInfo block_boundary_info; + ComputeConcatUnitBoundary(concat, pre_sum, start_offset, &block_boundary_info.begin_input_, + &block_boundary_info.begin_point_); + ComputeConcatUnitBoundary(concat, pre_sum, end_offset, &block_boundary_info.end_input_, + &block_boundary_info.end_point_); + concat->block_boundary_infos_[concat->block_size_] = block_boundary_info; + concat->block_size_++; + } + + concat->base_.thread_nr_ = concat->block_size_; + concat->base_.env_->Free(concat->base_.env_->allocator_, pre_sum); + return NNACL_OK; +} + +int ConcatResize(KernelBase *self) { + ConcatStruct *concat = (ConcatStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(concat); + ConcatParameter *param = (ConcatParameter *)concat->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + param->axis_ = param->axis_ >= 0 ? param->axis_ : self->in_[FIRST_INPUT]->shape_size_ + param->axis_; + NNACL_CHECK_FALSE(param->axis_ < 0, NNACL_CONCAT_AXIS_INVALID); + NNACL_CHECK_FALSE(param->axis_ >= self->in_[FIRST_INPUT]->shape_size_, NNACL_CONCAT_AXIS_INVALID); + + int ret = InitConcatDynamicStatus(concat); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + + if (concat->outer_size_ == 0 || concat->inner_sizes_[self->in_size_] == 0) { + return NNACL_OK; + } + + return ChooseConcatThreadCuttingStrategy(concat); +} + +int ConcatPepare(KernelBase *self) { + ConcatStruct *concat = (ConcatStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(concat); + + concat->inputs_ptr_ = self->env_->Alloc(self->env_->allocator_, self->in_size_ * sizeof(uint8_t *)); + NNACL_CHECK_NULL_RETURN_ERR(concat->inputs_ptr_); + concat->is_with_data_ = self->env_->Alloc(self->env_->allocator_, self->in_size_ * sizeof(bool)); + NNACL_CHECK_NULL_RETURN_ERR(concat->is_with_data_); + concat->inner_sizes_ = + self->env_->Alloc(self->env_->allocator_, (self->in_size_ + self->out_size_) * sizeof(int64_t)); + NNACL_CHECK_NULL_RETURN_ERR(concat->inner_sizes_); + + return NNACL_OK; +} + +int ConcatRelease(KernelBase *self) { + ConcatStruct *concat = (ConcatStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(concat); + if (concat->inputs_ptr_ != NULL) { + self->env_->Free(self->env_->allocator_, concat->inputs_ptr_); + } + if (concat->is_with_data_ != NULL) { + self->env_->Free(self->env_->allocator_, concat->is_with_data_); + } + if (concat->inner_sizes_ != NULL) { + self->env_->Free(self->env_->allocator_, concat->inner_sizes_); + } + return NNACL_OK; +} + +int ConcatCompute(KernelBase *self) { + ConcatStruct *concat = (ConcatStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(concat); + if (concat->outer_size_ == 0 || concat->inner_sizes_[self->in_size_] == 0) { + return NNACL_OK; + } + + for (size_t i = 0; i < self->in_size_; ++i) { + if (!concat->is_with_data_[i]) { + continue; + } + NNACL_CHECK_NULL_RETURN_ERR(self->in_[i]->data_); + concat->inputs_ptr_[i] = self->in_[i]->data_; + } + + concat->output_ = self->out_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(concat->output_); + return self->env_->ParallelLaunch(self->env_->thread_pool_, ConcatRun, self, self->thread_nr_); +} + +KernelBase *CreateConcat(OpParameter *param, int data_type) { + ConcatStruct *concat = (ConcatStruct *)malloc(sizeof(ConcatStruct)); + NNACL_CHECK_NULL_RETURN_NULL(concat); + memset(concat, 0, sizeof(ConcatStruct)); + concat->data_type_ = kNumberTypeFloat32; + concat->inner_sizes_ = NULL; + concat->inputs_ptr_ = NULL; + concat->is_with_data_ = NULL; + concat->base_.Prepare = ConcatPepare; + concat->base_.Resize = ConcatResize; + concat->base_.Release = ConcatRelease; + concat->base_.Compute = ConcatCompute; + return (KernelBase *)concat; +} + +REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeBool, CreateConcat) +REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeInt32, CreateConcat) +REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeFloat32, CreateConcat) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/concat.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/concat.h new file mode 100644 index 0000000000000000000000000000000000000000..cdc201f1d129a0e8ca66fc6e53d3d111010ab9b3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/concat.h @@ -0,0 +1,52 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONCAT_H_ +#define NNACL_KERNEL_CONCAT_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ConcatBlockBoundaryInfo { + int begin_input_; // input-index of upper boundary + int end_input_; // input-index of lower boundary. + int64_t begin_point_; // offset of begin-input. + int64_t end_point_; // required size of end-input. +} ConcatBlockBoundaryInfo; + +typedef struct ConcatStruct { + KernelBase base_; + int64_t outer_size_; + uint8_t *output_; + TypeIdC data_type_; + + bool *is_with_data_; /* size = in_tensor_size */ + uint8_t **inputs_ptr_; /* size = in_tensor_size */ + int64_t *inner_sizes_; // byte-inner-size (including axis) of each input and the last one is output's. + + ConcatBlockBoundaryInfo block_boundary_infos_[MAX_THREAD_NUM]; /* dynamic block size */ + int64_t block_splits_[MAX_THREAD_NUM]; /* dynamic block size */ + size_t block_size_; /* dynamic block size = actual thread number */ +} ConcatStruct; + +KernelBase *CreateConcat(OpParameter *param, int data_type); +int DoConcat(ConcatStruct *concat, int task_id); +int ConcatPepare(KernelBase *self); +int ConcatRelease(KernelBase *self); +int ConcatResize(KernelBase *self); + +#endif // NNACL_KERNEL_CONCAT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_1x1.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_1x1.c new file mode 100644 index 0000000000000000000000000000000000000000..f52b7d285969a2251ab059fa719a92db7d5751eb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_1x1.c @@ -0,0 +1,365 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_1x1.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/base/conv1x1_base.h" +#include "nnacl_c/fp32/matmul_fp32.h" + +int Conv1x1Run(void *cdata, int task_id, float l, float r) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + MatMulParameter *matmul = &conv_1x1->matmul_param_; + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, conv_1x1->thread_stride_, NNACL_ERR); + int total_thead_stride_ = task_id * conv_1x1->thread_stride_; + int res_stride = matmul->col_ - total_thead_stride_; + int cur_oc = MSMIN(conv_1x1->thread_stride_, res_stride); + if (cur_oc <= 0) { + return NNACL_OK; + } + + TensorC *out_tensor = conv_1x1->conv_.base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + float *bias = conv_1x1->conv_.bias_data_ == NULL + ? NULL + : (float *)conv_1x1->conv_.bias_data_ + conv_1x1->thread_stride_ * task_id; + float *weight = (float *)conv_1x1->conv_.packed_weight_ + total_thead_stride_ * matmul->deep_; + + if (out_tensor->format_ == Format_NC4HW4) { + MatMulOpt(conv_1x1->pack_input_, weight, conv_1x1->output_ptr_ + total_thead_stride_ * matmul->row_, bias, + matmul->act_type_, matmul->deep_, matmul->row_, cur_oc, matmul->row_, OutType_NC4HW4); + } else { + MatMulOpt(conv_1x1->pack_input_, weight, conv_1x1->output_ptr_ + total_thead_stride_, bias, matmul->act_type_, + matmul->deep_, matmul->row_, cur_oc, matmul->col_, OutType_Nhwc); + } + return NNACL_OK; +} + +void Conv1x1PackMatmulInput(const float *src_ptr, float *dst_ptr, int row, int col) { +#ifdef ENABLE_AVX + RowMajor2Col6Major(src_ptr, dst_ptr, row, col); +#elif defined(ENABLE_SSE) + RowMajor2Col4Major(src_ptr, dst_ptr, row, col); +#else + RowMajor2Col12Major(src_ptr, dst_ptr, row, col); +#endif +} + +int Conv1x1RunHw(void *cdata, int task_id, float l, float r) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + MatMulParameter *matmul = &conv_1x1->matmul_param_; + TensorC *output_tensor = conv_1x1->conv_.base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, conv_1x1->thread_stride_, NNACL_ERR); + int total_thead_stride_ = task_id * conv_1x1->thread_stride_; + int res_stride = matmul->row_ - total_thead_stride_; + int cur_hw_ = MSMIN(conv_1x1->thread_stride_, res_stride); + if (cur_hw_ <= 0) { + return NNACL_OK; + } + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_, matmul->deep_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, conv_1x1->row_tile_, NNACL_ERR); + int total_row_tile_ = task_id * conv_1x1->row_tile_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_row_tile_, matmul->deep_, NNACL_ERR); + float *thread_input_ptr = conv_1x1->input_ptr_ + total_thead_stride_ * matmul->deep_; + float *thread_pack_input = conv_1x1->pack_input_ + total_row_tile_ * matmul->deep_; + float *thread_output_ptr = NULL; + if (output_tensor->format_ != Format_NC4HW4) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_, matmul->col_, NNACL_ERR); + thread_output_ptr = conv_1x1->output_ptr_ + total_thead_stride_ * matmul->col_; + } else { + int col_min = MSMIN(matmul->col_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_, col_min, NNACL_ERR); + thread_output_ptr = conv_1x1->output_ptr_ + total_thead_stride_ * col_min; + } + float *cur_intput = thread_input_ptr; + float *cur_output = thread_output_ptr; + float *bias = (float *)conv_1x1->conv_.bias_data_; + for (int i = 0; i < cur_hw_; i += conv_1x1->row_tile_) { + int cur_rows = (cur_hw_ - i >= conv_1x1->row_tile_) ? conv_1x1->row_tile_ : (cur_hw_ - i); + Conv1x1PackMatmulInput(cur_intput, thread_pack_input, cur_rows, matmul->deep_); + if (output_tensor->format_ == Format_NC4HW4) { + MatMulOpt(thread_pack_input, (float *)conv_1x1->conv_.packed_weight_, cur_output, bias, matmul->act_type_, + matmul->deep_, cur_rows, matmul->col_, matmul->row_, OutType_NC4HW4); + cur_output += conv_1x1->row_tile_ * MSMIN(matmul->col_, C4NUM); + } else { + MatMulOpt(thread_pack_input, (float *)conv_1x1->conv_.packed_weight_, cur_output, bias, matmul->act_type_, + matmul->deep_, cur_rows, matmul->col_, matmul->col_, OutType_Nhwc); + cur_output += conv_1x1->row_tile_ * matmul->col_; + } + cur_intput += conv_1x1->row_tile_ * matmul->deep_; + } + + return NNACL_OK; +} + +void Conv1x1PackWeight(ConvolutionBaseStruct *conv) { + TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(filter_tensor); + ConvComputeParam *compute = &conv->compute_; + NNACL_CHECK_NULL_RETURN_VOID(compute); + + if (compute->in_c_ <= 0 || compute->out_c_ <= 0) { + return; + } + + void *origin_weight = conv->base_.train_session_ ? filter_tensor->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + +#ifdef ENABLE_AVX + RowMajor2Col16Major((float *)origin_weight, (float *)conv->packed_weight_, compute->out_c_, compute->in_c_); +#elif defined(ENABLE_ARM32) + RowMajor2Col4Major((float *)origin_weight, (float *)conv->packed_weight_, compute->out_c_, compute->in_c_); +#else + RowMajor2Col8Major((float *)origin_weight, (float *)conv->packed_weight_, compute->out_c_, compute->in_c_); +#endif +} + +int Conv1x1MallocWeightBiasData(ConvolutionBaseStruct *conv) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + + int size = conv->compute_.in_c_ * UP_ROUND(conv->compute_.out_c_, conv_1x1->col_tile_) * sizeof(float); + if (!conv->base_.train_session_) { + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + if (conv->base_.in_size_ == THREE_TENSOR) { + size = UP_ROUND(conv->compute_.out_c_, conv_1x1->col_tile_) * sizeof(float); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + memset(conv->bias_data_, 0, size); + } + return NNACL_OK; +} + +void Conv1x1FreeTmpBuffer(Convolution1x1Struct *conv_1x1) { + if (conv_1x1->pre_trans_input_ && conv_1x1->input_ptr_ != NULL) { + conv_1x1->conv_.base_.env_->Free(conv_1x1->conv_.base_.env_->allocator_, conv_1x1->input_ptr_); + conv_1x1->input_ptr_ = NULL; + } + return; +} + +int InitConv1x1MatmulParam(Convolution1x1Struct *conv_1x1) { + ConvParameter *conv_param = (ConvParameter *)conv_1x1->conv_.base_.param_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_param->output_h_, conv_param->output_w_, NNACL_ERR); + conv_1x1->matmul_param_.row_ = conv_param->output_h_ * conv_param->output_w_; + conv_1x1->matmul_param_.col_ = conv_param->output_channel_; + conv_1x1->matmul_param_.deep_ = conv_param->input_channel_; + conv_1x1->matmul_param_.row_align_ = UP_ROUND(conv_1x1->matmul_param_.row_, conv_1x1->row_tile_); + conv_1x1->matmul_param_.col_align_ = UP_ROUND(conv_1x1->matmul_param_.col_, conv_1x1->col_tile_); + conv_1x1->matmul_param_.act_type_ = conv_param->act_type_; + return NNACL_OK; +} + +int InitConv1x1Param(Convolution1x1Struct *conv_1x1) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->row_tile_, conv_1x1->conv_.base_.thread_nr_, NNACL_ERR); + if ((conv_1x1->matmul_param_.row_ > (conv_1x1->row_tile_ * conv_1x1->conv_.base_.thread_nr_)) && + (conv_1x1->matmul_param_.row_ > conv_1x1->matmul_param_.col_)) { + conv_1x1->multi_thread_by_hw_ = true; + conv_1x1->conv_.base_.thread_nr_ = + MSMIN(conv_1x1->conv_.base_.thread_nr_, UP_DIV(conv_1x1->matmul_param_.row_, conv_1x1->row_tile_)); + if (conv_1x1->conv_.base_.thread_nr_ <= 0) { + return NNACL_ERR; + } + conv_1x1->thread_stride_ = + UP_DIV(UP_DIV(conv_1x1->matmul_param_.row_, conv_1x1->row_tile_), conv_1x1->conv_.base_.thread_nr_) * + conv_1x1->row_tile_; + } else { + conv_1x1->multi_thread_by_hw_ = false; + conv_1x1->conv_.base_.thread_nr_ = + MSMIN(conv_1x1->conv_.base_.thread_nr_, UP_DIV(conv_1x1->matmul_param_.col_, conv_1x1->col_tile_)); + if (conv_1x1->conv_.base_.thread_nr_ <= 0) { + return NNACL_ERR; + } + conv_1x1->thread_stride_ = + UP_DIV(UP_DIV(conv_1x1->matmul_param_.col_, conv_1x1->col_tile_), conv_1x1->conv_.base_.thread_nr_) * + conv_1x1->col_tile_; + } + + ConvParameter *conv_param = (ConvParameter *)conv_1x1->conv_.base_.param_; + conv_1x1->pre_trans_input_ = + (conv_param->pad_u_ != 0 || conv_param->pad_l_ != 0 || conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1); + if (conv_1x1->pre_trans_input_) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->matmul_param_.row_, conv_1x1->matmul_param_.deep_, NNACL_ERR); + conv_1x1->input_ptr_ = (float *)(conv_1x1->conv_.base_.env_->Alloc( + conv_1x1->conv_.base_.env_->allocator_, + conv_1x1->matmul_param_.row_ * conv_1x1->matmul_param_.deep_ * sizeof(float))); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_1x1->input_ptr_); + memset(conv_1x1->input_ptr_, 0, conv_1x1->matmul_param_.row_ * conv_1x1->matmul_param_.deep_ * sizeof(float)); + } + + return NNACL_OK; +} + +int Convolution1x1Resize(KernelBase *self) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + + Conv1x1FreeTmpBuffer(conv_1x1); + int error_code = ConvBasePrepare(&conv_1x1->conv_); + if (error_code != NNACL_OK) { + return error_code; + } + + error_code = InitConv1x1MatmulParam(conv_1x1); + if (error_code != NNACL_OK) { + return error_code; + } + + error_code = InitConv1x1Param(conv_1x1); + if (error_code != NNACL_OK) { + return error_code; + } + + return NNACL_OK; +} + +int Convolution1x1Prepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + +#ifdef ENABLE_AVX + conv_1x1->row_tile_ = C6NUM; + conv_1x1->col_tile_ = C16NUM; +#elif defined(ENABLE_SSE) + conv_1x1->row_tile_ = C4NUM; + conv_1x1->col_tile_ = C8NUM; +#elif defined(ENABLE_ARM32) + conv_1x1->row_tile_ = C12NUM; + conv_1x1->col_tile_ = C4NUM; +#else + conv_1x1->row_tile_ = C12NUM; + conv_1x1->col_tile_ = C8NUM; +#endif + + if (self->train_session_) { + int output_tile_size = UP_ROUND(conv_1x1->conv_.compute_.out_c_, conv_1x1->col_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->conv_.compute_.in_c_, output_tile_size, NNACL_ERR); + size_t size = conv_1x1->conv_.compute_.in_c_ * output_tile_size * sizeof(float); + conv_1x1->conv_.base_.work_size_ = size; + } + + int error_code = ConvBaseInitConvWeightBias(&conv_1x1->conv_); + if (error_code != NNACL_OK) { + return error_code; + } + return NNACL_OK; +} + +int Convolution1x1Release(KernelBase *self) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + Conv1x1FreeTmpBuffer(conv_1x1); + ConvBaseRelease(&conv_1x1->conv_); + return NNACL_OK; +} + +int Convolution1x1Compute(KernelBase *self) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + + float *src_in = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_in); + float *src_out = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_out); + + int pack_input_size = 0; + if (conv_1x1->multi_thread_by_hw_) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->conv_.base_.thread_nr_, conv_1x1->row_tile_, NNACL_ERR); + int total_row_tile_ = conv_1x1->conv_.base_.thread_nr_ * conv_1x1->row_tile_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_row_tile_, conv_1x1->matmul_param_.deep_, NNACL_ERR); + pack_input_size = total_row_tile_ * conv_1x1->matmul_param_.deep_; + } else { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->matmul_param_.row_align_, conv_1x1->matmul_param_.deep_, NNACL_ERR); + pack_input_size = conv_1x1->matmul_param_.row_align_ * conv_1x1->matmul_param_.deep_; + } + conv_1x1->pack_input_ = + (float *)conv_1x1->conv_.base_.env_->Alloc(conv_1x1->conv_.base_.env_->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_1x1->pack_input_); + + int ret = ConvBaseRepackWeight(&conv_1x1->conv_); + if (ret != NNACL_OK) { + return ret; + } + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->matmul_param_.row_, conv_1x1->matmul_param_.col_, NNACL_ERR); + int matmul_size = conv_1x1->matmul_param_.row_ * conv_1x1->matmul_param_.col_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_param->input_batch_ - 1, matmul_size, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_param->input_h_, conv_param->input_w_, NNACL_ERR); + int conv_input_hw = conv_param->input_h_ * conv_param->input_w_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_input_hw, conv_param->input_channel_, NNACL_ERR); + int conv_input_bhw = conv_input_hw * conv_param->input_channel_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_param->input_batch_ - 1, conv_input_bhw, NNACL_ERR); + for (int batch_index = 0; batch_index < conv_param->input_batch_; batch_index++) { + conv_1x1->output_ptr_ = src_out + batch_index * matmul_size; + float *tmp_in = src_in + batch_index * conv_input_bhw; + if (conv_1x1->pre_trans_input_) { + Conv1x1InputPack(tmp_in, conv_1x1->input_ptr_, conv_param, sizeof(float)); + } else { + conv_1x1->input_ptr_ = tmp_in; + } + if (conv_1x1->multi_thread_by_hw_) { + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, Conv1x1RunHw, self, self->thread_nr_); + } else { + Conv1x1PackMatmulInput(conv_1x1->input_ptr_, conv_1x1->pack_input_, conv_1x1->matmul_param_.row_, + conv_1x1->matmul_param_.deep_); + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, Conv1x1Run, self, self->thread_nr_); + } + if (ret != NNACL_OK) { + break; + } + } + + if (conv_1x1->pack_input_ != NULL) { + self->env_->Free(self->env_->allocator_, conv_1x1->pack_input_); + conv_1x1->pack_input_ = NULL; + } + return ret; +} + +ConvolutionBaseStruct *CreateConvolution1x1(ConvParameter *conv_param) { + Convolution1x1Struct *conv1x1 = (Convolution1x1Struct *)malloc(sizeof(Convolution1x1Struct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv1x1); + memset(conv1x1, 0, sizeof(Convolution1x1Struct)); + + conv1x1->conv_.is_sharing_pack_ = false; + conv1x1->conv_.malloc_weight_bias_ = Conv1x1MallocWeightBiasData; + conv1x1->conv_.pack_weight_ = Conv1x1PackWeight; + + conv1x1->conv_.base_.Resize = Convolution1x1Resize; + conv1x1->conv_.base_.Prepare = Convolution1x1Prepare; + conv1x1->conv_.base_.Release = Convolution1x1Release; + conv1x1->conv_.base_.Compute = Convolution1x1Compute; + + return (ConvolutionBaseStruct *)conv1x1; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_1x1.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_1x1.h new file mode 100644 index 0000000000000000000000000000000000000000..bd26ed48728df0402d64e3f1401fe031b6ab8b51 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_1x1.h @@ -0,0 +1,42 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_1X1_H_ +#define NNACL_KERNEL_CONVOLLUTION_1X1_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" + +typedef struct Convolution1x1Struct { + ConvolutionBaseStruct conv_; + MatMulParameter matmul_param_; + int row_tile_; + int col_tile_; + bool pre_trans_input_; + float *input_ptr_; + float *output_ptr_; + float *pack_input_; + bool multi_thread_by_hw_; + int thread_stride_; +} Convolution1x1Struct; + +ConvolutionBaseStruct *CreateConvolution1x1(ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_1X1_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_base.c new file mode 100644 index 0000000000000000000000000000000000000000..8b49cbb34738d4d6c72a2230ccf249035ad6cdac --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_base.c @@ -0,0 +1,209 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/tensor_c_utils.h" + +int ConvBaseUpdateParamInfo(ConvComputeParam *compute, ConvParameter *conv_param) { + compute->stride_h_ = conv_param->stride_h_; + compute->stride_w_ = conv_param->stride_w_; + compute->dilation_h_ = conv_param->dilation_h_; + compute->dilation_w_ = conv_param->dilation_w_; + compute->pad_u_ = conv_param->pad_u_; + compute->pad_d_ = conv_param->pad_d_; + compute->pad_l_ = conv_param->pad_l_; + compute->pad_r_ = conv_param->pad_r_; + + compute->in_c_ = conv_param->input_channel_; + compute->out_c_ = conv_param->output_channel_; + + compute->kernel_h_ = conv_param->kernel_h_; + compute->kernel_w_ = conv_param->kernel_w_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->kernel_h_, compute->kernel_w_, NNACL_ERR); + compute->kernel_hw_ = compute->kernel_h_ * compute->kernel_w_; + + return NNACL_OK; +} + +int ConvBaseUpdateComputeInfo(ConvolutionBaseStruct *conv) { + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + TensorC *input = conv->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = conv->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + conv_param->input_batch_ = NNACLGetBatch(input); + conv_param->input_h_ = NNACLGetHeight(input); + conv_param->input_w_ = NNACLGetWidth(input); + conv_param->input_channel_ = NNACLGetChannel(input); + conv_param->output_batch_ = NNACLGetBatch(output); + conv_param->output_h_ = NNACLGetHeight(output); + conv_param->output_w_ = NNACLGetWidth(output); + conv_param->output_channel_ = NNACLGetChannel(output); + + ConvComputeParam *compute = &conv->compute_; + compute->in_n_ = NNACLGetBatch(input); + compute->in_h_ = NNACLGetHeight(input); + compute->in_w_ = NNACLGetWidth(input); + compute->in_c_ = NNACLGetChannel(input); + NNACL_CHECK_FALSE(compute->in_c_ != conv_param->input_channel_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_h_, compute->in_w_, NNACL_ERR); + compute->in_hw_ = compute->in_h_ * compute->in_w_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_hw_, compute->in_n_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_hw_ * compute->in_n_, compute->in_c_, NNACL_ERR); + + compute->out_n_ = NNACLGetBatch(output); + compute->out_h_ = NNACLGetHeight(output); + compute->out_w_ = NNACLGetWidth(output); + compute->out_c_ = NNACLGetChannel(output); + NNACL_CHECK_FALSE(compute->out_c_ != conv_param->output_channel_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_h_, compute->out_w_, NNACL_ERR); + compute->out_hw_ = compute->out_h_ * compute->out_w_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_hw_, compute->out_n_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_hw_ * compute->out_n_, compute->out_c_, NNACL_ERR); + + return ConvBaseUpdateParamInfo(compute, conv_param); +} + +void ConvBaseRelease(ConvolutionBaseStruct *conv) { + if (!conv->base_.train_session_) { + if (!conv->is_sharing_pack_) { + conv->base_.env_->Free(conv->base_.env_->allocator_, conv->packed_weight_); + } else { + conv->free_sharing_weight_(conv->shaing_manager_, conv->packed_weight_); + } + conv->packed_weight_ = NULL; + } + + if (conv->bias_data_ != NULL) { + conv->base_.env_->Free(conv->base_.env_->allocator_, conv->bias_data_); + conv->bias_data_ = NULL; + } +} + +int ConvBasePrepare(ConvolutionBaseStruct *conv) { + NNACL_CHECK_FALSE(conv->base_.in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(conv->base_.out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + conv->out_format_ = conv->base_.out_[OUTPUT_INDEX]->format_; + return ConvBaseUpdateComputeInfo(conv); +} + +void ConvBaseUpdateOriginWeightAndBias(ConvolutionBaseStruct *conv) { + NNACL_CHECK_NULL_RETURN_VOID(conv); + + if (conv->base_.in_[SECOND_INPUT]->data_ != NULL) { + conv->origin_weight_ = conv->base_.in_[SECOND_INPUT]->data_; + } + + if (conv->base_.in_size_ == THREE_TENSOR && conv->base_.in_[THIRD_INPUT]->data_ != NULL) { + conv->origin_bias_ = conv->base_.in_[THIRD_INPUT]->data_; + } +} + +int ConvBaseInitConvWeightBias(ConvolutionBaseStruct *conv) { + if (conv->base_.train_session_) { + ConvBaseUpdateOriginWeightAndBias(conv); + } + + /* check weight shape done */ + if (!CheckInferShapeDone(&conv->base_.in_[SECOND_INPUT], ONE_TENSOR, NULL, 0)) { + return NNACL_OK; + } + + int ret = conv->malloc_weight_bias_(conv); + if (ret != NNACL_OK) { + return ret; + } + + if ((conv->base_.in_size_ == THREE_TENSOR) && (conv->origin_bias_ != NULL)) { + memcpy(conv->bias_data_, conv->origin_bias_, NNACLGetSize(conv->base_.in_[THIRD_INPUT])); + } + + if (!conv->base_.train_session_) { + if (conv->weight_is_packed_) { + return NNACL_OK; + } + if (conv->origin_weight_ != NULL) { + conv->pack_weight_(conv); + } else { + conv->is_repack_ = true; + } + } + return NNACL_OK; +} + +int ConvBaseCheckResizeValid(ConvolutionBaseStruct *conv) { + // ===============check in channel================= // + TensorC *input_tensor = conv->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + int resize_in_channel = NNACLGetChannel(input_tensor); + TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(filter_tensor); + int filter_in_channel = NNACLGetChannel(filter_tensor); + if (filter_in_channel != resize_in_channel) { + return NNACL_CONVOLUTION_INPUT_CHANNEL_UNMATCH; + } + return NNACL_OK; +} + +void *ConvBaseGetConvPackWeightData(ConvolutionBaseStruct *conv, int data_size) { + TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT]; + bool const_fit = weight_tensor->category_ != ConstTensor && weight_tensor->category_ != ConstScalar; + bool group_fit = ((ConvParameter *)conv->base_.param_)->group_ > 1; + bool sharing_fit = conv->get_sharing_weight_ == NULL; + + void *data = NULL; + if (sharing_fit || const_fit || group_fit) { + if (data_size <= 0) { + return NULL; + } + data = conv->base_.env_->Alloc(conv->base_.env_->allocator_, data_size); + conv->weight_is_packed_ = false; + conv->is_sharing_pack_ = false; + } else { + data = conv->get_sharing_weight_(conv->shaing_manager_, weight_tensor->data_, data_size, &conv->weight_is_packed_); + } + return data; +} + +int ConvBaseRepackWeight(ConvolutionBaseStruct *conv) { + NNACL_CHECK_NULL_RETURN_ERR(conv); + + conv->origin_weight_ = conv->origin_weight_ != NULL ? conv->origin_weight_ : conv->base_.in_[SECOND_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv->origin_weight_); + + if (conv->packed_weight_ == NULL) { + int ret = ConvBaseInitConvWeightBias(conv); + if (ret != NNACL_OK) { + return ret; + } + } + + if (conv->is_repack_ || conv->base_.train_session_) { + if (conv->base_.train_session_) { + conv->packed_weight_ = (float *)conv->base_.workspace_; + memset(conv->packed_weight_, 0, conv->base_.work_size_); + } else { + conv->is_repack_ = false; + } + conv->pack_weight_(conv); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_base.h new file mode 100644 index 0000000000000000000000000000000000000000..c932b60231fbfca5a032b3f699dc2d5374065b35 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_base.h @@ -0,0 +1,63 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_BASE_H_ +#define NNACL_KERNEL_CONVOLLUTION_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/tensor_c_utils.h" + +#define ConvMinBlock 1 + +typedef struct ConvolutionBaseStruct { + KernelBase base_; + ConvComputeParam compute_; + bool weight_is_packed_; + bool is_repack_; + bool infershape_done_; + bool use_batch_cut_flag_; + FormatC out_format_; + + void *packed_weight_; + void *bias_data_; + void *origin_weight_; // do not Free + void *origin_bias_; // do not Free + + void (*init_global_variable_)(struct ConvolutionBaseStruct *conv_im2col); + int (*malloc_weight_bias_)(struct ConvolutionBaseStruct *conv_base); + void (*pack_weight_)(struct ConvolutionBaseStruct *conv_base); + int (*run_impl_)(struct ConvolutionBaseStruct *conv, int task_id); + + bool is_sharing_pack_; + void *shaing_manager_; + void (*free_sharing_weight_)(void *manager, void *tensor_data); + void *(*get_sharing_weight_)(void *manager, const void *tensor_data, const size_t size, bool *is_packed); +} ConvolutionBaseStruct; + +int ConvBaseUpdateParamInfo(ConvComputeParam *compute, ConvParameter *conv_param); +int ConvBaseUpdateComputeInfo(ConvolutionBaseStruct *conv); +void ConvBaseRelease(ConvolutionBaseStruct *conv); +int ConvBaseCheckResizeValid(ConvolutionBaseStruct *conv); +int ConvBasePrepare(ConvolutionBaseStruct *conv); +int ConvBaseInitConvWeightBias(ConvolutionBaseStruct *conv); +int ConvBaseRepackWeight(ConvolutionBaseStruct *conv); +void ConvBaseUpdateOriginWeightAndBias(ConvolutionBaseStruct *conv); +void *ConvBaseGetConvPackWeightData(ConvolutionBaseStruct *conv, int data_size); + +#endif // NNACL_KERNEL_CONVOLLUTION_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_delegate.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_delegate.c new file mode 100644 index 0000000000000000000000000000000000000000..7ee4aa2c43beed1e807866844242288a4cd8d312 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_delegate.c @@ -0,0 +1,365 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_delegate.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/kernel/group_convolution.h" +#include "nnacl_c/kernel/convolution_depthwise.h" +#include "nnacl_c/kernel/convolution_1x1.h" +#include "nnacl_c/kernel/convolution_im2col.h" +#include "nnacl_c/kernel/convolution_winograd.h" +#include "nnacl_c/fp32/conv_winograd_fp32.h" +#include "nnacl_c/kernel/convolution_depthwise_sw.h" +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_sw_1x1.h" +#include "nnacl_c/kernel/convolution_sw_avx.h" +#include "nnacl_c/kernel/convolution_depthwise_sw_avx.h" +#endif +#ifdef ENABLE_ARM64 +#include "nnacl_c/kernel/convolution_depthwise_indirect.h" +#include "nnacl_c/kernel/convolution_sw_arm64.h" +#include "nnacl_c/fp32/conv_sw_arm64_fp32.h" +#endif +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +#include "nnacl_c/kernel/convolution_depthwise_3x3.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#endif + +#define MaxDwConvSWSize 32 + +float *ConvolutionDelegateCopyData(const TensorC *tensor) { + NNACL_CHECK_NULL_RETURN_NULL(tensor); + NNACL_CHECK_NULL_RETURN_NULL(tensor->data_); + + float *data = (float *)malloc(NNACLGetSize(tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(data); + + (void)memcpy(data, tensor->data_, NNACLGetSize(tensor)); + return data; +} + +int ConvolutionDelegateGetWeightData(ConvolutionDelegateStruct *convolution_delegate) { + if (convolution_delegate->conv_.base_.in_[SECOND_INPUT]->data_ == NULL) { + return NNACL_OK; + } + if (convolution_delegate->conv_.infershape_done_) { + convolution_delegate->origin_weight_ = convolution_delegate->conv_.base_.in_[SECOND_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate->origin_weight_); + convolution_delegate->need_free_weight_ = false; + return NNACL_OK; + } + convolution_delegate->origin_weight_ = + ConvolutionDelegateCopyData(convolution_delegate->conv_.base_.in_[SECOND_INPUT]); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(convolution_delegate->origin_weight_); + convolution_delegate->need_free_weight_ = true; + return NNACL_OK; +} + +int ConvolutionDelegateGetBiasData(ConvolutionDelegateStruct *convolution_delegate) { + if (convolution_delegate->conv_.base_.in_size_ != THREE_TENSOR) { + convolution_delegate->origin_bias_ = NULL; + convolution_delegate->need_free_bias_ = false; + return NNACL_OK; + } + + if (convolution_delegate->conv_.infershape_done_) { + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate->conv_.base_.in_[THIRD_INPUT]); + convolution_delegate->origin_bias_ = convolution_delegate->conv_.base_.in_[THIRD_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate->origin_bias_); + convolution_delegate->need_free_bias_ = false; + return NNACL_OK; + } + + convolution_delegate->origin_bias_ = ConvolutionDelegateCopyData(convolution_delegate->conv_.base_.in_[THIRD_INPUT]); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(convolution_delegate->origin_bias_); + convolution_delegate->need_free_bias_ = true; + return NNACL_OK; +} + +int ConvolutionDelegateGetWeightAndBias(ConvolutionDelegateStruct *convolution_delegate) { + int ret = ConvolutionDelegateGetWeightData(convolution_delegate); + if (ret != NNACL_OK) { + return ret; + } + + return ConvolutionDelegateGetBiasData(convolution_delegate); +} + +ConvolutionBaseStruct *ConvolutionDelegateConvNC4KernelSelect(ConvolutionDelegateStruct *convolution_delegate) { + /* runtime nc4hw4 pass + * arm64: conv1x1 conv_Im2col support nc4 + * Avx: conv_Im2col support nc4 + * */ + ConvParameter *conv_param = (ConvParameter *)convolution_delegate->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_NULL(conv_param); + +#ifdef ENABLE_ARM64 + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + ConvolutionBaseStruct *conv1x1 = CreateConvolution1x1(conv_param); + return conv1x1; + } +#endif + +#if defined(ENABLE_ARM64) || defined(ENABLE_AVX) + ConvolutionBaseStruct *conv_im2col = CreateConvolutionIm2Col(&convolution_delegate->conv_.base_, conv_param); + return conv_im2col; +#endif + + return NULL; +} + +ConvolutionBaseStruct *ConvolutionDelegateConvNHWCKernelSelect(ConvolutionDelegateStruct *convolution_delegate) { + ConvParameter *conv_param = (ConvParameter *)convolution_delegate->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_NULL(conv_param); + + ConvolutionBaseStruct *conv = NULL; + + int out_unit; + if (CheckIfUseWinograd(&out_unit, conv_param)) { + conv = CreateConvolutionWinograd(conv_param, out_unit); + } + +#ifdef ENABLE_AVX + if (conv == NULL && CheckAvxUseSW1x1Conv(conv_param)) { + conv = CreateConvolutionSW1x1(conv_param, convolution_delegate->input_const_, convolution_delegate->weight_const_); + } + + if (conv == NULL && CheckAvxUseSWConv(conv_param, convolution_delegate->conv_.base_.thread_nr_)) { + conv = CreateConvolutionSWAVX(conv_param); + } +#endif + +#ifdef ENABLE_ARM64 + if (conv == NULL && CheckArm64UseSWConv(conv_param)) { + conv = CreateConvolutionSWARM64(conv_param); + } +#endif + + if (conv == NULL) { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + conv = CreateConvolution1x1(conv_param); + } else { + conv = CreateConvolutionIm2Col(&convolution_delegate->conv_.base_, conv_param); + } + } + return conv; +} + +ConvolutionBaseStruct *ConvolutionDelegateConvolutionSelect(ConvolutionDelegateStruct *convolution_delegate) { + ConvolutionBaseStruct *conv; + if (convolution_delegate->conv_.base_.out_[OUTPUT_INDEX]->format_ == Format_NC4HW4) { + conv = ConvolutionDelegateConvNC4KernelSelect(convolution_delegate); + } else { + conv = ConvolutionDelegateConvNHWCKernelSelect(convolution_delegate); + } + if (conv == NULL) { + return NULL; + } + + conv->base_.InferShape = convolution_delegate->conv_.base_.InferShape; + conv->base_.UpdateThread = convolution_delegate->conv_.base_.UpdateThread; + conv->base_.env_ = convolution_delegate->conv_.base_.env_; + conv->base_.param_ = convolution_delegate->conv_.base_.param_; + conv->base_.thread_nr_ = convolution_delegate->conv_.base_.thread_nr_; + conv->base_.train_session_ = convolution_delegate->conv_.base_.train_session_; + conv->base_.in_ = convolution_delegate->conv_.base_.in_; + conv->base_.in_size_ = convolution_delegate->conv_.base_.in_size_; + conv->base_.out_ = convolution_delegate->conv_.base_.out_; + conv->base_.out_size_ = convolution_delegate->conv_.base_.out_size_; + + conv->infershape_done_ = convolution_delegate->conv_.infershape_done_; + conv->shaing_manager_ = convolution_delegate->conv_.shaing_manager_; + conv->get_sharing_weight_ = convolution_delegate->conv_.get_sharing_weight_; + conv->free_sharing_weight_ = convolution_delegate->conv_.free_sharing_weight_; + conv->is_sharing_pack_ = convolution_delegate->conv_.is_sharing_pack_; + + conv->origin_weight_ = convolution_delegate->origin_weight_; + conv->origin_bias_ = convolution_delegate->origin_bias_; + return conv; +} + +void ConvolutionDelegateFreeCopiedData(ConvolutionDelegateStruct *convolution_delegate) { + if (convolution_delegate->origin_weight_ != NULL && convolution_delegate->need_free_weight_) { + free(convolution_delegate->origin_weight_); + } + convolution_delegate->origin_weight_ = NULL; + convolution_delegate->conv_.origin_weight_ = NULL; + convolution_delegate->need_free_weight_ = false; + + if (convolution_delegate->origin_bias_ != NULL && convolution_delegate->need_free_bias_) { + free(convolution_delegate->origin_bias_); + } + convolution_delegate->origin_bias_ = NULL; + convolution_delegate->conv_.origin_bias_ = NULL; + convolution_delegate->need_free_bias_ = false; +} + +int ConvolutionDelegateResize(struct KernelBase *self) { + ConvolutionDelegateStruct *convolution_delegate = (ConvolutionDelegateStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate); + + if (convolution_delegate->convolution_ == NULL) { + convolution_delegate->convolution_ = ConvolutionDelegateConvolutionSelect(convolution_delegate); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(convolution_delegate->convolution_); + (void)ConvBaseUpdateComputeInfo(convolution_delegate->convolution_); + int ret = convolution_delegate->convolution_->base_.Prepare(&convolution_delegate->convolution_->base_); + if (ret != NNACL_OK) { + return ret; + } + } + + (void)ConvBaseUpdateComputeInfo(convolution_delegate->convolution_); + int ret = convolution_delegate->convolution_->base_.Resize(&convolution_delegate->convolution_->base_); + if (ret != NNACL_OK) { + return ret; + } + + ConvolutionDelegateFreeCopiedData(convolution_delegate); + return NNACL_OK; +} + +int ConvolutionDelegatePrepare(struct KernelBase *self) { + ConvolutionDelegateStruct *convolution_delegate = (ConvolutionDelegateStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate); + + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[SECOND_INPUT]); + + NNACL_CHECK_FALSE(self->in_[SECOND_INPUT]->data_type_ != kNumberTypeFloat32 && + self->in_[SECOND_INPUT]->data_type_ != kNumberTypeFloat16, + NNACL_CONVOLUTION_WEIGHT_DATATYPE_INVALID); + NNACL_CHECK_FALSE(self->in_size_ == THREE_TENSOR && self->in_[THIRD_INPUT] != NULL && + self->in_[THIRD_INPUT]->data_type_ != kNumberTypeFloat32, + NNACL_CONVOLUTION_BIAS_DATATYPE_INVALID); + + convolution_delegate->input_const_ = NNACLIsConst(self->in_[FIRST_INPUT]) && !self->train_session_; + convolution_delegate->weight_const_ = NNACLIsConst(self->in_[SECOND_INPUT]) && !self->train_session_; + + return ConvolutionDelegateGetWeightAndBias(convolution_delegate); +} + +int ConvolutionDelegateRelease(struct KernelBase *self) { + ConvolutionDelegateStruct *convolution_delegate = (ConvolutionDelegateStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate); + if (convolution_delegate->convolution_ != NULL) { + (void)convolution_delegate->convolution_->base_.Release(&convolution_delegate->convolution_->base_); + free(convolution_delegate->convolution_); + convolution_delegate->convolution_ = NULL; + } + if (convolution_delegate->need_free_weight_ && convolution_delegate->origin_weight_ != NULL) { + free(convolution_delegate->origin_weight_); + convolution_delegate->origin_weight_ = NULL; + } + if (convolution_delegate->need_free_bias_ && convolution_delegate->origin_bias_ != NULL) { + free(convolution_delegate->origin_bias_); + convolution_delegate->origin_bias_ = NULL; + } + return NNACL_OK; +} + +int ConvolutionDelegateCompute(struct KernelBase *self) { + ConvolutionDelegateStruct *convolution_delegate = (ConvolutionDelegateStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate); + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate->convolution_); + + convolution_delegate->convolution_->base_.workspace_ = convolution_delegate->conv_.base_.workspace_; + return convolution_delegate->convolution_->base_.Compute(&convolution_delegate->convolution_->base_); +} + +KernelBase *CreateConvlutionDelegate(ConvParameter *conv_param) { + ConvolutionDelegateStruct *delegate = (ConvolutionDelegateStruct *)malloc(sizeof(ConvolutionDelegateStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(delegate); + memset(delegate, 0, sizeof(ConvolutionDelegateStruct)); + delegate->conv_.base_.Prepare = ConvolutionDelegatePrepare; + delegate->conv_.base_.Resize = ConvolutionDelegateResize; + delegate->conv_.base_.Release = ConvolutionDelegateRelease; + delegate->conv_.base_.Compute = ConvolutionDelegateCompute; + return (KernelBase *)delegate; +} + +KernelBase *CreateConvolutionDepthwise(ConvParameter *conv_param) { + NNACL_CHECK_NULL_RETURN_NULL(conv_param); + KernelBase *kernel = NULL; + + if (conv_param->dynamic_shape_) { + kernel = CreateConvDw(conv_param); + if (kernel != NULL) { + return kernel; + } + } + +#ifdef ENABLE_AVX + kernel = CreateConvDwSWAVX(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) + if (CheckConvDw1DWinograd(conv_param, conv_param->thread_num_)) { + kernel = CreateConvDw3x3(conv_param); + if (kernel != NULL) { + return kernel; + } + } +#endif + +#ifdef ENABLE_ARM64 + if (CheckConvDwUseIndirectBuffer(conv_param)) { + kernel = CreateConvDwIndirect(conv_param); + if (kernel != NULL) { + return kernel; + } + } +#endif + + if (conv_param->input_channel_ < MaxDwConvSWSize) { + kernel = CreateConvDwSW(conv_param); + if (kernel != NULL) { + return kernel; + } + } + + kernel = CreateConvDw(conv_param); + return kernel; +} + +KernelBase *CreateConv2DFusion(OpParameter *param, int data_type) { + ConvParameter *conv_param = (ConvParameter *)param; + conv_param->thread_num_ = param->thread_num_; + KernelBase *kernel; + if (conv_param->group_ == 1) { + kernel = CreateConvlutionDelegate(conv_param); + } else if (conv_param->group_ == conv_param->input_channel_ && conv_param->group_ == conv_param->output_channel_) { + kernel = CreateConvolutionDepthwise(conv_param); + } else { + kernel = CreateGroupConvolution(conv_param, data_type); + } + + if (kernel == NULL) { + return NULL; + } + + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)kernel; + (void)ConvBaseUpdateParamInfo(&conv->compute_, conv_param); + + return kernel; +} + +REG_KERNEL_CREATOR(PrimType_Conv2DFusion, kNumberTypeFloat32, CreateConv2DFusion) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_delegate.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_delegate.h new file mode 100644 index 0000000000000000000000000000000000000000..0f8ebd83bb49cfee152dfbf1a82ccef3ca791709 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_delegate.h @@ -0,0 +1,39 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_CONVOLUTION_DELEGATE_H_ +#define NNACL_KERNEL_CONVOLUTION_DELEGATE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/conv_parameter.h" + +typedef struct ConvolutionDelegateStruct { + ConvolutionBaseStruct conv_; /* used for Conv2dFusion basic operator */ + ConvolutionBaseStruct *convolution_; /* real running conv */ + float *origin_weight_; + float *origin_bias_; + bool need_free_weight_; + bool need_free_bias_; + bool input_const_; + bool weight_const_; +} ConvolutionDelegateStruct; + +KernelBase *CreateConvlutionDelegate(ConvParameter *conv_param); +KernelBase *CreateConv2DFusion(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_CONVOLUTION_DELEGATE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise.c new file mode 100644 index 0000000000000000000000000000000000000000..fbcb1761d3a5fa2317417b62438eca815ef3fe7f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise.c @@ -0,0 +1,229 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_depthwise.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#ifdef ENABLE_AVX512 +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" +#endif +#include "nnacl_c/fp32/conv_depthwise_avx_fp32.h" + +int ConvDwRun(void *cdata, int task_id, float l, float r) { + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + ConvParameter *conv_param = (ConvParameter *)conv_dw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + +#ifdef ENABLE_AVX512 + if (X86_Avx512_Support()) { + return ConvDwAVX512(conv_dw->output_ptr_, conv_dw->input_ptr_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, task_id, &conv_dw->dw_param_); + } else { + return ConvDwAVX(conv_dw->output_ptr_, conv_dw->input_ptr_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, task_id, &conv_dw->dw_param_); + } +#endif + +#ifdef ENABLE_AVX + return ConvDwAVX(conv_dw->output_ptr_, conv_dw->input_ptr_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, task_id, &conv_dw->dw_param_); +#endif + + return ConvDw(conv_dw->output_ptr_, conv_dw->input_ptr_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, task_id); +} + +void ConvDwReleaseParam(ConvolutionDepthwiseStruct *conv_dw) { + ExecEnv *env = conv_dw->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_VOID(env); + + if (conv_dw->dw_param_.num_pixels_ != NULL) { + env->Free(env->allocator_, conv_dw->dw_param_.num_pixels_); + conv_dw->dw_param_.num_pixels_ = NULL; + } + if (conv_dw->dw_param_.out_w_start_ != NULL) { + env->Free(env->allocator_, conv_dw->dw_param_.out_w_start_); + conv_dw->dw_param_.out_w_start_ = NULL; + } + if (conv_dw->dw_param_.out_w_end_ != NULL) { + env->Free(env->allocator_, conv_dw->dw_param_.out_w_end_); + conv_dw->dw_param_.out_w_end_ = NULL; + } +} + +void ConvDwPackWeight(ConvolutionBaseStruct *conv) { + void *origin_data = conv->base_.in_[SECOND_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_VOID(origin_data); + PackWeightKHWToHWKFp32(origin_data, conv->packed_weight_, conv->compute_.kernel_hw_, conv->compute_.out_c_); +} + +int ConvDwMallocWeightBiasData(ConvolutionBaseStruct *conv) { + TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(weight_tensor); + + int pack_weight_size = conv->compute_.kernel_hw_ * conv->compute_.out_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(pack_weight_size, sizeof(float), NNACL_ERR); + + if (!conv->base_.train_session_) { + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + NNACL_CHECK_MALLOC_SIZE(conv->compute_.out_c_ * sizeof(float)); + if (conv->bias_data_ == NULL) { + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, conv->compute_.out_c_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, conv->compute_.out_c_ * sizeof(float)); + return NNACL_OK; +} + +int ConvDwInitConvDwCalcInfo(ConvolutionDepthwiseStruct *conv_dw) { + ExecEnv *env = conv_dw->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + ConvComputeParam *compute = &conv_dw->conv_.compute_; + NNACL_CHECK_NULL_RETURN_ERR(compute); + + ConvDwReleaseParam(conv_dw); + + conv_dw->dw_param_.num_pixels_ = env->Alloc(env->allocator_, compute->kernel_w_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.num_pixels_); + + conv_dw->dw_param_.out_w_start_ = env->Alloc(env->allocator_, compute->kernel_w_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.out_w_start_); + + conv_dw->dw_param_.out_w_end_ = env->Alloc(env->allocator_, compute->kernel_w_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.out_w_end_); + + int *num_pixels = (int *)(conv_dw->dw_param_.num_pixels_); + int *out_w_start = (int *)(conv_dw->dw_param_.out_w_start_); + int *out_w_end = (int *)(conv_dw->dw_param_.out_w_end_); + conv_dw->dw_param_.first_calc_kw_ = -1; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->dilation_w_, (compute->kernel_w_ - 1), NNACL_ERR); + for (int kw = 0; kw < compute->kernel_w_; kw++) { + out_w_start[kw] = + NNACL_MAX(0, (compute->pad_l_ - compute->dilation_w_ * kw + compute->stride_w_ - 1) / compute->stride_w_); + + out_w_end[kw] = NNACL_MIN( + (compute->in_w_ + compute->pad_l_ - compute->dilation_w_ * kw + compute->stride_w_ - 1) / compute->stride_w_, + compute->out_w_); + + num_pixels[kw] = out_w_end[kw] - out_w_start[kw]; + if (conv_dw->dw_param_.first_calc_kw_ == -1 && out_w_start[kw] == 0 && num_pixels[kw] == compute->out_w_) { + conv_dw->dw_param_.first_calc_kw_ = kw; + } + } + return NNACL_OK; +} + +int ConvolutionDepthwisePrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + ConvBaseUpdateOriginWeightAndBias(&conv_dw->conv_); + + if (self->train_session_) { + TensorC *weight_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(weight_tensor); + NNACL_CHECK_TRUE_RET(weight_tensor->shape_size_ == DIMENSION_4D, NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + + int weight_size_hw = NNACLGetHeight(weight_tensor) * NNACLGetWidth(weight_tensor); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(NNACLGetBatch(weight_tensor), weight_size_hw, NNACL_ERR); + int pack_weight_size = NNACLGetBatch(weight_tensor) * weight_size_hw; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(pack_weight_size, sizeof(float), NNACL_ERR); + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_dw->conv_); +} + +int ConvolutionDepthwiseCompute(KernelBase *self) { + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int ret = ConvBaseRepackWeight(&conv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + conv_dw->input_ptr_ = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->input_ptr_); + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + conv_dw->output_ptr_ = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->output_ptr_); + + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.num_pixels_); + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.out_w_start_); + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.out_w_end_); + + return self->env_->ParallelLaunch(self->env_->thread_pool_, ConvDwRun, self, self->thread_nr_); +} + +int ConvolutionDepthwiseResize(KernelBase *self) { + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int ret = ConvBasePrepare(&conv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + self->thread_nr_ = NNACL_MIN(self->thread_nr_, conv_dw->conv_.compute_.out_h_); + NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_); + + ret = ConvDwInitConvDwCalcInfo(conv_dw); + if (ret != NNACL_OK) { + return ret; + } + + return NNACL_OK; +} + +int ConvolutionDepthwiseRelease(KernelBase *self) { + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + ConvDwReleaseParam(conv_dw); + + ConvBaseRelease(&conv_dw->conv_); + return NNACL_OK; +} + +KernelBase *CreateConvDw(ConvParameter *conv) { + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)malloc(sizeof(ConvolutionDepthwiseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_dw); + memset(conv_dw, 0, sizeof(ConvolutionDepthwiseStruct)); + + conv_dw->conv_.pack_weight_ = ConvDwPackWeight; + conv_dw->conv_.malloc_weight_bias_ = ConvDwMallocWeightBiasData; + conv_dw->conv_.base_.Prepare = ConvolutionDepthwisePrepare; + conv_dw->conv_.base_.Compute = ConvolutionDepthwiseCompute; + conv_dw->conv_.base_.Resize = ConvolutionDepthwiseResize; + conv_dw->conv_.base_.Release = ConvolutionDepthwiseRelease; + return (KernelBase *)conv_dw; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise.h new file mode 100644 index 0000000000000000000000000000000000000000..96af3331dde6861376dc993b82754124f496165a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_H_ +#define NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct ConvolutionDepthwiseStruct { + ConvolutionBaseStruct conv_; + ConvDwCalcParam dw_param_; + float *input_ptr_; + float *output_ptr_; +} ConvolutionDepthwiseStruct; + +int ConvolutionDepthwiseRelease(KernelBase *self); +KernelBase *CreateConvDw(ConvParameter *conv); + +#endif // NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_3x3.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_3x3.c new file mode 100644 index 0000000000000000000000000000000000000000..199ac49b728f22a771ac5a49782f0610d9288bab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_3x3.c @@ -0,0 +1,154 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +#include "nnacl_c/kernel/convolution_depthwise_3x3.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +int ConvDw3x3Run(void *cdata, int task_id, float l, float r) { + ConvolutionDepthwise3x3Struct *conv_dw = (ConvolutionDepthwise3x3Struct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int units = UP_DIV(conv_dw->conv_.compute_.out_w_, C2NUM); // F(2, 3) contains 2 conv units + int c4 = UP_ROUND(conv_dw->conv_.compute_.in_c_, C4NUM); + int c12c4_units = C12NUM * c4 * units; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(c12c4_units, task_id, NNACL_ERR); + float *buffer = conv_dw->buffer_ + c12c4_units * task_id; + NNACL_CHECK_ZERO_RETURN_ERR(conv_dw->conv_.base_.thread_nr_); + + int step_oh = UP_DIV(conv_dw->conv_.compute_.out_h_, conv_dw->conv_.base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(step_oh, task_id, NNACL_ERR); + int start_oh = step_oh * task_id; + int end_oh = MSMIN(start_oh + step_oh, conv_dw->conv_.compute_.out_h_); + + ConvParameter *conv_param = (ConvParameter *)conv_dw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + ConvDw3x3(conv_dw->output_ptr_, buffer, conv_dw->input_ptr_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, start_oh, end_oh); + return NNACL_OK; +} + +void ConvDw3x3PackWeight(ConvolutionBaseStruct *conv) { + void *origin_weight = (conv->base_.train_session_) ? conv->base_.in_[SECOND_INPUT]->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + PackWeightConvDw3x3Fp32((float *)origin_weight, (float *)conv->packed_weight_, conv->compute_.out_c_); +} + +int ConvDw3x3MallocWeightBiasData(ConvolutionBaseStruct *conv) { + int c4 = UP_ROUND(conv->compute_.out_c_, C4NUM); + if (!conv->base_.train_session_) { + if (conv->packed_weight_ == NULL) { + int pack_weight_size = c4 * C12NUM; + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + } + + if (conv->bias_data_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(c4 * sizeof(float)); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, c4 * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, c4 * sizeof(float)); + return NNACL_OK; +} + +int ConvolutionDepthwise3x3Resize(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + int ret = ConvBasePrepare(conv); + if (ret != NNACL_OK) { + return ret; + } + self->thread_nr_ = NNACL_MIN(self->thread_nr_, conv->compute_.out_h_); + return NNACL_OK; +} + +int ConvolutionDepthwise3x3Prepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionDepthwise3x3Struct *conv_dw = (ConvolutionDepthwise3x3Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + ConvBaseUpdateOriginWeightAndBias(&conv_dw->conv_); + + if (self->train_session_) { + int c4 = UP_ROUND(conv_dw->conv_.compute_.out_c_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(c4, C12NUM, NNACL_ERR); + int pack_weight_size = c4 * C12NUM; + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_dw->conv_); +} + +int ConvolutionDepthwise3x3Compute(KernelBase *self) { + ConvolutionDepthwise3x3Struct *conv_dw = (ConvolutionDepthwise3x3Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int units = UP_DIV(conv_dw->conv_.compute_.out_w_, C2NUM); // F(2, 3) contains 2 conv units + int c4 = UP_ROUND(conv_dw->conv_.compute_.in_c_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(C12NUM, c4, NNACL_ERR); + int c12c4 = C12NUM * c4; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(c12c4, units, NNACL_ERR); + int c12c4_units = c12c4 * units; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(c12c4_units, self->thread_nr_, NNACL_ERR); + int buffer_size = c12c4_units * self->thread_nr_; + + conv_dw->buffer_ = self->env_->Alloc(self->env_->allocator_, buffer_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->buffer_); + + int ret = ConvBaseRepackWeight(&conv_dw->conv_); + if (ret != NNACL_OK) { + self->env_->Free(self->env_->allocator_, conv_dw->buffer_); + return ret; + } + + conv_dw->input_ptr_ = self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->input_ptr_); + conv_dw->output_ptr_ = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->output_ptr_); + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvDw3x3Run, self, self->thread_nr_); + self->env_->Free(self->env_->allocator_, conv_dw->buffer_); + return ret; +} + +int ConvolutionDepthwise3x3Release(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + ConvBaseRelease(conv); + return NNACL_OK; +} + +KernelBase *CreateConvDw3x3(ConvParameter *conv_param) { + ConvolutionDepthwise3x3Struct *conv_dw = + (ConvolutionDepthwise3x3Struct *)malloc(sizeof(ConvolutionDepthwise3x3Struct)); + NNACL_CHECK_NULL_RETURN_NULL(conv_dw); + memset(conv_dw, 0, sizeof(ConvolutionDepthwise3x3Struct)); + conv_dw->conv_.pack_weight_ = ConvDw3x3PackWeight; + conv_dw->conv_.malloc_weight_bias_ = ConvDw3x3MallocWeightBiasData; + conv_dw->conv_.base_.Resize = ConvolutionDepthwise3x3Resize; + conv_dw->conv_.base_.Prepare = ConvolutionDepthwise3x3Prepare; + conv_dw->conv_.base_.Compute = ConvolutionDepthwise3x3Compute; + conv_dw->conv_.base_.Release = ConvolutionDepthwise3x3Release; + + return (KernelBase *)conv_dw; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_3x3.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_3x3.h new file mode 100644 index 0000000000000000000000000000000000000000..ecbd49ab98046e0cdaa2e3ac36cbfee2dd2530a1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_3x3.h @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_3X3_H_ +#define NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_3X3_H_ + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct ConvolutionDepthwise3x3Struct { + ConvolutionBaseStruct conv_; + float *buffer_; + float *input_ptr_; + float *output_ptr_; +} ConvolutionDepthwise3x3Struct; + +KernelBase *CreateConvDw3x3(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_3X3_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_indirect.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_indirect.c new file mode 100644 index 0000000000000000000000000000000000000000..5b1c6d5e75afa673f60d6f2682dca2eeb6587089 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_indirect.c @@ -0,0 +1,227 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_depthwise_indirect.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +int ConvDwIndirectMallocIndirectBuffer(ConvolutionDepthwiseIndirectStruct *conv_dw) { + ConvComputeParam *compute = &conv_dw->conv_.compute_; + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(compute); + + conv_dw->step_w_ = compute->dilation_w_ == 1 ? compute->stride_w_ : compute->kernel_w_; + int step_w_2d = conv_dw->step_w_ * compute->kernel_h_; + conv_dw->step_h_ = (compute->kernel_h_ * compute->kernel_w_) + (compute->out_w_ - 1) * step_w_2d; + int step_h_2d = compute->out_h_ * conv_dw->step_h_; + int buffer_size = compute->out_n_ * step_h_2d; + + ExecEnv *env = conv_dw->conv_.base_.env_; + conv_dw->indirect_buffer_ = (float **)(env->Alloc(env->allocator_, buffer_size * sizeof(float *))); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->indirect_buffer_); + return NNACL_OK; +} + +int ConvDwIndirectRun(void *cdata, int task_id, float l, float r) { + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)conv_dw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + ConvDwIndirection(conv_dw->output_ptr_, conv_dw->indirect_buffer_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_dw->zero_ptr_, conv_param, task_id); + return NNACL_OK; +} + +int ConvDwIndirectMallocPackedInput(ConvolutionDepthwiseIndirectStruct *conv_dw) { + int IC_DIV = UP_DIV(conv_dw->conv_.compute_.in_c_, conv_dw->div_flag_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->conv_.compute_.in_n_, conv_dw->conv_.compute_.in_hw_, NNACL_ERR); + int conv_input_bhw = conv_dw->conv_.compute_.in_n_ * conv_dw->conv_.compute_.in_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_input_bhw, conv_dw->div_flag_ * IC_DIV, NNACL_ERR); + int pack_input_size = conv_input_bhw * conv_dw->div_flag_ * IC_DIV; + conv_dw->packed_input_ = + conv_dw->conv_.base_.env_->Alloc(conv_dw->conv_.base_.env_->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->packed_input_); + return NNACL_OK; +} + +void ConvDwIndirectPackWeight(ConvolutionBaseStruct *conv) { + TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(weight_tensor); + void *origin_weight = (conv->base_.train_session_) ? weight_tensor->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + +#ifdef ENABLE_AVX + PackDepthwiseIndirectWeightC8Fp32(origin_weight, conv->packed_weight_, conv->compute_.kernel_h_, + conv->compute_.kernel_w_, conv->compute_.out_c_); +#else + PackDepthwiseIndirectWeightC4Fp32(origin_weight, conv->packed_weight_, conv->compute_.kernel_h_, + conv->compute_.kernel_w_, conv->compute_.out_c_); +#endif +} + +int ConvDwIndirectMallocWeightBiasData(ConvolutionBaseStruct *conv) { + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)conv; + ExecEnv *env = conv->base_.env_; + + int batch_flag = UP_DIV(conv->compute_.out_c_, conv_dw->div_flag_); + int pack_weight_size = conv_dw->div_flag_ * batch_flag * conv->compute_.kernel_hw_; + if (!conv->base_.train_session_) { + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + // malloc zero ptr + NNACL_CHECK_MALLOC_SIZE(batch_flag * conv_dw->div_flag_ * sizeof(float)); + conv_dw->zero_ptr_ = (float *)env->Alloc(env->allocator_, batch_flag * conv_dw->div_flag_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->zero_ptr_); + memset(conv_dw->zero_ptr_, 0, batch_flag * conv_dw->div_flag_ * sizeof(float)); + + // malloc bias ptr + if (conv->bias_data_ == NULL) { + conv->bias_data_ = env->Alloc(env->allocator_, batch_flag * conv_dw->div_flag_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, batch_flag * conv_dw->div_flag_ * sizeof(float)); + return NNACL_OK; +} + +int ConvolutionDepthwiseIndirectCompute(KernelBase *self) { + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + void *input_ptr = input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + + if (conv_dw->conv_.compute_.in_c_ % conv_dw->div_flag_ != 0) { + int ret = ConvDwIndirectMallocPackedInput(conv_dw); + if (ret != NNACL_OK) { + return ret; + } +#ifdef ENABLE_AVX + PackNHWCToNHWC8Fp32(input_ptr, conv_dw->packed_input_, conv_dw->conv_.compute_.in_n_, + conv_dw->conv_.compute_.in_hw_, conv_dw->conv_.compute_.in_c_); +#else + PackNHWCToNHWC4Fp32(input_ptr, conv_dw->packed_input_, conv_dw->conv_.compute_.in_n_, + conv_dw->conv_.compute_.in_hw_, conv_dw->conv_.compute_.in_c_); +#endif + } else { + conv_dw->packed_input_ = input_ptr; + } + + int ret = ConvBaseRepackWeight(&conv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + conv_dw->output_ptr_ = output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->output_ptr_); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + ConvDwInitIndirection(conv_dw->indirect_buffer_, conv_dw->packed_input_, conv_dw->zero_ptr_, conv_param, + conv_dw->step_h_, conv_dw->step_w_); + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvDwIndirectRun, self, self->thread_nr_); + + if (conv_dw->conv_.compute_.in_c_ % conv_dw->div_flag_ != 0) { + self->env_->Free(self->env_->allocator_, conv_dw->packed_input_); + } + return ret; +} +int ConvolutionDepthwiseIndirectResize(KernelBase *self) { + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + if (conv_dw->indirect_buffer_ != NULL) { + self->env_->Free(self->env_->allocator_, conv_dw->indirect_buffer_); + conv_dw->indirect_buffer_ = NULL; + } + + int ret = ConvBasePrepare(&conv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvDwIndirectMallocIndirectBuffer(conv_dw); + if (ret != NNACL_OK) { + return ret; + } + + self->thread_nr_ = NNACL_MIN(self->thread_nr_, conv_dw->conv_.compute_.out_h_); + NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_); + return NNACL_OK; +} + +int ConvolutionDepthwiseIndirectPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + ConvBaseUpdateOriginWeightAndBias(&conv_dw->conv_); + + if (self->train_session_) { + int batch_flag = UP_DIV(conv_dw->conv_.compute_.out_c_, conv_dw->div_flag_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->div_flag_ * batch_flag, conv_dw->conv_.compute_.kernel_hw_, NNACL_ERR); + int pack_weight_size = conv_dw->div_flag_ * batch_flag * conv_dw->conv_.compute_.kernel_hw_; + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_dw->conv_); +} + +int ConvolutionDepthwiseIndirectRelease(KernelBase *self) { + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + if (conv_dw->zero_ptr_ != NULL) { + self->env_->Free(self->env_->allocator_, conv_dw->zero_ptr_); + conv_dw->zero_ptr_ = NULL; + } + if (conv_dw->indirect_buffer_ != NULL) { + self->env_->Free(self->env_->allocator_, conv_dw->indirect_buffer_); + conv_dw->indirect_buffer_ = NULL; + } + ConvBaseRelease(&conv_dw->conv_); + return NNACL_OK; +} + +KernelBase *CreateConvDwIndirect(ConvParameter *conv_param) { + ConvolutionDepthwiseIndirectStruct *conv_dw = + (ConvolutionDepthwiseIndirectStruct *)malloc(sizeof(ConvolutionDepthwiseIndirectStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_dw); + memset(conv_dw, 0, sizeof(ConvolutionDepthwiseIndirectStruct)); + +#ifdef ENABLE_AVX + conv_dw->div_flag_ = C8NUM; +#else + conv_dw->div_flag_ = C4NUM; +#endif + conv_dw->conv_.pack_weight_ = ConvDwIndirectPackWeight; + conv_dw->conv_.malloc_weight_bias_ = ConvDwIndirectMallocWeightBiasData; + + conv_dw->conv_.base_.Compute = ConvolutionDepthwiseIndirectCompute; + conv_dw->conv_.base_.Resize = ConvolutionDepthwiseIndirectResize; + conv_dw->conv_.base_.Prepare = ConvolutionDepthwiseIndirectPrepare; + conv_dw->conv_.base_.Release = ConvolutionDepthwiseIndirectRelease; + + return (KernelBase *)conv_dw; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_indirect.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_indirect.h new file mode 100644 index 0000000000000000000000000000000000000000..0008c7725786fe9c35032f52977cbe9ae85db692 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_indirect.h @@ -0,0 +1,39 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_INDIRECT_H_ +#define NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_INDIRECT_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct ConvolutionDepthwiseIndirectStruct { + ConvolutionBaseStruct conv_; + int div_flag_; + int step_w_; + int step_h_; + float *zero_ptr_; + float *output_ptr_; + float *packed_input_; + float **indirect_buffer_; +} ConvolutionDepthwiseIndirectStruct; + +KernelBase *CreateConvDwIndirect(ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_INDIRECT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw.c new file mode 100644 index 0000000000000000000000000000000000000000..1caf86f9d3da2fdb84aaca96ef95bddb614fed58 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw.c @@ -0,0 +1,200 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_depthwise_sw.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +int ConvDwSWMallocWeightBiasData(ConvolutionBaseStruct *conv) { + int OC4 = UP_DIV(conv->compute_.out_c_, C4NUM); + int pack_weight_size = C4NUM * OC4 * conv->compute_.kernel_hw_; + if (!conv->base_.train_session_) { + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + int malloc_size = NNACL_MAX(conv->compute_.out_c_, C4NUM * OC4); + if (conv->bias_data_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(malloc_size * sizeof(float)); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, malloc_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, malloc_size * sizeof(float)); + conv->base_.thread_nr_ = NNACL_MIN(conv->base_.thread_nr_, OC4); + return NNACL_OK; +} + +int ConvDwSWInitPackedInputOutput(ConvolutionDepthwiseSWStruct *conv_dw) { + if (conv_dw->conv_.compute_.in_c_ % C4NUM == 0) { + conv_dw->need_align_ = false; + return NNACL_OK; + } + + conv_dw->need_align_ = true; + int IC4 = UP_DIV(conv_dw->conv_.compute_.in_c_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->conv_.compute_.in_n_, conv_dw->conv_.compute_.in_hw_, NNACL_ERR); + int conv_input_bhw = conv_dw->conv_.compute_.in_n_ * conv_dw->conv_.compute_.in_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_input_bhw, C4NUM * IC4, NNACL_ERR); + int pack_input_size = conv_input_bhw * C4NUM * IC4; + NNACL_CHECK_MALLOC_SIZE(pack_input_size * sizeof(float)); + conv_dw->packed_input_ = + (float *)conv_dw->conv_.base_.env_->Alloc(conv_dw->conv_.base_.env_->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->packed_input_); + + int OC4 = UP_DIV(conv_dw->conv_.compute_.out_c_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->conv_.compute_.out_n_, conv_dw->conv_.compute_.out_hw_, NNACL_ERR); + int output_bhw = conv_dw->conv_.compute_.out_n_ * conv_dw->conv_.compute_.out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, C4NUM * OC4, NNACL_ERR); + int pack_output_size = output_bhw * C4NUM * OC4; + NNACL_CHECK_MALLOC_SIZE(pack_output_size * sizeof(float)); + conv_dw->packed_output_ = + (float *)conv_dw->conv_.base_.env_->Alloc(conv_dw->conv_.base_.env_->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->packed_output_); + return NNACL_OK; +} + +int ConvDwSWRun(void *cdata, int task_id, float l, float r) { + ConvolutionDepthwiseSWStruct *conv_dw = (ConvolutionDepthwiseSWStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)conv_dw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + ConvDwSWFp32(conv_dw->packed_output_, conv_dw->packed_input_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, &conv_dw->sliding_, task_id); + return NNACL_OK; +} + +void ConvDwSWFreePackedInputOutput(ConvolutionDepthwiseSWStruct *conv_dw) { + if (conv_dw->need_align_) { + conv_dw->conv_.base_.env_->Free(conv_dw->conv_.base_.env_->allocator_, conv_dw->packed_input_); + conv_dw->packed_input_ = NULL; + conv_dw->conv_.base_.env_->Free(conv_dw->conv_.base_.env_->allocator_, conv_dw->packed_output_); + conv_dw->packed_output_ = NULL; + } +} + +void ConvDwSWPackWeight(ConvolutionBaseStruct *conv) { + void *origin_weight = (conv->base_.train_session_) ? conv->base_.in_[SECOND_INPUT]->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + PackNCHWToNC4HW4Fp32(origin_weight, conv->packed_weight_, 1, conv->compute_.kernel_hw_, conv->compute_.out_c_); +} + +int ConvolutionDepthwiseSWResize(KernelBase *self) { + ConvolutionDepthwiseSWStruct *conv_dw = (ConvolutionDepthwiseSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + int ret = ConvBasePrepare(&conv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + InitSlidingParamConvDw(&conv_dw->sliding_, conv_param, C4NUM); + + self->thread_nr_ = NNACL_MIN(self->thread_nr_, conv_dw->conv_.compute_.out_h_); + NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_); + return NNACL_OK; +} + +int ConvolutionDepthwiseSWCompute(KernelBase *self) { + ConvolutionDepthwiseSWStruct *conv_dw = (ConvolutionDepthwiseSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int ret = ConvDwSWInitPackedInputOutput(conv_dw); + if (ret != NNACL_OK) { + ConvDwSWFreePackedInputOutput(conv_dw); + return ret; + } + + ret = ConvBaseRepackWeight(&conv_dw->conv_); + if (ret != NNACL_OK) { + ConvDwSWFreePackedInputOutput(conv_dw); + return ret; + } + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + float *input_ptr = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + if (conv_dw->need_align_) { + PackNHWCToNHWC4Fp32(input_ptr, conv_dw->packed_input_, conv_dw->conv_.compute_.in_n_, + conv_dw->conv_.compute_.in_hw_, conv_dw->conv_.compute_.in_c_); + } else { + conv_dw->packed_input_ = input_ptr; + } + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + float *output_ptr = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + if (!conv_dw->need_align_) { + conv_dw->packed_output_ = output_ptr; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvDwSWRun, self, self->thread_nr_); + + if (conv_dw->need_align_) { + PackNHWCXToNHWCFp32(conv_dw->packed_output_, output_ptr, conv_dw->conv_.compute_.out_n_, + conv_dw->conv_.compute_.out_hw_, conv_dw->conv_.compute_.out_c_, C4NUM); + } + + ConvDwSWFreePackedInputOutput(conv_dw); + return ret; +} + +int ConvolutionDdepthwiseSWPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionDepthwiseSWStruct *conv_dw = (ConvolutionDepthwiseSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + ConvBaseUpdateOriginWeightAndBias(&conv_dw->conv_); + + if (self->train_session_) { + int OC4 = UP_DIV(conv_dw->conv_.compute_.out_c_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(C4NUM * OC4, conv_dw->conv_.compute_.kernel_hw_, NNACL_ERR); + int pack_weight_size = C4NUM * OC4 * conv_dw->conv_.compute_.kernel_hw_; + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_dw->conv_); +} + +int ConvolutionDepthwiseSWRelease(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvBaseRelease(conv); + return NNACL_OK; +} + +KernelBase *CreateConvDwSW(ConvParameter *conv_param) { + ConvolutionDepthwiseSWStruct *conv_dw = (ConvolutionDepthwiseSWStruct *)malloc(sizeof(ConvolutionDepthwiseSWStruct)); + NNACL_CHECK_NULL_RETURN_NULL(conv_dw); + memset(conv_dw, 0, sizeof(ConvolutionDepthwiseSWStruct)); + + conv_dw->conv_.malloc_weight_bias_ = ConvDwSWMallocWeightBiasData; + conv_dw->conv_.pack_weight_ = ConvDwSWPackWeight; + conv_dw->conv_.base_.Resize = ConvolutionDepthwiseSWResize; + conv_dw->conv_.base_.Compute = ConvolutionDepthwiseSWCompute; + conv_dw->conv_.base_.Prepare = ConvolutionDdepthwiseSWPrepare; + conv_dw->conv_.base_.Release = ConvolutionDepthwiseSWRelease; + return (KernelBase *)conv_dw; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw.h new file mode 100644 index 0000000000000000000000000000000000000000..a7d6819ee56c42c4f10e4dff6f07a737cd94c0f8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_H_ +#define NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct ConvolutionDepthwiseSWStruct { + ConvolutionBaseStruct conv_; + SlidingWindowParam sliding_; + float *packed_input_; + float *packed_output_; + bool need_align_; +} ConvolutionDepthwiseSWStruct; + +KernelBase *CreateConvDwSW(ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw_avx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw_avx.c new file mode 100644 index 0000000000000000000000000000000000000000..84c830bc65f798dfea838fa14c59856de5885875 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw_avx.c @@ -0,0 +1,216 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_depthwise_sw_avx.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/tensor_c.h" + +int ConvDwSWAVXInitPackedInputOutput(ConvolutionDepthwiseSWAVXStruct *conv_dw) { + conv_dw->input_need_align_ = (conv_dw->conv_.compute_.in_c_ % conv_dw->oc_tile_ != 0); + conv_dw->output_need_align_ = (conv_dw->conv_.compute_.out_c_ % conv_dw->oc_tile_ != 0); + + ExecEnv *env = conv_dw->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + if (conv_dw->input_need_align_) { + int ic_algin = UP_DIV(conv_dw->conv_.compute_.in_c_, conv_dw->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->conv_.compute_.in_n_, conv_dw->conv_.compute_.in_hw_, NNACL_ERR); + int input_bhw = conv_dw->conv_.compute_.in_n_ * conv_dw->conv_.compute_.in_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input_bhw, conv_dw->oc_tile_ * ic_algin, NNACL_ERR); + int pack_input_size = input_bhw * conv_dw->oc_tile_ * ic_algin; + conv_dw->packed_input_ = (float *)env->Alloc(env->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->packed_input_); + } + + if (conv_dw->output_need_align_) { + int oc_algin = UP_DIV(conv_dw->conv_.compute_.out_c_, conv_dw->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->conv_.compute_.out_n_, conv_dw->conv_.compute_.out_hw_, NNACL_ERR); + int output_bhw = conv_dw->conv_.compute_.out_n_ * conv_dw->conv_.compute_.out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, conv_dw->oc_tile_ * oc_algin, NNACL_ERR); + int pack_output_size = output_bhw * conv_dw->oc_tile_ * oc_algin; + conv_dw->packed_output_ = (float *)env->Alloc(env->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->packed_output_); + } + + return NNACL_OK; +} + +void ConvDwSWAVXPackWeight(ConvolutionBaseStruct *conv) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(conv_dw); + + int oc_algin = UP_DIV(conv->compute_.out_c_, conv_dw->oc_tile_); + void *origin_weight = conv->base_.train_session_ ? conv->base_.in_[SECOND_INPUT]->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + + PackNHWCToNXHWCXFp32(conv->compute_.kernel_h_, conv->compute_.kernel_w_, conv->compute_.out_c_, oc_algin, 1, + (float *)conv->packed_weight_, (float *)conv->origin_weight_); +} + +int ConvDwSWAVXMallocWeightBiasData(ConvolutionBaseStruct *conv) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int oc_algin = UP_DIV(conv->compute_.out_c_, conv_dw->oc_tile_); + int pack_weight_size = oc_algin * conv_dw->oc_tile_ * conv->compute_.kernel_hw_; + + if (!conv->base_.train_session_) { + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + if (conv->base_.in_size_ == THREE_TENSOR) { + int bias_size = oc_algin * conv_dw->oc_tile_; + NNACL_CHECK_MALLOC_SIZE(bias_size * sizeof(float)); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, bias_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + memset(conv->bias_data_, 0, bias_size * sizeof(float)); + } + return NNACL_OK; +} + +int ConvDwSWAvxRun(void *cdata, int task_id, float l, float r) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)conv_dw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + DepthwiseSWAvxFp32(conv_dw->packed_output_, conv_dw->packed_input_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, &conv_dw->sliding_param_, task_id); + return NNACL_OK; +} + +void ConvDwSWAVXFreePackedInputOutput(ConvolutionDepthwiseSWAVXStruct *conv_dw) { + if (conv_dw->input_need_align_) { + conv_dw->conv_.base_.env_->Free(conv_dw->conv_.base_.env_->allocator_, conv_dw->packed_input_); + conv_dw->packed_input_ = NULL; + conv_dw->input_need_align_ = false; + } + if (conv_dw->output_need_align_) { + conv_dw->conv_.base_.env_->Free(conv_dw->conv_.base_.env_->allocator_, conv_dw->packed_output_); + conv_dw->packed_output_ = NULL; + conv_dw->output_need_align_ = false; + } +} + +int ConvolutionDepthwiseSWAVXCompute(KernelBase *self) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int ret = ConvDwSWAVXInitPackedInputOutput(conv_dw); + if (ret != NNACL_OK) { + ConvDwSWAVXFreePackedInputOutput(conv_dw); + return ret; + } + + ret = ConvBaseRepackWeight(&conv_dw->conv_); + if (ret != NNACL_OK) { + ConvDwSWAVXFreePackedInputOutput(conv_dw); + return ret; + } + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + float *input_ptr = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + + if (conv_dw->input_need_align_) { + PackNHWCToNHWCXFp32(input_ptr, conv_dw->packed_input_, conv_dw->conv_.compute_.in_n_, + conv_dw->conv_.compute_.in_hw_, conv_dw->conv_.compute_.in_c_, conv_dw->oc_tile_); + } else { + conv_dw->packed_input_ = input_ptr; + } + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + float *output_ptr = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + + if (!conv_dw->output_need_align_) { + conv_dw->packed_output_ = output_ptr; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvDwSWAvxRun, self, self->thread_nr_); + + if (conv_dw->output_need_align_) { + PackNHWCXToNHWCFp32(conv_dw->packed_output_, output_ptr, conv_dw->conv_.compute_.out_n_, + conv_dw->conv_.compute_.out_hw_, conv_dw->conv_.compute_.out_c_, conv_dw->oc_tile_); + } + + ConvDwSWAVXFreePackedInputOutput(conv_dw); + return ret; +} + +int ConvolutionDepthwiseSWAVXPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + conv_dw->oc_tile_ = C8NUM; + ConvBaseUpdateOriginWeightAndBias(&conv_dw->conv_); + + if (self->train_session_) { + int oc_algin = UP_DIV(conv_dw->conv_.compute_.out_c_, conv_dw->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(oc_algin * conv_dw->oc_tile_, conv_dw->conv_.compute_.kernel_hw_, NNACL_ERR); + int pack_weight_size = oc_algin * conv_dw->oc_tile_ * conv_dw->conv_.compute_.kernel_hw_; + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_dw->conv_); +} + +int ConvolutionDepthwiseSWAVXResize(KernelBase *self) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + ConvBasePrepare(&conv_dw->conv_); + + InitSlidingParamConvDw(&conv_dw->sliding_param_, conv_param, conv_dw->oc_tile_); + return NNACL_OK; +} + +int ConvolutionDepthwiseSWAVXRelease(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvBaseRelease(conv); + return NNACL_OK; +} + +KernelBase *CreateConvDwSWAVX(ConvParameter *conv_param) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = + (ConvolutionDepthwiseSWAVXStruct *)malloc(sizeof(ConvolutionDepthwiseSWAVXStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_dw); + memset(conv_dw, 0, sizeof(ConvolutionDepthwiseSWAVXStruct)); + + conv_dw->conv_.pack_weight_ = ConvDwSWAVXPackWeight; + conv_dw->conv_.malloc_weight_bias_ = ConvDwSWAVXMallocWeightBiasData; + + conv_dw->conv_.base_.Prepare = ConvolutionDepthwiseSWAVXPrepare; + conv_dw->conv_.base_.Compute = ConvolutionDepthwiseSWAVXCompute; + conv_dw->conv_.base_.Resize = ConvolutionDepthwiseSWAVXResize; + conv_dw->conv_.base_.Release = ConvolutionDepthwiseSWAVXRelease; + return (KernelBase *)conv_dw; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw_avx.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw_avx.h new file mode 100644 index 0000000000000000000000000000000000000000..8a76ccbdff41b8b440a4d02aa92f8851b09d3450 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw_avx.h @@ -0,0 +1,40 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_AVX_H_ +#define NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_AVX_H_ + +#ifdef ENABLE_AVX +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct ConvolutionDepthwiseSWAVXStruct { + ConvolutionBaseStruct conv_; + SlidingWindowParam sliding_param_; + int oc_tile_; + float *packed_input_; + float *packed_output_; + bool input_need_align_; + bool output_need_align_; +} ConvolutionDepthwiseSWAVXStruct; + +KernelBase *CreateConvDwSWAVX(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_AVX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col.c new file mode 100644 index 0000000000000000000000000000000000000000..482333cd617443e4002534fbfcaa4901bac4efaa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col.c @@ -0,0 +1,81 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_im2col.h" +#include "nnacl_c/kernel/convolution_im2col_base.h" +#ifdef ENABLE_ARM32 +#include "nnacl_c/kernel/convolution_im2col_arm32.h" +#endif +#ifdef ENABLE_ARM64 +#include "nnacl_c/kernel/convolution_im2col_arm64.h" +#endif +#ifdef ENABLE_SSE +#include "nnacl_c/kernel/convolution_im2col_sse.h" +#endif +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_im2col_avx.h" +#endif +#ifdef ENABLE_AVX512 +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/kernel/convolution_im2col_avx512.h" +#endif + +ConvolutionBaseStruct *CreateConvolutionIm2Col(KernelBase *base, ConvParameter *conv_param) { + ConvolutionBaseStruct *kernel = NULL; + +#ifdef ENABLE_AVX512 + FormatC out_format = base->out_[OUTPUT_INDEX]->format_; + if (out_format != Format_NC4HW4) { + AVX512_HARDWARE_SELF_AWARENESS_BEGIN; + kernel = CreateConvIm2ColAVX512(conv_param); + if (kernel != NULL) { + return kernel; + } + AVX512_HARDWARE_SELF_AWARENESS_END; + } +#endif + +#ifdef ENABLE_AVX + kernel = CreateConvIm2ColAVX(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_SSE + kernel = CreateConvIm2ColSSE(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_ARM64 + kernel = CreateConvIm2ColARM64(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_ARM32 + kernel = CreateConvIm2ColARM32(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + + kernel = CreateConvIm2ColBase(conv_param); + return kernel; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col.h new file mode 100644 index 0000000000000000000000000000000000000000..ab115d531eb88c6a59b5038206116160f046b8ab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +ConvolutionBaseStruct *CreateConvolutionIm2Col(KernelBase *base, ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm32.c new file mode 100644 index 0000000000000000000000000000000000000000..44f69f651ee267f2874e7a80c5e82883915fd3aa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm32.c @@ -0,0 +1,45 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/kernel/convolution_im2col_arm32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +void ConvIm2ColARM32InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C4NUM; + conv_im2col->row_tile_ = C12NUM; + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col4Major; +} + +ConvolutionBaseStruct *CreateConvIm2ColARM32(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColBaseInitTmpBuffer; + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColARM32InitGlobalVariable; + conv_im2col->conv_.run_impl_ = ConvIm2ColBaseRunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colBaseCompute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm32.h new file mode 100644 index 0000000000000000000000000000000000000000..c928273ead29050d2b777ffd220eae7d9f3bb259 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM32_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM32_H_ + +#ifdef ENABLE_ARM32 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_im2col_base.h" + +ConvolutionBaseStruct *CreateConvIm2ColARM32(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm64.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm64.c new file mode 100644 index 0000000000000000000000000000000000000000..96bd0c947f9dad8019390b164ba14c081d432a86 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm64.c @@ -0,0 +1,72 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM64 +#include "nnacl_c/kernel/convolution_im2col_arm64.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" + +void ConvIm2ColARM64InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C8NUM; + conv_im2col->row_tile_ = C12NUM; + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col8Major; +} + +int ConvIm2ColARM64RunImpl(struct ConvolutionBaseStruct *conv, int task_id) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + float *ori_input_data = (float *)conv->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(ori_input_data); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + if (conv->out_format_ != Format_NC4HW4) { + if (conv->use_batch_cut_flag_) { + ConvFp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, + conv_param); + } else { + ConvFp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, (float *)conv->bias_data_, + conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, conv_param); + } + } else { + ConvFp32OutNC4HW4(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, + conv_param); + } + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateConvIm2ColARM64(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColBaseInitTmpBuffer; + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColARM64InitGlobalVariable; + conv_im2col->conv_.run_impl_ = ConvIm2ColARM64RunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colBaseCompute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm64.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm64.h new file mode 100644 index 0000000000000000000000000000000000000000..92288bc7ba6e41217f7abb0c4103f4c12c287523 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm64.h @@ -0,0 +1,29 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM64_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM64_H_ + +#ifdef ENABLE_ARM64 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_im2col_base.h" + +ConvolutionBaseStruct *CreateConvIm2ColARM64(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM64_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx.c new file mode 100644 index 0000000000000000000000000000000000000000..2d1efc035b00d34a4c5db3f7afa6e2277eb15731 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx.c @@ -0,0 +1,151 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_im2col_avx.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" + +void ConvIm2ColAVXInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C16NUM; + conv_im2col->row_tile_ = C6NUM; + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col16Major; +} + +int ConvIm2ColAVXInitTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col) { + int kernel_chw = conv_im2col->conv_.compute_.kernel_hw_ * conv_im2col->conv_.compute_.in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, conv_im2col->conv_.base_.thread_nr_, NNACL_ERR); + int total_kernel_chw = kernel_chw * conv_im2col->conv_.base_.thread_nr_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_kernel_chw, conv_im2col->row_tile_, NNACL_ERR); + int unit_size = total_kernel_chw * conv_im2col->row_tile_; + + ExecEnv *env = conv_im2col->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + if (conv_im2col->packed_input_ != NULL) { + env->Free(env->allocator_, conv_im2col->packed_input_); + conv_im2col->packed_input_ = NULL; + } + conv_im2col->packed_input_ = env->Alloc(env->allocator_, unit_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->packed_input_); + + if (conv_im2col->col_major_input_ != NULL) { + env->Free(env->allocator_, conv_im2col->col_major_input_); + conv_im2col->col_major_input_ = NULL; + } + conv_im2col->col_major_input_ = env->Alloc(env->allocator_, unit_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->col_major_input_); + + conv_im2col->output_need_align_ = + conv_im2col->conv_.compute_.out_c_ % conv_im2col->oc_tile_ != 0 && conv_im2col->conv_.out_format_ == Format_NC4HW4; + if (conv_im2col->output_need_align_) { + int oc_algin = UP_DIV(conv_im2col->conv_.compute_.out_c_, conv_im2col->oc_tile_); + int output_bhw = conv_im2col->conv_.compute_.out_n_ * conv_im2col->conv_.compute_.out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, conv_im2col->oc_tile_ * oc_algin, NNACL_ERR); + int pack_output_size = output_bhw * conv_im2col->oc_tile_ * oc_algin; + + if (conv_im2col->tmp_output_ != NULL) { + env->Free(env->allocator_, conv_im2col->tmp_output_); + conv_im2col->tmp_output_ = NULL; + } + conv_im2col->tmp_output_ = env->Alloc(env->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->tmp_output_); + } + return NNACL_OK; +} + +int ConvIm2ColAVXRunImpl(struct ConvolutionBaseStruct *conv, int task_id) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + float *ori_input_data = conv->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(ori_input_data); + + if (conv->out_format_ != Format_NC4HW4) { + if (conv->use_batch_cut_flag_) { + ConvFp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, + conv_param); + } else { + ConvFp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, (float *)conv->bias_data_, + conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, conv_param); + } + } else { + ConvFp32OutNC4HW4(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, + conv_param); + } + return NNACL_OK; +} + +int ConvolutionIm2colAvxCompute(KernelBase *self) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + + int ret = conv_im2col->init_tmp_buffer_(conv_im2col); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + float *output_addr = (float *)self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_addr); + if (!conv_im2col->output_need_align_) { + conv_im2col->tmp_output_ = output_addr; + } + + ret = ConvBaseRepackWeight(&conv_im2col->conv_); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvIm2ColBaseImpl, self, self->thread_nr_); + + if (conv_im2col->output_need_align_) { + PackNC8HW8AlignedToNC8HW8NotAlignedFp32(conv_im2col->tmp_output_, output_addr, conv_im2col->conv_.compute_.out_n_, + conv_im2col->conv_.compute_.out_w_ * conv_im2col->conv_.compute_.out_h_, + conv_im2col->conv_.compute_.out_c_); + } else { + conv_im2col->tmp_output_ = NULL; + } + + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; +} + +ConvolutionBaseStruct *CreateConvIm2ColAVX(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColAVXInitTmpBuffer; + + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColAVXInitGlobalVariable; + conv_im2col->conv_.run_impl_ = ConvIm2ColAVXRunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colAvxCompute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx.h new file mode 100644 index 0000000000000000000000000000000000000000..48e51e66024f93681f67253a5b3787c16130f40a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx.h @@ -0,0 +1,29 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX_H_ + +#ifdef ENABLE_AVX +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_im2col_base.h" + +ConvolutionBaseStruct *CreateConvIm2ColAVX(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx512.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx512.c new file mode 100644 index 0000000000000000000000000000000000000000..9b4a65ea748bc60077e0400cda8f012eceb9c850 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx512.c @@ -0,0 +1,146 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX512 +#include "nnacl_c/kernel/convolution_im2col_avx512.h" +#include "nnacl_c/fp32/conv_im2col_avx512_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/tensor_c.h" + +void ConvIm2ColAVX512InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C16NUM; + conv_im2col->row_tile_ = + MSMIN(UP_DIV(conv_im2col->conv_.compute_.out_hw_, conv_im2col->conv_.base_.thread_nr_), C150NUM); + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col64Major; +} + +int ConvIm2ColAVX512InitTmpBuffer(struct ConvolutionIm2ColBaseStruct *conv_im2col) { + ExecEnv *env = conv_im2col->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + ConvComputeParam *compute = &conv_im2col->conv_.compute_; + NNACL_CHECK_NULL_RETURN_ERR(compute); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->kernel_hw_, compute->in_c_, NNACL_ERR); + int kernel_chw = compute->kernel_hw_ * compute->in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, conv_im2col->conv_.base_.thread_nr_, NNACL_ERR); + int total_kernel_chw = kernel_chw * conv_im2col->conv_.base_.thread_nr_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_kernel_chw, conv_im2col->row_tile_, NNACL_ERR); + size_t unit_size = total_kernel_chw * conv_im2col->row_tile_; + + if (conv_im2col->packed_input_ != NULL) { + env->Free(env->allocator_, conv_im2col->packed_input_); + conv_im2col->packed_input_ = NULL; + } + conv_im2col->packed_input_ = env->Alloc(env->allocator_, unit_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->packed_input_); + + conv_im2col->output_need_align_ = compute->out_c_ % conv_im2col->oc_tile_ != 0; + if (conv_im2col->output_need_align_) { + if (conv_im2col->tmp_output_ != NULL) { + env->Free(env->allocator_, conv_im2col->tmp_output_); + conv_im2col->tmp_output_ = NULL; + } + + // avx512 need to malloc dst aligned to C16NUM + int oc_algin = UP_ROUND(compute->out_c_, conv_im2col->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_n_, compute->out_hw_, NNACL_ERR); + int output_bhw = compute->out_n_ * compute->out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, oc_algin, NNACL_ERR); + size_t pack_output_size = output_bhw * compute->out_w_ * oc_algin; + + conv_im2col->tmp_output_ = env->Alloc(env->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->tmp_output_); + } + + return NNACL_OK; +} + +int ConvIm2ColAVX512RunImpl(struct ConvolutionBaseStruct *conv, int task_id) { + if (conv->out_format_ == Format_NC4HW4) { + return NNACL_CONVOLUTION_AVX512_UNSUPPORT_FORMAT; + } + + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + float *ori_input_data = (float *)conv->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(ori_input_data); + + if (conv->use_batch_cut_flag_) { + ConvIm2ColAVX512Fp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->tmp_output_, task_id, conv_param, + conv_im2col->row_tile_); + } else { + ConvIm2ColAVX512Fp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->tmp_output_, task_id, conv_param, + conv_im2col->row_tile_); + } + return NNACL_OK; +} + +int ConvolutionIm2colAvx512Compute(KernelBase *self) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self; + int ret = conv_im2col->init_tmp_buffer_(conv_im2col); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + float *output_addr = (float *)self->out_[OUTPUT_INDEX]->data_; + if (!conv_im2col->output_need_align_) { + conv_im2col->tmp_output_ = output_addr; + } + + ret = ConvBaseRepackWeight(&conv_im2col->conv_); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvIm2ColBaseImpl, self, self->thread_nr_); + + if (conv_im2col->output_need_align_) { + PackNHWCXToNHWCFp32(conv_im2col->tmp_output_, output_addr, conv_im2col->conv_.compute_.out_n_, + conv_im2col->conv_.compute_.out_hw_, conv_im2col->conv_.compute_.out_c_, conv_im2col->oc_tile_); + } else { + conv_im2col->tmp_output_ = NULL; + } + + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; +} + +ConvolutionBaseStruct *CreateConvIm2ColAVX512(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColAVX512InitTmpBuffer; + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColAVX512InitGlobalVariable; + conv_im2col->conv_.run_impl_ = ConvIm2ColAVX512RunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colAvx512Compute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx512.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx512.h new file mode 100644 index 0000000000000000000000000000000000000000..8fde7b9f4bfaf0343c4fe43f52ccce408103c250 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx512.h @@ -0,0 +1,29 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX512_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX512_H_ + +#ifdef ENABLE_AVX512 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_im2col_base.h" + +ConvolutionBaseStruct *CreateConvIm2ColAVX512(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX512_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_base.c new file mode 100644 index 0000000000000000000000000000000000000000..36a9846e642536d402c35bd00b079dfd83ede9db --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_base.c @@ -0,0 +1,246 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_im2col_base.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" + +int ConvIm2ColBaseImpl(void *cdata, int task_id, float l, float r) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(cdata); + return conv->run_impl_(conv, task_id); +} + +int ConvIm2ColBaseRunImpl(ConvolutionBaseStruct *conv, int task_id) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + + float *ori_input_data = (float *)conv->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(ori_input_data); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + if (conv->use_batch_cut_flag_) { + ConvFp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, + conv_param); + } else { + ConvFp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, (float *)conv->bias_data_, + conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, conv_param); + } + return NNACL_OK; +} + +int ConvIm2ColBaseMallocWeightBiasData(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + + size_t oc_block_num = UP_ROUND(conv->compute_.out_c_, conv_im2col->oc_tile_); + size_t pack_weight_size = oc_block_num * conv->compute_.in_c_ * conv->compute_.kernel_hw_; + if (!conv->base_.train_session_) { + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + if (conv->bias_data_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(oc_block_num * sizeof(float)); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, oc_block_num * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, oc_block_num * sizeof(float)); + return NNACL_OK; +} + +int ConvIm2ColBaseUpdateThreadNumProcess(KernelBase *self, int32_t kernel_type, int64_t per_unit_load_num, + int64_t per_unit_store_num, int64_t unit_num) { +#ifdef DYNAMIC_THREAD_DISTRIBUTE + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + + if (conv_im2col->conv_.compute_.in_n_ % self->thread_nr_ == 0) { + conv_im2col->conv_.use_batch_cut_flag_ = true; + return NNACL_OK; + } else { + conv_im2col->conv_.use_batch_cut_flag_ = false; + } + + int update_thread = UP_DIV(UP_DIV(conv_im2col->conv_.compute_.out_hw_, conv_im2col->row_tile_), ConvMinBlock); + self->thread_nr_ = NNACL_MIN(self->thread_nr_, update_thread); +#else + self->thread_nr_ = self->thread_nr_ > 0 ? self->thread_nr_ : 1; +#endif + return NNACL_OK; +} + +void ConvIm2ColBaseFreeTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col) { + ExecEnv *env = conv_im2col->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_VOID(env); + + if (conv_im2col->packed_input_ != NULL) { + env->Free(env->allocator_, conv_im2col->packed_input_); + conv_im2col->packed_input_ = NULL; + } + if (conv_im2col->col_major_input_ != NULL) { + env->Free(env->allocator_, conv_im2col->col_major_input_); + conv_im2col->col_major_input_ = NULL; + } + if (conv_im2col->output_need_align_ && conv_im2col->tmp_output_ != NULL) { + env->Free(env->allocator_, conv_im2col->tmp_output_); + conv_im2col->tmp_output_ = NULL; + conv_im2col->output_need_align_ = false; + } +} + +int ConvIm2ColBaseInitTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)conv_im2col; + TensorC *out_tensor = conv_im2col->conv_.base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + NNACL_CHECK_NULL_RETURN_ERR(out_tensor->data_); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv->compute_.kernel_hw_, conv->compute_.in_c_, NNACL_ERR); + int kernel_chw = conv->compute_.kernel_hw_ * conv->compute_.in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, conv->base_.thread_nr_, NNACL_ERR); + int total_kernel_chw = kernel_chw * conv->base_.thread_nr_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_kernel_chw, conv_im2col->row_tile_, NNACL_ERR); + int unit_size = total_kernel_chw * conv_im2col->row_tile_; + + if (conv_im2col->packed_input_ != NULL) { + conv->base_.env_->Free(conv->base_.env_->allocator_, conv_im2col->packed_input_); + conv_im2col->packed_input_ = NULL; + } + conv_im2col->packed_input_ = + (float *)conv->base_.env_->Alloc(conv->base_.env_->allocator_, unit_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->packed_input_); + + if (conv_im2col->col_major_input_ != NULL) { + conv->base_.env_->Free(conv->base_.env_->allocator_, conv_im2col->col_major_input_); + conv_im2col->col_major_input_ = NULL; + } + conv_im2col->col_major_input_ = + (float *)conv->base_.env_->Alloc(conv->base_.env_->allocator_, unit_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->col_major_input_); + + return NNACL_OK; +} + +void ConvIm2ColBasePackWeight(ConvolutionBaseStruct *conv) { + void *origin_weight = (conv->base_.train_session_) ? conv->base_.in_[SECOND_INPUT]->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(conv_im2col->row_major_to_col_nmajor_); + conv_im2col->row_major_to_col_nmajor_((float *)origin_weight, (float *)conv->packed_weight_, conv->compute_.out_c_, + conv->compute_.in_c_ * conv->compute_.kernel_hw_); +} + +void ConvIm2ColBaseInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C8NUM; + conv_im2col->row_tile_ = C12NUM; + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col8Major; +} + +int ConvolutionIm2colBaseRelease(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvBaseRelease(conv); + return NNACL_OK; +} + +int ConvolutionIm2colBaseCompute(KernelBase *self) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + + int ret = conv_im2col->init_tmp_buffer_(conv_im2col); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + float *output_addr = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_addr); + if (!conv_im2col->output_need_align_) { + conv_im2col->tmp_output_ = output_addr; + } + + ret = ConvBaseRepackWeight(&conv_im2col->conv_); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvIm2ColBaseImpl, self, self->thread_nr_); + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; +} + +int ConvolutionIm2colBaseResize(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + + int ret = ConvBaseCheckResizeValid(conv); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvBasePrepare(conv); + if (ret != NNACL_OK) { + return ret; + } + + return ConvIm2ColBaseUpdateThreadNumProcess(self, TC_PTYPE(PrimType_Conv2DFusion), 0, 0, 0); +} + +int ConvolutionIm2colBasePrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + + conv_im2col->conv_.init_global_variable_(&conv_im2col->conv_); + + if (self->train_session_) { + int oc_block_num = UP_ROUND(conv_im2col->conv_.compute_.out_c_, conv_im2col->oc_tile_); + int kernel_chw = conv_im2col->conv_.compute_.in_c_ * conv_im2col->conv_.compute_.kernel_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(oc_block_num, kernel_chw, NNACL_ERR); + int pack_weight_size = oc_block_num * kernel_chw; + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_im2col->conv_); +} + +ConvolutionBaseStruct *CreateConvIm2ColBase(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColBaseInitTmpBuffer; + + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.run_impl_ = ConvIm2ColBaseRunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColBaseInitGlobalVariable; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colBaseCompute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_base.h new file mode 100644 index 0000000000000000000000000000000000000000..ce8ec9e112509ce8229355504e59aaf3647a052b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_base.h @@ -0,0 +1,52 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_BASE_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct ConvolutionIm2ColBaseStruct { + ConvolutionBaseStruct conv_; + int oc_tile_; + int row_tile_; + + float *tmp_output_; + float *packed_input_; + float *col_major_input_; + bool output_need_align_; + + void (*row_major_to_col_nmajor_)(const float *src_ptr, float *dst_ptr, int row, int col); + int (*init_tmp_buffer_)(struct ConvolutionIm2ColBaseStruct *conv_im2col); +} ConvolutionIm2ColBaseStruct; + +int ConvIm2ColBaseMallocWeightBiasData(ConvolutionBaseStruct *conv); +int ConvIm2ColBaseInitTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col); +int ConvIm2ColBaseImpl(void *cdata, int task_id, float l, float r); +void ConvIm2ColBaseFreeTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col); +void ConvIm2ColBasePackWeight(ConvolutionBaseStruct *conv); +int ConvIm2ColBaseRunImpl(ConvolutionBaseStruct *conv, int task_id); +int ConvolutionIm2colBaseCompute(KernelBase *self); +int ConvolutionIm2colBasePrepare(KernelBase *self); +int ConvolutionIm2colBaseResize(KernelBase *self); +int ConvolutionIm2colBaseRelease(KernelBase *self); +ConvolutionBaseStruct *CreateConvIm2ColBase(ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_sse.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_sse.c new file mode 100644 index 0000000000000000000000000000000000000000..c08d3b09a3b1611bef0c0d30d1734c8b731b8a66 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_sse.c @@ -0,0 +1,47 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_SSE +#include "nnacl_c/kernel/convolution_im2col_sse.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/tensor_c.h" + +void ConvIm2ColSSEInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C8NUM; + conv_im2col->row_tile_ = C4NUM; + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col8Major; +} + +ConvolutionBaseStruct *CreateConvIm2ColSSE(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColBaseInitTmpBuffer; + + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.run_impl_ = ConvIm2ColBaseRunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColSSEInitGlobalVariable; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colBaseCompute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_sse.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_sse.h new file mode 100644 index 0000000000000000000000000000000000000000..9762eee9eeca61a5720d386b47916211400361a4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_sse.h @@ -0,0 +1,29 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_SSE_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_SSE_H_ + +#ifdef ENABLE_SSE +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_im2col_base.h" + +ConvolutionBaseStruct *CreateConvIm2ColSSE(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_SSE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_slidewindow.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_slidewindow.c new file mode 100644 index 0000000000000000000000000000000000000000..c031575b82ddffdbca8b7eef3507ecaeab86d518 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_slidewindow.c @@ -0,0 +1,227 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) +#include "nnacl_c/kernel/convolution_slidewindow.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/tensor_c_utils.h" + +int ConvSWInitTmpBuffer(ConvolutionSWStruct *conv_sw) { + TensorC *input_tensor = conv_sw->conv_.base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + float *input_data = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + ConvComputeParam *compute = &conv_sw->conv_.compute_; + NNACL_CHECK_NULL_RETURN_ERR(compute); + + if (conv_sw->ic_res_ != 0 && compute->kernel_h_ == 1 && compute->kernel_w_ == 1) { + int ic_block_num = UP_DIV(compute->in_c_, conv_sw->in_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_n_, compute->in_hw_, NNACL_ERR); + int input_bhw = compute->in_n_ * conv_sw->conv_.compute_.in_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input_bhw, ic_block_num * conv_sw->in_tile_, NNACL_ERR); + + conv_sw->input_data_ = (float *)conv_sw->conv_.base_.env_->Alloc( + conv_sw->conv_.base_.env_->allocator_, input_bhw * ic_block_num * conv_sw->in_tile_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_sw->input_data_); + + PackNHWCToNHWCXFp32(input_data, conv_sw->input_data_, compute->in_n_, compute->in_hw_, compute->in_c_, + conv_sw->oc_tile_); + } else { + conv_sw->input_data_ = input_data; + } + + float *out_data = (float *)conv_sw->conv_.base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + if (conv_sw->oc_res_ == 0) { // not need to malloc dst + conv_sw->output_data_ = out_data; + } else { // need to malloc dst to align block + int oc_block_num = UP_DIV(compute->out_c_, conv_sw->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_n_, compute->out_hw_, NNACL_ERR); + int output_bhw = compute->out_n_ * compute->out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, oc_block_num * conv_sw->oc_tile_, NNACL_ERR); + conv_sw->output_data_ = (float *)conv_sw->conv_.base_.env_->Alloc( + conv_sw->conv_.base_.env_->allocator_, output_bhw * oc_block_num * conv_sw->oc_tile_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_sw->output_data_); + } + + return NNACL_OK; +} + +void ConvSWFreeTmpBuffer(ConvolutionSWStruct *conv_sw) { + ConvParameter *conv_param = (ConvParameter *)conv_sw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_VOID(conv_param); + + if (conv_sw->output_data_ != NULL && conv_sw->oc_res_ != 0) { + conv_sw->conv_.base_.env_->Free(conv_sw->conv_.base_.env_->allocator_, conv_sw->output_data_); + conv_sw->output_data_ = NULL; + } + if (conv_sw->input_data_ != NULL && conv_sw->ic_res_ != 0 && conv_param->kernel_w_ == 1 && + conv_param->kernel_h_ == 1) { + conv_sw->conv_.base_.env_->Free(conv_sw->conv_.base_.env_->allocator_, conv_sw->input_data_); + conv_sw->input_data_ = NULL; + } +} + +void ConvSWPackWeight(ConvolutionBaseStruct *conv) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(conv_sw); + TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(filter_tensor); + + int input_channel = NNACLGetChannel(filter_tensor); + int output_channel = NNACLGetBatch(filter_tensor); + int kernel_h = NNACLGetHeight(filter_tensor); + int kernel_w = NNACLGetWidth(filter_tensor); + + int oc_block_num = UP_DIV(output_channel, conv_sw->oc_tile_); + void *origin_weight = (conv->base_.train_session_) ? filter_tensor->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + PackNHWCToNXHWCXFp32(kernel_h, kernel_w, output_channel, oc_block_num, input_channel, (float *)conv->packed_weight_, + (float *)origin_weight); +} + +int ConvSWMallocWeightBiasData(ConvolutionBaseStruct *conv) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(filter_tensor); + + int input_channel = NNACLGetChannel(filter_tensor); + int output_channel = NNACLGetBatch(filter_tensor); + int kernel_h = NNACLGetHeight(filter_tensor); + int kernel_w = NNACLGetWidth(filter_tensor); + + conv_param->input_channel_ = input_channel; + conv_param->output_channel_ = output_channel; + int kernel_plane = kernel_h * kernel_w; + int oc_block_num = UP_DIV(output_channel, conv_sw->oc_tile_); + int pack_weight_size = oc_block_num * conv_sw->oc_tile_ * input_channel * kernel_plane; + if (!conv_sw->conv_.base_.train_session_) { + conv_sw->conv_.packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_sw->conv_.packed_weight_); + } + + if (conv_sw->conv_.base_.in_size_ == THREE_TENSOR) { + int malloc_size = oc_block_num * conv_sw->oc_tile_ * sizeof(float); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, malloc_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + memset(conv->bias_data_, 0, oc_block_num * conv_sw->oc_tile_ * sizeof(float)); + } + return NNACL_OK; +} + +int ConvSWImpl(void *cdata, int task_id, float l, float r) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + return conv_sw->conv_.run_impl_(&conv_sw->conv_, task_id); +} + +int ConvolutionSWCompute(KernelBase *self) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + + int ret = ConvSWInitTmpBuffer(conv_sw); + if (ret != NNACL_OK) { + ConvSWFreeTmpBuffer(conv_sw); + return ret; + } + + ret = ConvBaseRepackWeight(&conv_sw->conv_); + if (ret != NNACL_OK) { + ConvSWFreeTmpBuffer(conv_sw); + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvSWImpl, self, self->thread_nr_); + if (ret != NNACL_OK) { + ConvSWFreeTmpBuffer(conv_sw); + return ret; + } + + if (conv_sw->oc_res_ != 0) { + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + float *out_data = (float *)self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + PackNHWCXToNHWCFp32(conv_sw->output_data_, out_data, conv_param->output_batch_, + conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_, conv_sw->oc_tile_); + } + + ConvSWFreeTmpBuffer(conv_sw); + return NNACL_OK; +} + +int ConvolutionSWRelease(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvBaseRelease(conv); + return NNACL_OK; +} + +int ConvolutionSWResize(KernelBase *self) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + int ret = ConvBaseCheckResizeValid(&conv_sw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvBasePrepare(&conv_sw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + InitSlidingParamConv(&conv_sw->sw_param_, conv_param, conv_sw->in_tile_, conv_sw->oc_tile_); + return NNACL_OK; +} + +int ConvolutionSWPrepare(KernelBase *self) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + + conv_sw->conv_.init_global_variable_(&conv_sw->conv_); + + if (self->train_session_) { + TensorC *filter_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(filter_tensor); + NNACL_CHECK_FALSE(filter_tensor->shape_size_ != DIMENSION_4D, NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + + int input_channel = NNACLGetChannel(filter_tensor); + int output_channel = NNACLGetBatch(filter_tensor); + int kernel_h = NNACLGetHeight(filter_tensor); + int kernel_w = NNACLGetWidth(filter_tensor); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_h, kernel_w, NNACL_ERR); + int kernel_hw = kernel_h * kernel_w; + int oc_block_num = UP_DIV(output_channel, conv_sw->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input_channel, kernel_hw, NNACL_ERR); + int kernel_chw = input_channel * kernel_hw; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(oc_block_num * conv_sw->oc_tile_, kernel_chw, NNACL_ERR); + int pack_weight_size = oc_block_num * conv_sw->oc_tile_ * kernel_chw; + + conv_sw->conv_.base_.work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_sw->conv_); +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_slidewindow.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_slidewindow.h new file mode 100644 index 0000000000000000000000000000000000000000..a888b9f23522fb4f37ae41f33b51185fe113b636 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_slidewindow.h @@ -0,0 +1,46 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_SLIDEWINDOW_H_ +#define NNACL_KERNEL_CONVOLLUTION_SLIDEWINDOW_H_ + +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/matmul_parameter.h" + +typedef struct ConvolutionSWStruct { + ConvolutionBaseStruct conv_; + SlidingWindowParam sw_param_; + int oc_tile_; + int in_tile_; + int oc_res_; + int ic_res_; + float *output_data_; + float *input_data_; +} ConvolutionSWStruct; + +int ConvolutionSWPrepare(KernelBase *self); +int ConvolutionSWCompute(KernelBase *self); +int ConvolutionSWResize(KernelBase *self); +int ConvolutionSWRelease(KernelBase *self); +void ConvSWPackWeight(ConvolutionBaseStruct *conv); +int ConvSWMallocWeightBiasData(ConvolutionBaseStruct *conv); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_SLIDEWINDOW_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_1x1.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_1x1.c new file mode 100644 index 0000000000000000000000000000000000000000..c0aa92ed33ffe945c2f16f633081191429c0e1bd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_1x1.c @@ -0,0 +1,152 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_sw_1x1.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/kernel/matmul_create.h" + +int MatmulConv1x1Prelare(ConvolutionSW1x1Struct *sw_1x1) { + sw_1x1->matmul_->batch_ = 1; + sw_1x1->matmul_->a_batch_ = 1; + sw_1x1->matmul_->b_batch_ = 1; + + sw_1x1->matmul_->compute_.deep_ = sw_1x1->conv_.compute_.in_c_; + sw_1x1->matmul_->compute_.col_ = sw_1x1->conv_.compute_.out_c_; + sw_1x1->matmul_->compute_.row_ = sw_1x1->conv_.compute_.in_hw_ * sw_1x1->conv_.compute_.in_n_; + + return sw_1x1->matmul_->base_.Prepare(&sw_1x1->matmul_->base_); +} + +int MatmulConv1x1Resize(ConvolutionSW1x1Struct *sw_1x1) { + sw_1x1->matmul_->compute_.deep_ = sw_1x1->conv_.compute_.in_c_; + sw_1x1->matmul_->compute_.col_ = sw_1x1->conv_.compute_.out_c_; + sw_1x1->matmul_->compute_.row_ = sw_1x1->conv_.compute_.in_hw_ * sw_1x1->conv_.compute_.in_n_; + + MatmulBaseFreeBatchOffset(sw_1x1->matmul_); + int ret = MatmulBaseMallocBatchOffset(sw_1x1->matmul_); + if (ret != NNACL_OK) { + return ret; + } + + return sw_1x1->matmul_->base_.Resize(&sw_1x1->matmul_->base_); +} + +void UpdateTensorInfo(KernelBase *self, ConvolutionSW1x1Struct *sw_1x1) { + sw_1x1->matmul_->base_.in_ = self->in_; + sw_1x1->matmul_->base_.in_size_ = self->in_size_; + sw_1x1->matmul_->base_.out_ = self->out_; + sw_1x1->matmul_->base_.out_size_ = self->out_size_; + sw_1x1->matmul_->base_.workspace_ = self->workspace_; +} + +int ConvolutionSW1x1Compute(KernelBase *self) { + ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1); + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1->matmul_); + + UpdateTensorInfo(self, sw_1x1); + return sw_1x1->matmul_->base_.Compute(&sw_1x1->matmul_->base_); +} + +int ConvolutionSW1x1Resize(KernelBase *self) { + ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1); + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1->matmul_); + + UpdateTensorInfo(self, sw_1x1); + return MatmulConv1x1Resize(sw_1x1); +} + +int ConvolutionSW1x1Prepare(KernelBase *self) { + ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1); + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1->matmul_); + + sw_1x1->matmul_->matrix_b_.origin_ptr_ = sw_1x1->conv_.origin_weight_; + sw_1x1->matmul_->matrix_b_.origin_need_free_ = false; + sw_1x1->matmul_->matrix_c_.origin_ptr_ = sw_1x1->conv_.origin_bias_; + sw_1x1->matmul_->matrix_c_.origin_need_free_ = false; + + sw_1x1->matmul_->infer_shape_ = sw_1x1->conv_.infershape_done_; + sw_1x1->matmul_->base_.train_session_ = self->train_session_; + sw_1x1->matmul_->base_.thread_nr_ = self->thread_nr_; + sw_1x1->matmul_->base_.env_ = self->env_; + + UpdateTensorInfo(self, sw_1x1); + return MatmulConv1x1Prelare(sw_1x1); +} + +int ConvolutionSW1x1Release(KernelBase *self) { + ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1); + + if (sw_1x1->matmul_ != NULL) { + sw_1x1->matmul_->matrix_b_.origin_ptr_ = NULL; + sw_1x1->matmul_->matrix_c_.origin_ptr_ = NULL; + + (void)sw_1x1->matmul_->base_.Release(&sw_1x1->matmul_->base_); + + if (sw_1x1->matmul_->base_.param_ != NULL) { + free(sw_1x1->matmul_->base_.param_); + sw_1x1->matmul_->base_.param_ = NULL; + } + + free(sw_1x1->matmul_); + sw_1x1->matmul_ = NULL; + } + + ConvBaseRelease(&sw_1x1->conv_); + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateConvolutionSW1x1(ConvParameter *conv_param, bool input_const, bool weight_const) { + ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)malloc(sizeof(ConvolutionSW1x1Struct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(sw_1x1); + memset(sw_1x1, 0, sizeof(ConvolutionSW1x1Struct)); + + sw_1x1->conv_.is_sharing_pack_ = false; + sw_1x1->conv_.base_.Compute = ConvolutionSW1x1Compute; + sw_1x1->conv_.base_.Resize = ConvolutionSW1x1Resize; + sw_1x1->conv_.base_.Prepare = ConvolutionSW1x1Prepare; + sw_1x1->conv_.base_.Release = ConvolutionSW1x1Release; + + OpParameter *param = (OpParameter *)malloc(sizeof(MatMulParameter)); + if (param == NULL) { + free(sw_1x1); + return NULL; + } + MatMulParameter *matmul_param = (MatMulParameter *)param; + matmul_param->op_parameter_ = conv_param->op_parameter_; + matmul_param->act_type_ = conv_param->act_type_; + matmul_param->a_transpose_ = false; + matmul_param->b_transpose_ = true; + + KernelBase *matmul = CreateMatmulKernel(); + if (matmul == NULL) { + free(sw_1x1); + free(param); + return NULL; + } + + ((MatmulStruct *)matmul)->is_sharing_pack_ = false; + ((MatmulStruct *)matmul)->a_const_ = input_const; + ((MatmulStruct *)matmul)->b_const_ = weight_const; + ((MatmulStruct *)matmul)->base_.param_ = param; + sw_1x1->matmul_ = (MatmulStruct *)matmul; + return (ConvolutionBaseStruct *)sw_1x1; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_1x1.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_1x1.h new file mode 100644 index 0000000000000000000000000000000000000000..46804c5d437d09b1b7c4020a79ce14ee22e68beb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_1x1.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_SW_1X1_H_ +#define NNACL_KERNEL_CONVOLLUTION_SW_1X1_H_ + +#ifdef ENABLE_AVX +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/kernel/matmul_struct.h" + +typedef struct ConvolutionSW1x1Struct { + ConvolutionBaseStruct conv_; + MatmulStruct *matmul_; +} ConvolutionSW1x1Struct; + +ConvolutionBaseStruct *CreateConvolutionSW1x1(ConvParameter *conv_param, bool input_const, bool weight_const); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_SW_1X1_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_arm64.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_arm64.c new file mode 100644 index 0000000000000000000000000000000000000000..4a9194107baa88215c86c4dac24150c63f1b011c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_arm64.c @@ -0,0 +1,59 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM64 +#include "nnacl_c/kernel/convolution_sw_arm64.h" +#include "nnacl_c/kernel/convolution_slidewindow.h" +#include "nnacl_c/fp32/conv_sw_arm64_fp32.h" + +void ConvSWARM64InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(conv_sw); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_VOID(conv_param); + + conv_sw->oc_tile_ = C8NUM; + conv_sw->oc_res_ = conv_param->output_channel_ % conv_sw->oc_tile_; +} + +int ConvSWARM64RunImpl(ConvolutionBaseStruct *conv, int task_id) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + ConvSWArm64Fp32(conv_sw->input_data_, (float *)conv->packed_weight_, (float *)conv->bias_data_, conv_sw->output_data_, + task_id, conv_param, &conv_sw->sw_param_); + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateConvolutionSWARM64(ConvParameter *conv_param) { + ConvolutionSWStruct *sw = (ConvolutionSWStruct *)malloc(sizeof(ConvolutionSWStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(sw); + memset(sw, 0, sizeof(ConvolutionSWStruct)); + + sw->conv_.run_impl_ = ConvSWARM64RunImpl; + sw->conv_.init_global_variable_ = ConvSWARM64InitGlobalVariable; + sw->conv_.pack_weight_ = ConvSWPackWeight; + sw->conv_.malloc_weight_bias_ = ConvSWMallocWeightBiasData; + + sw->conv_.base_.Compute = ConvolutionSWCompute; + sw->conv_.base_.Prepare = ConvolutionSWPrepare; + sw->conv_.base_.Release = ConvolutionSWRelease; + sw->conv_.base_.Resize = ConvolutionSWResize; + + return (ConvolutionBaseStruct *)sw; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_arm64.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_arm64.h new file mode 100644 index 0000000000000000000000000000000000000000..e546924c8532084e21dfb07109e88c9db538679d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_arm64.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_SW_ARM64_H_ +#define NNACL_KERNEL_CONVOLLUTION_SW_ARM64_H_ +#ifdef ENABLE_ARM64 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +ConvolutionBaseStruct *CreateConvolutionSWARM64(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_SW_ARM64_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_avx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_avx.c new file mode 100644 index 0000000000000000000000000000000000000000..fbcadfca2b0a2ca23320da75cb3a804224220484 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_avx.c @@ -0,0 +1,71 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_sw_avx.h" +#include "nnacl_c/kernel/convolution_slidewindow.h" +#include "nnacl_c/fp32/conv_1x1_avx_fp32.h" +#include "nnacl_c/fp32/conv_sw_avx_fp32.h" + +void ConvSWAVXInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(conv_sw); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_VOID(conv_param); + + conv_sw->oc_tile_ = C8NUM; + conv_sw->oc_res_ = conv_param->output_channel_ % conv_sw->oc_tile_; + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + // 1x1 conv is aligned to C8NUM + conv_sw->in_tile_ = C8NUM; + conv_sw->ic_res_ = conv_param->input_channel_ % conv_sw->in_tile_; + } +} + +int ConvSWAVXRunImpl(ConvolutionBaseStruct *conv, int task_id) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + if (conv_param->kernel_w_ == 1 && conv_param->kernel_h_ == 1) { + Conv1x1SWAVXFp32(conv_sw->input_data_, (float *)conv->packed_weight_, (float *)conv->bias_data_, + conv_sw->output_data_, task_id, conv_param, &conv_sw->sw_param_); + } else { + ConvSWAVXFp32(conv_sw->input_data_, (float *)conv->packed_weight_, (float *)conv->bias_data_, conv_sw->output_data_, + task_id, conv_param, &conv_sw->sw_param_); + } + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateConvolutionSWAVX(ConvParameter *conv_param) { + ConvolutionSWStruct *sw = (ConvolutionSWStruct *)malloc(sizeof(ConvolutionSWStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(sw); + memset(sw, 0, sizeof(ConvolutionSWStruct)); + + sw->conv_.run_impl_ = ConvSWAVXRunImpl; + sw->conv_.init_global_variable_ = ConvSWAVXInitGlobalVariable; + sw->conv_.pack_weight_ = ConvSWPackWeight; + sw->conv_.malloc_weight_bias_ = ConvSWMallocWeightBiasData; + + sw->conv_.base_.Compute = ConvolutionSWCompute; + sw->conv_.base_.Prepare = ConvolutionSWPrepare; + sw->conv_.base_.Release = ConvolutionSWRelease; + sw->conv_.base_.Resize = ConvolutionSWResize; + + return (ConvolutionBaseStruct *)sw; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_avx.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_avx.h new file mode 100644 index 0000000000000000000000000000000000000000..c2d47268511429e21920878bd4c48ffb54c30d63 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_avx.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_SW_AVX_H_ +#define NNACL_KERNEL_CONVOLLUTION_SW_AVX_H_ +#ifdef ENABLE_AVX +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +ConvolutionBaseStruct *CreateConvolutionSWAVX(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_SW_AVX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd.c new file mode 100644 index 0000000000000000000000000000000000000000..d7e464dd92f4bf7bfd44ba2f827fb089b888c9ba --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd.c @@ -0,0 +1,76 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_winograd.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_winograd_avx.h" +#endif +#ifdef ENABLE_SSE +#include "nnacl_c/kernel/convolution_winograd_sse.h" +#endif +#ifdef ENABLE_ARM64 +#include "nnacl_c/kernel/convolution_winograd_arm64.h" +#endif +#ifdef ENABLE_ARM32 +#include "nnacl_c/kernel/convolution_winograd_arm32.h" +#endif + +ConvolutionWinogradBaseStruct *SelectConvolutionWinograd(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *kernel = NULL; + +#ifdef ENABLE_AVX + kernel = CreateConvWinogradAVX(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_SSE + kernel = CreateConvWinogradSSE(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_ARM64 + kernel = CreateConvWinogradARM64(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_ARM32 + kernel = CreateConvWinogradARM32(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + + kernel = CreateConvWinogradBase(conv_param); + return kernel; +} + +ConvolutionBaseStruct *CreateConvolutionWinograd(ConvParameter *conv_param, int out_unit) { + ConvolutionWinogradBaseStruct *kernel = SelectConvolutionWinograd(conv_param); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(kernel); + + kernel->output_unit_ = out_unit; + kernel->conv_.malloc_weight_bias_ = ConvWinoBaseMallocWeightBiasData; + kernel->conv_.run_impl_ = ConvWinoBaseRunImpl; + kernel->conv_.pack_weight_ = ConvWinoBasePackWeight; + return (ConvolutionBaseStruct *)kernel; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd.h new file mode 100644 index 0000000000000000000000000000000000000000..23c5e1c86647c01d260051642e241ffcc5d8fd4c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct ConvolutionWinogradStruct { + ConvolutionBaseStruct conv_; +} ConvolutionWinogradStruct; + +ConvolutionBaseStruct *CreateConvolutionWinograd(ConvParameter *conv_param, int out_uint); + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm32.c new file mode 100644 index 0000000000000000000000000000000000000000..f22088a80aa5328ef21aa265f4cad77f84466239 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm32.c @@ -0,0 +1,42 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM32 +#include "nnacl_c/kernel/convolution_winograd_arm32.h" + +void ConvWinoARM32InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + winograd->oc_block_ = C8NUM; + winograd->tmp_data_tile_ = C4NUM; + winograd->tile_num_ = C12NUM; +} + +ConvolutionWinogradBaseStruct *CreateConvWinogradARM32(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *winograd = + (ConvolutionWinogradBaseStruct *)malloc(sizeof(ConvolutionWinogradBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(winograd); + memset(winograd, 0, sizeof(ConvolutionWinogradBaseStruct)); + + winograd->config_input_output_ = ConvWinoBaseConfigInputOutput; + winograd->conv_.init_global_variable_ = ConvWinoARM32InitGlobalVariable; + + winograd->conv_.base_.Prepare = ConvolutionWinogradBasePrepare; + winograd->conv_.base_.Resize = ConvolutionWinogradBaseResize; + winograd->conv_.base_.Release = ConvolutionWinogradBaseRelease; + winograd->conv_.base_.Compute = ConvolutionWinogradBaseCompute; + return winograd; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm32.h new file mode 100644 index 0000000000000000000000000000000000000000..21b32c4a491f2770d5a27a5f7c580d660d7a39e2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM32_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM32_H_ + +#ifdef ENABLE_ARM32 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +ConvolutionWinogradBaseStruct *CreateConvWinogradARM32(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm64.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm64.c new file mode 100644 index 0000000000000000000000000000000000000000..0594f2da4e4da891155094006b142cc63a349b09 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm64.c @@ -0,0 +1,60 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM64 +#include "nnacl_c/kernel/convolution_winograd_arm64.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +void ConvWinoARM64InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + winograd->oc_block_ = C8NUM; + winograd->tmp_data_tile_ = C4NUM; + winograd->tile_num_ = C12NUM; +} + +int ConvWinoARM64ConfigInputOutput(ConvolutionWinogradBaseStruct *winograd) { + winograd->transfer_functions_.in_func_ = GetInputTransFunc(winograd->input_unit_); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.in_func_); + + winograd->transfer_functions_.in_step_func_ = GetInputTransStepFunc(winograd->input_unit_); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.in_func_); + + winograd->transfer_functions_.in_pack_func_ = GetInputTransPackFunc(winograd->input_unit_); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.in_func_); + + ActType act_type = ((ConvParameter *)winograd->conv_.base_.param_)->act_type_; + winograd->transfer_functions_.out_func_ = GetOutputTransFunc(winograd->input_unit_, winograd->output_unit_, act_type); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.in_func_); + + return NNACL_OK; +} + +ConvolutionWinogradBaseStruct *CreateConvWinogradARM64(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *winograd = + (ConvolutionWinogradBaseStruct *)malloc(sizeof(ConvolutionWinogradBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(winograd); + memset(winograd, 0, sizeof(ConvolutionWinogradBaseStruct)); + + winograd->config_input_output_ = ConvWinoARM64ConfigInputOutput; + winograd->conv_.init_global_variable_ = ConvWinoARM64InitGlobalVariable; + + winograd->conv_.base_.Prepare = ConvolutionWinogradBasePrepare; + winograd->conv_.base_.Resize = ConvolutionWinogradBaseResize; + winograd->conv_.base_.Release = ConvolutionWinogradBaseRelease; + winograd->conv_.base_.Compute = ConvolutionWinogradBaseCompute; + return winograd; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm64.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm64.h new file mode 100644 index 0000000000000000000000000000000000000000..d2e98d4f94fe7f1e133a80ab7fde1d21bec0e9f1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm64.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM64_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM64_H_ + +#ifdef ENABLE_ARM64 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +ConvolutionWinogradBaseStruct *CreateConvWinogradARM64(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM64_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_avx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_avx.c new file mode 100644 index 0000000000000000000000000000000000000000..30242ae440996c08c35c661c8b0412a885a1853e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_avx.c @@ -0,0 +1,43 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_winograd_avx.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +void ConvWinoAVXInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + winograd->oc_block_ = C16NUM; + winograd->tmp_data_tile_ = C8NUM; + winograd->tile_num_ = C12NUM; +} + +ConvolutionWinogradBaseStruct *CreateConvWinogradAVX(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *winograd = + (ConvolutionWinogradBaseStruct *)malloc(sizeof(ConvolutionWinogradBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(winograd); + memset(winograd, 0, sizeof(ConvolutionWinogradBaseStruct)); + + winograd->config_input_output_ = ConvWinoBaseConfigInputOutput; + winograd->conv_.init_global_variable_ = ConvWinoAVXInitGlobalVariable; + + winograd->conv_.base_.Prepare = ConvolutionWinogradBasePrepare; + winograd->conv_.base_.Resize = ConvolutionWinogradBaseResize; + winograd->conv_.base_.Release = ConvolutionWinogradBaseRelease; + winograd->conv_.base_.Compute = ConvolutionWinogradBaseCompute; + return winograd; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_avx.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_avx.h new file mode 100644 index 0000000000000000000000000000000000000000..baa831823ddc0dac914853a4326c7551c364052b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_avx.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_AVX_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_AVX_H_ + +#ifdef ENABLE_AVX +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +ConvolutionWinogradBaseStruct *CreateConvWinogradAVX(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_AVX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_base.c new file mode 100644 index 0000000000000000000000000000000000000000..f71e7fceff4216b828b0f154dfeb5378138eb2d4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_base.c @@ -0,0 +1,320 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_winograd_base.h" +#include "nnacl_c/base/minimal_filtering_generator.h" +#include "nnacl_c/fp32/winograd_transform.h" +#include "nnacl_c/fp32/conv_winograd_fp32.h" + +int ConvWinoBaseMallocWeightBiasData(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd); + + // set data + size_t trans_matrix_data_size = winograd->input_unit_ * winograd->input_unit_ * conv->compute_.in_c_ * + UP_ROUND(conv->compute_.out_c_, winograd->oc_block_) * sizeof(float); + if (!conv->base_.train_session_) { + if (conv->packed_weight_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(trans_matrix_data_size); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, trans_matrix_data_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + } + + float matrix_a[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + float matrix_at[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + float matrix_b[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + float matrix_bt[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + float coef = 1.0f; + if (winograd->input_unit_ == CONVOLUTION_WINOGRAD_INPUT_UNIT_SIZE) { + coef = 0.5f; + } + int ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, winograd->matrix_g_, winograd->matrix_gt_, coef, + winograd->output_unit_, winograd->kernel_unit_); + if (ret != NNACL_OK) { + return ret; + } + + // init bias + size_t new_bias_size = UP_ROUND(conv->compute_.out_c_, C4NUM) * sizeof(float); + if (conv->bias_data_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(new_bias_size); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, new_bias_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, new_bias_size); + return NNACL_OK; +} + +void ConvWinoBaseFreeTmpBuffer(ConvolutionWinogradBaseStruct *winograd) { + ExecEnv *env = winograd->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_VOID(env); + + if (winograd->trans_input_ != NULL) { + env->Free(env->allocator_, winograd->trans_input_); + winograd->trans_input_ = NULL; + } + if (winograd->tmp_data_ != NULL) { + env->Free(env->allocator_, winograd->tmp_data_); + winograd->tmp_data_ = NULL; + } + if (winograd->gemm_out_ != NULL) { + env->Free(env->allocator_, winograd->gemm_out_); + winograd->gemm_out_ = NULL; + } + if (winograd->col_buffer_ != NULL) { + env->Free(env->allocator_, winograd->col_buffer_); + winograd->col_buffer_ = NULL; + } + if (winograd->opt_input_trans_ != NULL) { + env->Free(env->allocator_, winograd->opt_input_trans_); + winograd->opt_input_trans_ = NULL; + } +} + +void ConvWinoBaseInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + winograd->oc_block_ = C8NUM; + winograd->tmp_data_tile_ = C4NUM; + winograd->tile_num_ = C12NUM; +} + +int ConvWinoBaseWinogradFilterTransform(ConvolutionWinogradBaseStruct *winograd, const float *weight_data) { + NNACL_CHECK_ZERO_RETURN_ERR(winograd->oc_block_); + return WinogradWeightTransform(weight_data, (float *)winograd->conv_.packed_weight_, winograd->matrix_g_, + winograd->matrix_gt_, winograd->oc_block_, winograd->input_unit_, + winograd->kernel_unit_, winograd->conv_.compute_.in_c_, + winograd->conv_.compute_.out_c_, true); +} + +void ConvWinoBasePackWeight(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(winograd); + TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(weight_tensor); + void *origin_weight = (conv->base_.train_session_) ? weight_tensor->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + ConvWinoBaseWinogradFilterTransform(winograd, (float *)origin_weight); +} + +int ConvolutionWinogradBasePrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(winograd); + + winograd->conv_.init_global_variable_(&winograd->conv_); + + winograd->kernel_unit_ = winograd->conv_.compute_.kernel_h_; + winograd->input_unit_ = winograd->output_unit_ + winograd->kernel_unit_ - 1; + + if (self->train_session_) { + TensorC *filter_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(filter_tensor); + NNACL_CHECK_FALSE(filter_tensor->shape_size_ != DIMENSION_4D, NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + + int input_plane = winograd->input_unit_ * winograd->input_unit_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input_plane, winograd->conv_.compute_.in_c_, NNACL_ERR); + int in_chw = input_plane * winograd->conv_.compute_.in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(in_chw, UP_ROUND(winograd->conv_.compute_.out_c_, winograd->oc_block_), NNACL_ERR); + int trans_matrix_data_size = + in_chw * UP_ROUND(winograd->conv_.compute_.out_c_, winograd->oc_block_) * sizeof(float); + self->work_size_ = trans_matrix_data_size; + } + + return ConvBaseInitConvWeightBias(&winograd->conv_); +} + +int ConvoWinoBaseUpdateThreadNumProcess(ConvolutionWinogradBaseStruct *winograd) { + if (winograd->conv_.compute_.in_n_ % winograd->conv_.base_.thread_nr_ == 0) { + winograd->conv_.use_batch_cut_flag_ = true; + return NNACL_OK; + } else { + winograd->conv_.use_batch_cut_flag_ = false; + } + + int update_thread = UP_DIV(UP_DIV(winograd->conv_.compute_.out_hw_, C12NUM), ConvMinBlock); + winograd->conv_.base_.thread_nr_ = NNACL_MIN(update_thread, winograd->conv_.base_.thread_nr_); + return NNACL_OK; +} + +int ConvoWinoBaseUpdateThread(ConvolutionWinogradBaseStruct *winograd) { +#ifdef DYNAMIC_THREAD_DISTRIBUTE + ConvoWinoBaseUpdateThreadNumProcess(winograd); +#else + KernelBase *base = &winograd->conv_.base_; + base->thread_nr_ = base->UpdateThread(TC_PTYPE(PrimType_Conv2DFusion), 0, 0, 0, base->thread_nr_); +#endif + return NNACL_OK; +} + +int ConvWinoBaseConfigInputOutput(ConvolutionWinogradBaseStruct *winograd) { + winograd->transfer_functions_.in_func_ = GetInputTransFunc(winograd->input_unit_); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.in_func_); + + ConvParameter *conv_param = (ConvParameter *)winograd->conv_.base_.param_; + winograd->transfer_functions_.out_func_ = + GetOutputTransFunc(winograd->input_unit_, winograd->output_unit_, conv_param->act_type_); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.out_func_); + + return NNACL_OK; +} + +int ConvoWinoBaseInitTmpBuffer(ConvolutionWinogradBaseStruct *winograd) { + ExecEnv *env = winograd->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + int thread_input_plane = winograd->conv_.base_.thread_nr_ * winograd->input_unit_ * winograd->input_unit_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(winograd->tile_num_, thread_input_plane, NNACL_ERR); + int total_thread_input_plane = winograd->tile_num_ * thread_input_plane; + size_t tile_buffer_size = total_thread_input_plane * winograd->conv_.compute_.in_c_ * sizeof(float); + winograd->trans_input_ = (float *)env->Alloc(env->allocator_, tile_buffer_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd->trans_input_); + + int oc8 = UP_ROUND(winograd->conv_.compute_.out_c_, C8NUM); + winograd->gemm_out_ = env->Alloc(env->allocator_, total_thread_input_plane * oc8 * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd->gemm_out_); + + winograd->tmp_data_ = env->Alloc(env->allocator_, winograd->tmp_data_tile_ * thread_input_plane * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd->tmp_data_); + + winograd->col_buffer_ = env->Alloc(env->allocator_, winograd->conv_.base_.thread_nr_ * winograd->tile_num_ * + winograd->conv_.compute_.in_c_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd->col_buffer_); + + int tile = UP_ROUND(winograd->conv_.compute_.in_c_, winograd->tmp_data_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thread_input_plane, tile, NNACL_ERR); + winograd->opt_input_trans_ = env->Alloc(env->allocator_, total_thread_input_plane * tile * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd->opt_input_trans_); + + winograd->tmp_buffer_address_list_[Index0] = winograd->trans_input_; + winograd->tmp_buffer_address_list_[Index1] = winograd->gemm_out_; + winograd->tmp_buffer_address_list_[Index2] = winograd->tmp_data_; + winograd->tmp_buffer_address_list_[Index3] = winograd->col_buffer_; + winograd->tmp_buffer_address_list_[Index4] = winograd->opt_input_trans_; + return NNACL_OK; +} + +int ConvWinoBaseRunImpl(ConvolutionBaseStruct *conv, int task_id) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(winograd); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + TensorC *input_tensor = conv->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + float *input_data = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + + TensorC *output_tensor = conv->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + float *output_data = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + if (conv->use_batch_cut_flag_) { + ConvWinogardFp32CutByBatch(input_data, (float *)conv->packed_weight_, (float *)conv->bias_data_, output_data, + winograd->tmp_buffer_address_list_, task_id, conv_param, winograd->transfer_functions_); + } else { + ConvWinogardFp32(input_data, (float *)conv->packed_weight_, (float *)conv->bias_data_, output_data, + winograd->tmp_buffer_address_list_, task_id, conv_param, winograd->transfer_functions_); + } + + return NNACL_OK; +} + +int ConvWinoImpl(void *cdata, int task_id, float l, float r) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv); + return conv->run_impl_(conv, task_id); +} + +void ConvWinoBaseUpdateParam(ConvParameter *param, ConvolutionWinogradBaseStruct *winograd) { + param->input_unit_ = winograd->input_unit_; + param->output_unit_ = winograd->output_unit_; +} + +int ConvolutionWinogradBaseResize(KernelBase *self) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(winograd); + + int ret = ConvBaseCheckResizeValid(&winograd->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvBasePrepare(&winograd->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvoWinoBaseUpdateThread(winograd); + if (ret != NNACL_OK) { + return ret; + } + + ret = winograd->config_input_output_(winograd); + if (ret != NNACL_OK) { + return ret; + } + + ConvWinoBaseUpdateParam((ConvParameter *)self->param_, winograd); + return NNACL_OK; +} + +int ConvolutionWinogradBaseCompute(KernelBase *self) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(winograd); + + int ret = ConvoWinoBaseInitTmpBuffer(winograd); + if (ret != NNACL_OK) { + ConvWinoBaseFreeTmpBuffer(winograd); + return ret; + } + + ret = ConvBaseRepackWeight(&winograd->conv_); + if (ret != NNACL_OK) { + ConvWinoBaseFreeTmpBuffer(winograd); + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvWinoImpl, self, self->thread_nr_); + ConvWinoBaseFreeTmpBuffer(winograd); + return ret; +} + +int ConvolutionWinogradBaseRelease(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvBaseRelease(conv); + return NNACL_OK; +} + +ConvolutionWinogradBaseStruct *CreateConvWinogradBase(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *winograd = + (ConvolutionWinogradBaseStruct *)malloc(sizeof(ConvolutionWinogradBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(winograd); + memset(winograd, 0, sizeof(ConvolutionWinogradBaseStruct)); + + winograd->config_input_output_ = ConvWinoBaseConfigInputOutput; + winograd->conv_.init_global_variable_ = ConvWinoBaseInitGlobalVariable; + + winograd->conv_.base_.Prepare = ConvolutionWinogradBasePrepare; + winograd->conv_.base_.Resize = ConvolutionWinogradBaseResize; + winograd->conv_.base_.Release = ConvolutionWinogradBaseRelease; + winograd->conv_.base_.Compute = ConvolutionWinogradBaseCompute; + return (ConvolutionWinogradBaseStruct *)winograd; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_base.h new file mode 100644 index 0000000000000000000000000000000000000000..85ffca6b3c9e34037fde1e8aabd68b1ac2a40e43 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_base.h @@ -0,0 +1,65 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_BASE_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/fp32/winograd_utils.h" + +#define CONVOLUTION_WINOGRAD_MATRIX_SIZE 64 +#define CONVOLUTION_WINOGRAD_TMP_BUFFER_SIZE 5 +#define CONVOLUTION_WINOGRAD_INPUT_UNIT_SIZE 8 + +typedef float *TmpBufferAddress; + +typedef struct ConvolutionWinogradBaseStruct { + ConvolutionBaseStruct conv_; + + int kernel_unit_; + int input_unit_; + int output_unit_; + int oc_block_; + int tile_num_; + int tmp_data_tile_; + float *tmp_data_; + float *trans_input_; + float *gemm_out_; + float *col_buffer_; + float *opt_input_trans_; + float matrix_g_[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + float matrix_gt_[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + TmpBufferAddress tmp_buffer_address_list_[CONVOLUTION_WINOGRAD_TMP_BUFFER_SIZE]; + TransFuncList transfer_functions_; + + int (*config_input_output_)(struct ConvolutionWinogradBaseStruct *winograd); +} ConvolutionWinogradBaseStruct; + +void ConvWinoBasePackWeight(ConvolutionBaseStruct *conv); +int ConvWinoBaseConfigInputOutput(ConvolutionWinogradBaseStruct *winograd); +int ConvWinoBaseRunImpl(ConvolutionBaseStruct *conv, int task_id); +int ConvWinoBaseMallocWeightBiasData(ConvolutionBaseStruct *conv); +int ConvolutionWinogradBasePrepare(KernelBase *self); +int ConvolutionWinogradBaseResize(KernelBase *self); +int ConvolutionWinogradBaseRelease(KernelBase *self); +int ConvolutionWinogradBaseCompute(KernelBase *self); +ConvolutionWinogradBaseStruct *CreateConvWinogradBase(ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_sse.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_sse.c new file mode 100644 index 0000000000000000000000000000000000000000..d91dbc50076a52c70befc3d49c1f702d4cdb393c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_sse.c @@ -0,0 +1,44 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_SSE +#include "nnacl_c/kernel/convolution_winograd_sse.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +void ConvWinoSSEInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + winograd->oc_block_ = C8NUM; + winograd->tmp_data_tile_ = C4NUM; + winograd->tile_num_ = C12NUM; +} + +ConvolutionWinogradBaseStruct *CreateConvWinogradSSE(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *winograd = + (ConvolutionWinogradBaseStruct *)malloc(sizeof(ConvolutionWinogradBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(winograd); + memset(winograd, 0, sizeof(ConvolutionWinogradBaseStruct)); + + winograd->config_input_output_ = ConvWinoBaseConfigInputOutput; + winograd->conv_.init_global_variable_ = ConvWinoSSEInitGlobalVariable; + + winograd->conv_.base_.Prepare = ConvolutionWinogradBasePrepare; + winograd->conv_.base_.Resize = ConvolutionWinogradBaseResize; + winograd->conv_.base_.Release = ConvolutionWinogradBaseRelease; + winograd->conv_.base_.Compute = ConvolutionWinogradBaseCompute; + + return winograd; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_sse.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_sse.h new file mode 100644 index 0000000000000000000000000000000000000000..82755a52eca564dadc8e15f578e16d095db4db61 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_sse.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_SSE_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_SSE_H_ + +#ifdef ENABLE_SSE +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +ConvolutionWinogradBaseStruct *CreateConvWinogradSSE(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_SSE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop.c new file mode 100644 index 0000000000000000000000000000000000000000..16244cb8ba77ca858d285933e765274ebbb6e476 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop.c @@ -0,0 +1,96 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/crop.h" +#include "nnacl_c/base/crop_base.h" +#include "nnacl_c/fp32/crop_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/crop_fp16.h" +#endif + +int CropLaunch(void *cdata, int task_id, float l, float r) { + CropStruct *crop = (CropStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(crop); + + TensorC *in = crop->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in); + TensorC *out = crop->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out); + +#ifdef ENABLE_FP16 + if (in->data_type_ == kNumberTypeFloat16) { + Fp16Crop((float16_t *)in->data_, (float16_t *)out->data_, in->shape_, out->shape_, crop->in_offset_, + in->shape_size_, task_id, crop->base_.thread_nr_); + return NNACL_OK; + } +#endif + + CropParameter *crop_param = (CropParameter *)crop->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(crop_param); + Crop4D((float *)in->data_, (float *)out->data_, in->shape_, out->shape_, crop_param, task_id, crop->base_.thread_nr_); + return NNACL_OK; +} + +int CropResize(struct KernelBase *self) { + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + NNACL_CHECK_FALSE(out_tensor->shape_size_ <= Num1, NNACL_OUTPUT_TENSOR_ERROR); + + CropStruct *crop = (CropStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(crop); + CropParameter *crop_param = (CropParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(crop_param); + + return CropPadOffset(in_tensor->shape_size_, crop_param, crop->in_offset_); +} + +int CropCompute(struct KernelBase *self) { + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + CropParameter *crop_param = (CropParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(crop_param); + + if (in_tensor->data_type_ != kNumberTypeFloat16 && out_tensor->shape_[Index1] < self->thread_nr_) { + float *input_data = (float *)in_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + float *output_data = (float *)out_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + Crop4DNoParallel(input_data, output_data, in_tensor->shape_, out_tensor->shape_, crop_param); + return NNACL_OK; + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, CropLaunch, self, self->thread_nr_); +} + +KernelBase *CreateCrop(OpParameter *param, int data_type) { + CropStruct *crop = (CropStruct *)malloc(sizeof(CropStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(crop); + memset(crop, 0, sizeof(CropStruct)); + crop->base_.Prepare = DefaultPrepare1In1Out; + crop->base_.Resize = CropResize; + crop->base_.Release = DefaultRelease; + crop->base_.Compute = CropCompute; + return (KernelBase *)crop; +} + +REG_KERNEL_CREATOR(PrimType_Crop, kNumberTypeInt32, CreateCrop) +REG_KERNEL_CREATOR(PrimType_Crop, kNumberTypeFloat32, CreateCrop) +REG_KERNEL_CREATOR(PrimType_Crop, kNumberTypeFloat16, CreateCrop) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop.h new file mode 100644 index 0000000000000000000000000000000000000000..26408dd7a3a65e341f90ebbd6ed9c4a5ce90e64f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CROP_H_ +#define NNACL_KERNEL_CROP_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct CropStruct { + KernelBase base_; + int64_t in_offset_[COMM_SHAPE_SIZE]; +} CropStruct; + +KernelBase *CreateCrop(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_CROP_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop_and_resize.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop_and_resize.c new file mode 100644 index 0000000000000000000000000000000000000000..0c0054b36bff152a3bb1f69fe80abd81add7e4dd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop_and_resize.c @@ -0,0 +1,190 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/crop_and_resize.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/resize_fp32.h" +#include "nnacl_c/tensor_c_utils.h" + +int CropAndResizeMallocTmpBuffer(CropAndResizeStruct *crop_and_resize) { + TensorC *input_tensor = crop_and_resize->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *output_tensor = crop_and_resize->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + ExecEnv *env = crop_and_resize->base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + // Malloc buffer to save coordinate. + // For mode CROP_AND_RESIZE, different output batches require different cache coordinates. + crop_and_resize->batch_ = NNACLGetBatch(output_tensor); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(crop_and_resize->new_height_, crop_and_resize->batch_, NNACL_ERR); + int height_size = crop_and_resize->new_height_ * crop_and_resize->batch_; + NNACL_CHECK_MALLOC_SIZE(height_size); + crop_and_resize->y_bottoms_ = (int *)env->Alloc(env->allocator_, height_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->y_bottoms_); + crop_and_resize->y_tops_ = (int *)env->Alloc(env->allocator_, height_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->y_tops_); + crop_and_resize->y_bottom_weights_ = (float *)env->Alloc(env->allocator_, height_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->y_bottom_weights_); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(crop_and_resize->new_width_, crop_and_resize->batch_, NNACL_ERR); + int width_size = crop_and_resize->new_width_ * crop_and_resize->batch_; + NNACL_CHECK_MALLOC_SIZE(width_size); + crop_and_resize->x_lefts_ = (int *)env->Alloc(env->allocator_, width_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->x_lefts_); + crop_and_resize->x_rights_ = (int *)env->Alloc(env->allocator_, width_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->x_rights_); + crop_and_resize->x_left_weights_ = (float *)env->Alloc(env->allocator_, width_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->x_left_weights_); + + int c = NNACLGetChannel(input_tensor); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(crop_and_resize->new_width_, c, NNACL_ERR); + int new_wc = crop_and_resize->new_width_ * c; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(new_wc, crop_and_resize->mapped_point_num_, NNACL_ERR); + int total_point_num = new_wc * crop_and_resize->mapped_point_num_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_point_num, crop_and_resize->base_.thread_nr_, NNACL_ERR); + int line_buffer_size = total_point_num * crop_and_resize->base_.thread_nr_ * sizeof(float); + crop_and_resize->line_buffer_ = (float *)env->Alloc(env->allocator_, line_buffer_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->line_buffer_); + return NNACL_OK; +} + +void CropAndResizeFreeTmpBuffer(CropAndResizeStruct *crop_and_resize) { + ExecEnv *env = crop_and_resize->base_.env_; + NNACL_CHECK_NULL_RETURN_VOID(env); + env->Free(env->allocator_, crop_and_resize->y_bottoms_); + env->Free(env->allocator_, crop_and_resize->y_tops_); + env->Free(env->allocator_, crop_and_resize->y_bottom_weights_); + env->Free(env->allocator_, crop_and_resize->x_lefts_); + env->Free(env->allocator_, crop_and_resize->x_rights_); + env->Free(env->allocator_, crop_and_resize->x_left_weights_); + env->Free(env->allocator_, crop_and_resize->line_buffer_); + crop_and_resize->y_bottoms_ = NULL; + crop_and_resize->y_tops_ = NULL; + crop_and_resize->y_bottom_weights_ = NULL; + crop_and_resize->x_lefts_ = NULL; + crop_and_resize->x_rights_ = NULL; + crop_and_resize->x_left_weights_ = NULL; + crop_and_resize->line_buffer_ = NULL; +} + +int CropAndResizeImpl(void *cdata, int task_id, float l, float r) { + CropAndResizeStruct *crop_and_resize = (CropAndResizeStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(crop_and_resize); + + TensorC *input = crop_and_resize->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *boxes = crop_and_resize->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(boxes); + TensorC *box_idx = crop_and_resize->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(box_idx); + TensorC *output = crop_and_resize->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + int unit = UP_DIV(crop_and_resize->new_height_, crop_and_resize->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(unit, task_id, NNACL_ERR); + int h_begin = unit * task_id; + int h_end = MSMIN(h_begin + unit, crop_and_resize->new_height_); + if (h_end <= h_begin) { + return NNACL_OK; + } + + float extrapolation_value = ((CropAndResizeParameter *)crop_and_resize->base_.param_)->extrapolation_value_; + int c = input->shape_[kNHWC_C]; + float *line0 = crop_and_resize->line_buffer_ + crop_and_resize->new_width_ * c * 2 * task_id; + float *line1 = line0 + crop_and_resize->new_width_ * c; + + return CropAndResizeBilinear((float *)input->data_, (float *)output->data_, (int32_t *)box_idx->data_, + (float *)boxes->data_, extrapolation_value, input->shape_, output->shape_, + crop_and_resize->y_bottoms_, crop_and_resize->y_tops_, crop_and_resize->x_lefts_, + crop_and_resize->x_rights_, crop_and_resize->y_bottom_weights_, + crop_and_resize->x_left_weights_, line0, line1, h_begin, h_end); +} + +int CropAndResizeCompute(struct KernelBase *self) { + CropAndResizeStruct *crop_and_resize = (CropAndResizeStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(crop_and_resize); + + // In Prepare() stage, in_tensor[0] may be of fp16 data type in fp16 mode, so move type checks here. + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *boxes_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(boxes_tensor); + TensorC *boxidx_tensor = self->in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(boxidx_tensor); + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + + int ret = CropAndResizeMallocTmpBuffer(crop_and_resize); + if (ret != NNACL_OK) { + CropAndResizeFreeTmpBuffer(crop_and_resize); + return ret; + } + + float *boxes = (float *)boxes_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(boxes); + int32_t *box_idx = (int32_t *)boxidx_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(box_idx); + + if (CheckCropAndResizeBoxIdx(box_idx, boxes_tensor->shape_[Index0], NNACLGetBatch(input_tensor)) != NNACL_OK) { + return NNACL_CROP_AND_RESIZE_BOX_IDX_INVALID; + } + + ret = PrepareCropAndResizeBilinear(input_tensor->shape_, boxes, box_idx, output_tensor->shape_, + crop_and_resize->y_bottoms_, crop_and_resize->y_tops_, crop_and_resize->x_lefts_, + crop_and_resize->x_rights_, crop_and_resize->y_bottom_weights_, + crop_and_resize->x_left_weights_); + if (ret != NNACL_OK) { + CropAndResizeFreeTmpBuffer(crop_and_resize); + return ret; + } + + int error_code = self->env_->ParallelLaunch(self->env_->thread_pool_, CropAndResizeImpl, self, self->thread_nr_); + CropAndResizeFreeTmpBuffer(crop_and_resize); + return error_code; +} + +int CropAndResizeResize(KernelBase *self) { + CropAndResizeStruct *crop_and_resize = (CropAndResizeStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(crop_and_resize); + TensorC *output = self->out_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(output); + crop_and_resize->new_height_ = output->shape_[Index1]; + crop_and_resize->new_width_ = output->shape_[Index2]; + return NNACL_OK; +} + +int CropAndResizePrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + return NNACL_OK; +} + +KernelBase *CreateCropAndResize(OpParameter *param, int data_type) { + CropAndResizeStruct *crop_and_resize = (CropAndResizeStruct *)malloc(sizeof(CropAndResizeStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(crop_and_resize); + memset(crop_and_resize, 0, sizeof(CropAndResizeStruct)); + crop_and_resize->mapped_point_num_ = Num2; + crop_and_resize->base_.Prepare = CropAndResizePrepare; + crop_and_resize->base_.Resize = CropAndResizeResize; + crop_and_resize->base_.Compute = CropAndResizeCompute; + crop_and_resize->base_.Release = DefaultRelease; + return (KernelBase *)crop_and_resize; +} + +REG_KERNEL_CREATOR(PrimType_CropAndResize, kNumberTypeFloat32, CreateCropAndResize) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop_and_resize.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop_and_resize.h new file mode 100644 index 0000000000000000000000000000000000000000..6d5a0d19c8507c6c00324e0221876de58595ecef --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop_and_resize.h @@ -0,0 +1,41 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CROP_AND_RESIZE_H_ +#define NNACL_KERNEL_CROP_AND_RESIZE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct { + KernelBase base_; + int mapped_point_num_; + int batch_; + int new_height_; + int new_width_; + int *y_tops_; + int *y_bottoms_; + int *x_lefts_; + int *x_rights_; + float *y_bottom_weights_; + float *x_left_weights_; + float *line_buffer_; +} CropAndResizeStruct; + +KernelBase *CreateCropAndResize(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_CROP_AND_RESIZE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution.c new file mode 100644 index 0000000000000000000000000000000000000000..ce7a66a7b5a5bc80d5ad5bf6ca2a0a0525891300 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution.c @@ -0,0 +1,337 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/deconvolution.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/kernel/deconvolution_winograd.h" +#include "nnacl_c/kernel/deconvolution_depthwise.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/deconv_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_avx_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +int DeConvMallocWeightBiasData(ConvolutionBaseStruct *conv) { + int output_aligned_size = UP_ROUND(conv->compute_.out_c_, C8NUM) * sizeof(float); + size_t pack_weight_size = conv->compute_.in_c_ * conv->compute_.kernel_hw_ * output_aligned_size; + if (!conv->base_.train_session_) { + conv->packed_weight_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, pack_weight_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + if (conv->bias_data_ == NULL) { + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, output_aligned_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, output_aligned_size); + return NNACL_OK; +} + +void DeConvPackWeight(ConvolutionBaseStruct *conv) { + TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(weight_tensor); + void *weight_data = weight_tensor->data_ == NULL ? conv->origin_weight_ : weight_tensor->data_; + NNACL_CHECK_NULL_RETURN_VOID(weight_data); + +#ifdef ENABLE_AVX + PackNHWCToCXHWNXFp32((float *)weight_data, (float *)conv->packed_weight_, conv->compute_.in_c_, + conv->compute_.kernel_hw_, conv->compute_.out_c_); +#else + PackNHWCToC8HWN8Fp32((float *)weight_data, (float *)conv->packed_weight_, conv->compute_.in_c_, + conv->compute_.kernel_hw_, conv->compute_.out_c_); +#endif +} + +int DeConvInitParam(DeConvStruct *deconv) { + ConvComputeParam *compute = &deconv->conv_.compute_; + deconv->matmul_.row_ = compute->in_hw_; + deconv->matmul_.deep_ = compute->in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_c_, compute->kernel_hw_, NNACL_ERR); + deconv->matmul_.col_ = compute->out_c_ * compute->kernel_hw_; + deconv->matmul_.row_align_ = UP_ROUND(deconv->matmul_.row_, deconv->matmul_.row_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(UP_ROUND(compute->out_c_, C8NUM), compute->kernel_hw_, NNACL_ERR); + deconv->matmul_.col_align_ = UP_ROUND(compute->out_c_, C8NUM) * compute->kernel_hw_; + + deconv->conv_.base_.thread_nr_ = NNACL_MIN(deconv->conv_.base_.thread_nr_, UP_DIV(compute->out_c_, C8NUM)); + NNACL_CHECK_ZERO_RETURN_ERR(deconv->conv_.base_.thread_nr_); +#ifdef ENABLE_AVX + deconv->thread_stride_ = UP_DIV(UP_DIV(compute->out_c_, C8NUM * C3NUM), deconv->conv_.base_.thread_nr_) * C3NUM; +#else + deconv->thread_stride_ = UP_DIV(UP_DIV(compute->out_c_, C8NUM), deconv->conv_.base_.thread_nr_); +#endif + return NNACL_OK; +} + +int DeConvRun(void *cdata, int task_id, float l, float r) { + DeConvStruct *deconv = (DeConvStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + + int total_thead_stride_ = task_id * deconv->thread_stride_; + int res_stride = UP_DIV(deconv->conv_.compute_.out_c_, C8NUM) - total_thead_stride_; + int oc = NNACL_MIN(deconv->thread_stride_, res_stride); + int cur_stride = deconv->thread_stride_ * C8NUM; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_, C8NUM, NNACL_ERR); + int total_thead_stride_c8 = total_thead_stride_ * C8NUM; + res_stride = deconv->conv_.compute_.out_c_ - total_thead_stride_c8; + int oc_res = NNACL_MIN(cur_stride, res_stride); + if (oc <= 0 || oc_res <= 0) { + return NNACL_OK; + } + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_c8, deconv->conv_.compute_.kernel_hw_, NNACL_ERR); + int plane_thead_stride_c8 = total_thead_stride_c8 * deconv->conv_.compute_.kernel_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(plane_thead_stride_c8, deconv->matmul_.row_align_, NNACL_ERR); + int row_c8 = plane_thead_stride_c8 * deconv->matmul_.row_align_; + float *tmp_buffer = deconv->tmp_buffer_ + row_c8; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(plane_thead_stride_c8, deconv->matmul_.deep_, NNACL_ERR); + int deep_c8 = plane_thead_stride_c8 * deconv->matmul_.deep_; + +#ifdef ENABLE_AVX + DeconvMatmulAvx(deconv->pack_input_, (float *)deconv->conv_.packed_weight_ + deep_c8, tmp_buffer, + deconv->matmul_.deep_, deconv->matmul_.row_align_, oc * C8NUM * deconv->conv_.compute_.kernel_hw_, + deconv->conv_.compute_.kernel_hw_); +#elif ENABLE_SSE + DeconvMatmulFloatSse(deconv->pack_input_, (float *)deconv->conv_.packed_weight_ + deep_c8, tmp_buffer, + deconv->matmul_.deep_, deconv->matmul_.row_align_, + oc * C8NUM * deconv->conv_.compute_.kernel_hw_); +#else + MatMulOpt(deconv->pack_input_, (float *)deconv->conv_.packed_weight_ + deep_c8, tmp_buffer, NULL, ActType_No, + deconv->matmul_.deep_, deconv->matmul_.row_align_, oc * C8NUM * deconv->conv_.compute_.kernel_hw_, + deconv->matmul_.col_, OutType_C8); +#endif + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_c8, deconv->conv_.compute_.out_hw_, NNACL_OK); + DeConvPostFp32C8(tmp_buffer, deconv->pack_output_ + total_thead_stride_c8 * deconv->conv_.compute_.out_hw_, + (float *)deconv->conv_.bias_data_ + total_thead_stride_c8, + deconv->output_ptr_ + total_thead_stride_c8, oc_res, (ConvParameter *)deconv->conv_.base_.param_); + return NNACL_OK; +} + +void DeConvFreeRunBuf(DeConvStruct *deconv) { + ExecEnv *env = deconv->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_VOID(env); + + if (deconv->pack_output_ != NULL) { + env->Free(env->allocator_, deconv->pack_output_); + deconv->pack_output_ = NULL; + } + if (deconv->tmp_buffer_ != NULL) { + env->Free(env->allocator_, deconv->tmp_buffer_); + deconv->tmp_buffer_ = NULL; + } + if (deconv->pack_input_ != NULL) { + env->Free(env->allocator_, deconv->pack_input_); + deconv->pack_input_ = NULL; + } +} + +int DeConvInitRunBuf(DeConvStruct *deconv) { + ExecEnv *env = deconv->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + int pack_output_size = UP_ROUND(deconv->conv_.compute_.out_c_, C8NUM) * deconv->conv_.compute_.out_hw_; + deconv->pack_output_ = (float *)env->Alloc(env->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->pack_output_); + + int tmp_buffer_size = deconv->matmul_.row_align_ * deconv->matmul_.col_align_; + deconv->tmp_buffer_ = (float *)env->Alloc(env->allocator_, tmp_buffer_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->tmp_buffer_); + + int pack_input_size = deconv->matmul_.row_align_ * deconv->matmul_.deep_; + deconv->pack_input_ = (float *)env->Alloc(env->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->pack_input_); + + return NNACL_OK; +} + +int DeConvCheckvResizeValid(ConvolutionBaseStruct *conv) { + // ===============check in channel================= // + TensorC *input_tensor = conv->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT]; + + int resize_out_channel = NNACLGetChannel(input_tensor); + int filter_out_channel = NNACLGetBatch(filter_tensor); + if (filter_out_channel != resize_out_channel) { + return NNACL_DECONV_RESIZE_OC_INVALID; + } + return NNACL_OK; +} + +int DeConvResize(KernelBase *self) { + DeConvStruct *deconv = (DeConvStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + + (void)ConvBaseUpdateComputeInfo(&deconv->conv_); + + int ret = DeConvCheckvResizeValid(&deconv->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvBasePrepare(&deconv->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = DeConvInitParam(deconv); + if (ret != NNACL_OK) { + return ret; + } + + return NNACL_OK; +} + +int DeConvCompute(KernelBase *self) { + DeConvStruct *deconv = (DeConvStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + + int error_code = ConvBaseRepackWeight(&deconv->conv_); + if (error_code != NNACL_OK) { + return error_code; + } + + error_code = DeConvInitRunBuf(deconv); + if (error_code != NNACL_OK) { + DeConvFreeRunBuf(deconv); + return error_code; + } + + float *src_in = (float *)self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_in); + float *src_out = (float *)self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_out); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.compute_.in_n_ - 1, deconv->conv_.compute_.in_c_, NNACL_ERR); + int input_bc = (deconv->conv_.compute_.in_n_ - 1) * deconv->conv_.compute_.in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.compute_.in_hw_, input_bc, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.compute_.out_hw_, input_bc, NNACL_ERR); + for (int batch_index = 0; batch_index < deconv->conv_.compute_.in_n_; batch_index++) { + deconv->input_ptr_ = src_in + batch_index * deconv->conv_.compute_.in_hw_ * deconv->conv_.compute_.in_c_; + deconv->output_ptr_ = src_out + batch_index * deconv->conv_.compute_.out_hw_ * deconv->conv_.compute_.out_c_; + +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) + RowMajor2Col4Major(deconv->input_ptr_, deconv->pack_input_, deconv->matmul_.row_, deconv->matmul_.deep_); +#else + RowMajor2Col12Major(deconv->input_ptr_, deconv->pack_input_, deconv->matmul_.row_, deconv->matmul_.deep_); +#endif + + error_code = self->env_->ParallelLaunch(self->env_->thread_pool_, DeConvRun, self, self->thread_nr_); + if (error_code != NNACL_OK) { + DeConvFreeRunBuf(deconv); + return error_code; + } + } + + DeConvFreeRunBuf(deconv); + return NNACL_OK; +} + +int DeConvPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + DeConvStruct *deconv = (DeConvStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + ConvParameter *param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + // There could be weight dataType casting before Prepare, thus weight update is required. + ConvBaseUpdateOriginWeightAndBias(&deconv->conv_); + +#if defined(ENABLE_ARM32) || defined(ENABLE_AVX) || defined(ENABLE_SSE) + deconv->matmul_.row_tile_ = C4NUM; +#else + deconv->matmul_.row_tile_ = C12NUM; +#endif + + if (self->train_session_) { + int output_aligned_size = UP_ROUND(deconv->conv_.compute_.out_c_, C8NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.compute_.in_c_, deconv->conv_.compute_.kernel_hw_, NNACL_ERR); + int kernel_chw = deconv->conv_.compute_.in_c_ * deconv->conv_.compute_.kernel_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, output_aligned_size, NNACL_ERR); + size_t pack_weight_size = kernel_chw * output_aligned_size * sizeof(float); + self->work_size_ = pack_weight_size; + } + + if (self->in_[SECOND_INPUT]->data_ != NULL) { + int error_code = ConvBaseInitConvWeightBias(&deconv->conv_); + if (error_code != NNACL_OK) { + return error_code; + } + } else { + deconv->conv_.is_repack_ = true; + } + + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateDeConv(ConvParameter *param) { + DeConvStruct *deconv = (DeConvStruct *)malloc(sizeof(DeConvStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(deconv); + memset(deconv, 0, sizeof(DeConvStruct)); + deconv->conv_.malloc_weight_bias_ = DeConvMallocWeightBiasData; + deconv->conv_.pack_weight_ = DeConvPackWeight; + deconv->conv_.base_.Prepare = DeConvPrepare; + deconv->conv_.base_.Resize = DeConvResize; + deconv->conv_.base_.Release = DefaultRelease; + deconv->conv_.base_.Compute = DeConvCompute; + return &deconv->conv_; +} + +ConvolutionBaseStruct *SelectDeConv(ConvParameter *conv_param) { +#ifndef _WIN32 +#ifndef ENABLE_MCU + bool param_winograd_fit = (conv_param->stride_h_ > 1 || conv_param->stride_w_ > 1) && + (conv_param->dilation_w_ == 1 && conv_param->dilation_h_ == 1); + +#ifdef ENABLE_AVX + bool in_size_winograd_fit = conv_param->input_w_ * conv_param->input_h_ >= NNACL_DECONV_WINOGRAD_HW_MAX; + bool size_winograd_fit = (conv_param->kernel_w_ / conv_param->stride_w_ >= C2NUM || + conv_param->kernel_h_ / conv_param->stride_h_ >= C2NUM || conv_param->output_channel_ == 1); +#else + bool in_size_winograd_fit = true; + bool size_winograd_fit = + (conv_param->kernel_w_ / conv_param->stride_w_ > C2NUM || conv_param->kernel_h_ / conv_param->stride_h_ > C2NUM); +#endif + + if (param_winograd_fit && size_winograd_fit && in_size_winograd_fit) { + ConvolutionBaseStruct *kernel = CreateDeConvWinograd(conv_param); + if (kernel != NULL) { + return kernel; + } + } +#endif +#endif + + return CreateDeConv(conv_param); +} + +KernelBase *CreateConvolutionTranspose(OpParameter *param, int data_type) { + ConvParameter *conv_param = (ConvParameter *)param; + NNACL_CHECK_NULL_RETURN_NULL(conv_param); + + ConvolutionBaseStruct *conv = NULL; + if (conv_param->group_ == 1 && conv_param->input_channel_ == 1 && conv_param->output_channel_ == 1) { + conv = CreateDeConvDw(conv_param); + } else if (conv_param->group_ == 1) { + conv = SelectDeConv(conv_param); + } else if (conv_param->group_ == conv_param->input_channel_ && conv_param->group_ == conv_param->output_channel_) { + conv = CreateDeConvDw(conv_param); + } + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv); + ConvBaseUpdateParamInfo(&conv->compute_, conv_param); + return &conv->base_; +} + +REG_KERNEL_CREATOR(PrimType_Conv2dTransposeFusion, kNumberTypeFloat32, CreateConvolutionTranspose) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution.h new file mode 100644 index 0000000000000000000000000000000000000000..a7f773a61f66afcfb7f5e146ca5ec910ab3fb8a5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution.h @@ -0,0 +1,39 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_DECONVOLUTION_H_ +#define NNACL_KERNEL_DECONVOLUTION_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/kernel/matmul_struct.h" + +typedef struct DeConvStruct { + ConvolutionBaseStruct conv_; + MatmulComputeParam matmul_; + int thread_stride_; + float *pack_input_; + float *pack_output_; + float *tmp_buffer_; + float *input_ptr_; + float *output_ptr_; +} DeConvStruct; + +int DeConvCheckvResizeValid(ConvolutionBaseStruct *conv); +KernelBase *CreateConvolutionTranspose(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_DECONVOLUTION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_depthwise.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_depthwise.c new file mode 100644 index 0000000000000000000000000000000000000000..f612886b86947f12c59e23d16380c177c203619d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_depthwise.c @@ -0,0 +1,233 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/deconvolution_depthwise.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/kernel/deconvolution.h" + +int DeConvDwInitPackedInputOutput(DeConvDwStruct *deconv_dw) { + if (!deconv_dw->need_align_) { + return NNACL_OK; + } + ExecEnv *env = deconv_dw->conv_.base_.env_; + ConvComputeParam *compute = &deconv_dw->conv_.compute_; + + int ic4 = UP_ROUND(compute->in_c_, compute->tile_num_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_n_, compute->in_hw_, NNACL_ERR); + int input_bhw = compute->in_n_ * compute->in_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input_bhw, ic4, NNACL_ERR); + int pack_input_size = input_bhw * ic4; + deconv_dw->packed_input_ = (float *)env->Alloc(env->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv_dw->packed_input_); + + int oc4 = UP_ROUND(compute->out_c_, compute->tile_num_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_n_, compute->out_hw_, NNACL_ERR); + int output_bhw = compute->out_n_ * compute->out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, oc4, NNACL_ERR); + int pack_output_size = output_bhw * oc4; + deconv_dw->packed_output_ = (float *)env->Alloc(env->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv_dw->packed_output_); + memset(deconv_dw->packed_output_, 0, pack_output_size * sizeof(float)); + + return NNACL_OK; +} + +int DeconvDwRun(void *cdata, int task_id, float l, float r) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw); + DeconvDwSWFp32(deconv_dw->packed_output_, deconv_dw->packed_input_, (float *)deconv_dw->conv_.packed_weight_, + (float *)deconv_dw->conv_.bias_data_, (ConvParameter *)deconv_dw->conv_.base_.param_, + &deconv_dw->sliding_, task_id); + return NNACL_OK; +} + +int DeConvDwMallocWeightBiasData(ConvolutionBaseStruct *conv) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw); + + int oc4 = UP_ROUND(conv->compute_.out_c_, conv->compute_.tile_num_); + if (!conv->base_.train_session_) { + int pack_weight_size = oc4 * conv->compute_.kernel_hw_; + NNACL_CHECK_MALLOC_SIZE(pack_weight_size); + deconv_dw->conv_.packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv_dw->conv_.packed_weight_); + } + + if (deconv_dw->conv_.bias_data_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(oc4 * sizeof(float)); + deconv_dw->conv_.bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, oc4 * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv_dw->conv_.bias_data_); + } + memset(deconv_dw->conv_.bias_data_, 0, oc4 * sizeof(float)); + return NNACL_OK; +} + +void DeConvDwPackWeight(ConvolutionBaseStruct *conv) { + void *origin_weight = conv->base_.train_session_ ? conv->base_.in_[SECOND_INPUT]->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + PackNCHWToNC4HW4Fp32(origin_weight, conv->packed_weight_, 1, conv->compute_.kernel_hw_, conv->compute_.out_c_); +} + +void DeConvDwFreePackedInputOutput(DeConvDwStruct *deconv_dw) { + if (deconv_dw->need_align_) { + ExecEnv *env = deconv_dw->conv_.base_.env_; + + env->Free(env->allocator_, deconv_dw->packed_input_); + deconv_dw->packed_input_ = NULL; + env->Free(env->allocator_, deconv_dw->packed_output_); + deconv_dw->packed_output_ = NULL; + } +} + +int DeConvDwPrepare(KernelBase *self) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw); + ConvComputeParam *compute = &deconv_dw->conv_.compute_; + deconv_dw->conv_.compute_.tile_num_ = C4NUM; + + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + + NNACL_CHECK_FALSE(compute->in_c_ != compute->out_c_, NNACL_DECONVOLUTION_DEPTHWISE_CHANNEL_INVALID); + NNACL_CHECK_FALSE(compute->dilation_h_ != Num1, NNACL_DECONVOLUTION_DEPTHWISE_DILATION_INVALID); + NNACL_CHECK_FALSE(compute->dilation_w_ != Num1, NNACL_DECONVOLUTION_DEPTHWISE_DILATION_INVALID); + + ConvBaseUpdateOriginWeightAndBias(&deconv_dw->conv_); + + if (self->train_session_) { + int oc4 = UP_ROUND(compute->out_c_, compute->tile_num_); + int pack_weight_size = oc4 * compute->kernel_hw_; + self->work_size_ = pack_weight_size; + } + + int ret = ConvBaseInitConvWeightBias(&deconv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw->conv_.packed_weight_); + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw->conv_.bias_data_); + return NNACL_OK; +} + +void DeConvDwUpdateParam(ConvolutionBaseStruct *conv) { + TensorC *input = conv->base_.in_[FIRST_INPUT]; + TensorC *output = conv->base_.out_[OUTPUT_INDEX]; + + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + conv_param->thread_num_ = conv->base_.thread_nr_; + conv_param->input_batch_ = NNACLGetBatch(output); + conv_param->input_h_ = NNACLGetHeight(output); + conv_param->input_w_ = NNACLGetWidth(output); + conv_param->input_channel_ = NNACLGetChannel(output); + conv_param->output_batch_ = NNACLGetBatch(input); + conv_param->output_h_ = NNACLGetHeight(input); + conv_param->output_w_ = NNACLGetWidth(input); + conv_param->output_channel_ = NNACLGetChannel(input); + + ConvComputeParam *compute = &conv->compute_; + compute->in_n_ = NNACLGetBatch(output); + compute->in_h_ = NNACLGetHeight(output); + compute->in_w_ = NNACLGetWidth(output); + compute->in_c_ = NNACLGetChannel(output); + compute->out_n_ = NNACLGetBatch(input); + compute->out_h_ = NNACLGetHeight(input); + compute->out_w_ = NNACLGetWidth(input); + compute->out_c_ = NNACLGetChannel(input); +} + +int DeConvDwResize(KernelBase *self) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw); + + (void)ConvBaseUpdateComputeInfo(&deconv_dw->conv_); + + int ret = DeConvCheckvResizeValid(&deconv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + int tile_num = deconv_dw->conv_.compute_.tile_num_; + DeConvDwUpdateParam(&deconv_dw->conv_); + (void)InitSlidingParamConvDw(&deconv_dw->sliding_, (ConvParameter *)self->param_, tile_num); + self->thread_nr_ = NNACL_MIN(self->thread_nr_, UP_DIV(deconv_dw->conv_.compute_.out_c_, tile_num)); + deconv_dw->need_align_ = deconv_dw->conv_.compute_.in_c_ % tile_num != 0; + + ret = ConvBasePrepare(&deconv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +int DeConvDwCompute(KernelBase *self) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw); + ConvComputeParam *compute = &deconv_dw->conv_.compute_; + + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + float *in_data = (float *)in_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(in_data); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + float *out_data = (float *)out_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + + int ret = ConvBaseRepackWeight(&deconv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = DeConvDwInitPackedInputOutput(deconv_dw); + if (ret != NNACL_OK) { + DeConvDwFreePackedInputOutput(deconv_dw); + return ret; + } + + if (deconv_dw->need_align_) { + PackNHWCToNHWC4Fp32(in_data, deconv_dw->packed_input_, compute->in_n_, compute->in_hw_, compute->in_c_); + } else { + deconv_dw->packed_input_ = in_data; + deconv_dw->packed_output_ = out_data; + memset(deconv_dw->packed_output_, 0, NNACLGetSize(out_tensor)); + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, DeconvDwRun, self, self->thread_nr_); + + if (deconv_dw->need_align_) { + PackNHWCXToNHWCFp32(deconv_dw->packed_output_, out_data, compute->out_n_, compute->out_hw_, compute->out_c_, + compute->tile_num_); + } + DeConvDwFreePackedInputOutput(deconv_dw); + return ret; +} + +ConvolutionBaseStruct *CreateDeConvDw(ConvParameter *param) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)malloc(sizeof(DeConvDwStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(deconv_dw); + memset(deconv_dw, 0, sizeof(DeConvDwStruct)); + + deconv_dw->conv_.pack_weight_ = DeConvDwPackWeight; + deconv_dw->conv_.malloc_weight_bias_ = DeConvDwMallocWeightBiasData; + deconv_dw->conv_.base_.Prepare = DeConvDwPrepare; + deconv_dw->conv_.base_.Resize = DeConvDwResize; + deconv_dw->conv_.base_.Release = DefaultRelease; + deconv_dw->conv_.base_.Compute = DeConvDwCompute; + return &deconv_dw->conv_; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_depthwise.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_depthwise.h new file mode 100644 index 0000000000000000000000000000000000000000..b929109e4be501de47bc96d2068008e2bdfe91de --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_depthwise.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_DECONVOLUTION_DEPTHWISE_H_ +#define NNACL_KERNEL_DECONVOLUTION_DEPTHWISE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct DeConvDwStruct { + ConvolutionBaseStruct conv_; + SlidingWindowParam sliding_; + bool need_align_; + float *packed_input_; + float *packed_output_; +} DeConvDwStruct; + +ConvolutionBaseStruct *CreateDeConvDw(ConvParameter *param); + +#endif // NNACL_KERNEL_DECONVOLUTION_DEPTHWISE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_winograd.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_winograd.c new file mode 100644 index 0000000000000000000000000000000000000000..727cbc1965f11f1218125fb7631c0cb703346349 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_winograd.c @@ -0,0 +1,551 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef _WIN32 +#ifndef ENABLE_MCU +#include "nnacl_c/kernel/deconvolution_winograd.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/deconv_winograd_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/kernel/deconvolution.h" + +void DeConvWinogradFreeResizeBuf(DeConvWinogradStruct *deconv) { + DeConvParam *param = &deconv->param_; + + for (int i = 0; i < param->compute_size_; i++) { + DeConvComputeUnit *unit = ¶m->compute_units_[i]; + if (unit->tmp_buffer_ != NULL) { + free(unit->tmp_buffer_); + unit->tmp_buffer_ = NULL; + } + + if (unit->use_winograd_) { + if (unit->winograd_.b_buffer_ != NULL) { + free(unit->winograd_.b_buffer_); + unit->winograd_.b_buffer_ = NULL; + } + } + } + + for (int i = 0; i < DECONV_WINOGRAD_BUFFER_COUNT; i++) { + DeConvWgABuffer *wg = ¶m->a_buffer_[i]; + if (wg->buf_init_) { + if (wg->dest_buffer_ != NULL) { + free(wg->dest_buffer_); + wg->dest_buffer_ = NULL; + } + if (wg->middle_buffer_ != NULL) { + free(wg->middle_buffer_); + wg->middle_buffer_ = NULL; + } + } + wg->buf_init_ = false; + } + + if (deconv->tile_input_ != NULL) { + free(deconv->tile_input_); + deconv->tile_input_ = NULL; + } +} + +void DeConvWinogradFreeDeconvParam(DeConvWinogradStruct *deconv) { + DeConvParam *param = &deconv->param_; + + for (int i = 0; i < param->compute_size_; i++) { + DeConvComputeUnit *unit = ¶m->compute_units_[i]; + + if (unit->weight_ != NULL) { + free(unit->weight_); + unit->weight_ = NULL; + } + + if (unit->use_winograd_) { + if (unit->winograd_.AT_ != NULL) { + free(unit->winograd_.AT_); + unit->winograd_.AT_ = NULL; + } + if (unit->winograd_.BT_ != NULL) { + free(unit->winograd_.BT_); + unit->winograd_.BT_ = NULL; + } + } + } + + if (param->compute_units_ != NULL) { + free(param->compute_units_); + param->compute_units_ = NULL; + } +} + +int DeConvWinogradInitParameter(DeConvWinogradStruct *deconv) { + DeConvParam *param = &deconv->param_; + ConvComputeParam *compute = &deconv->conv_.compute_; + + int thread_num = deconv->conv_.base_.thread_nr_; + NNACL_CHECK_ZERO_RETURN_ERR(thread_num); + + param->input_plane_ = compute->in_hw_; + param->output_plane_ = compute->out_hw_; + + param->in_tile_w_count_ = UP_DIV(compute->in_w_, WINOGRAD_DEFAULT_UNIT); + NNACL_CHECK_ZERO_RETURN_ERR(param->in_tile_w_count_); + param->in_tile_h_count_ = UP_DIV(compute->in_h_, WINOGRAD_DEFAULT_UNIT); + NNACL_CHECK_ZERO_RETURN_ERR(param->in_tile_h_count_); + param->in_tile_count_ = UP_DIV(param->in_tile_w_count_ * param->in_tile_h_count_, WINOGRAD_DEFAULT_TILE); + + deconv->conv_.base_.thread_nr_ = NNACL_MAX(1, deconv->conv_.base_.thread_nr_); + deconv->conv_.base_.thread_nr_ = NNACL_MIN(deconv->conv_.base_.thread_nr_, param->in_tile_count_); + + deconv->thread_num_hw_ = NNACL_MIN(deconv->conv_.base_.thread_nr_, compute->out_hw_); + NNACL_CHECK_ZERO_RETURN_ERR(deconv->thread_num_hw_); + deconv->thread_stride_hw_ = UP_DIV(compute->out_hw_, deconv->thread_num_hw_); + + int total_ic_up = WINOGRAD_DEFAULT_UNIT * WINOGRAD_DEFAULT_UNIT * WINOGRAD_DEFAULT_TILE * param->ic_up_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.base_.thread_nr_, total_ic_up, NNACL_ERR); + int size = deconv->conv_.base_.thread_nr_ * total_ic_up; + NNACL_CHECK_MALLOC_SIZE(size * sizeof(float)); + deconv->tile_input_ = (float *)malloc(size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->tile_input_); + (void)memset(deconv->tile_input_, 0, size * sizeof(float)); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW((WINOGRAD_DEFAULT_UNIT - 1), compute->stride_w_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW((WINOGRAD_DEFAULT_UNIT - 1), compute->stride_h_, NNACL_ERR); + param->out_tile_w_ = (WINOGRAD_DEFAULT_UNIT - 1) * compute->stride_w_ + compute->kernel_w_; + param->out_tile_h_ = (WINOGRAD_DEFAULT_UNIT - 1) * compute->stride_h_ + compute->kernel_h_; + + for (int i = 0; i < param->compute_size_; i++) { + DeConvComputeUnit *unit = ¶m->compute_units_[i]; + if (unit->use_winograd_) { + if (!param->a_buffer_[unit->winograd_.kh_].buf_init_) { + param->a_buffer_[unit->winograd_.kh_].buf_init_ = true; + size = unit->winograd_.kh_ * unit->winograd_.kw_ * WINOGRAD_DEFAULT_TILE * param->ic_up_; + + param->a_buffer_[unit->winograd_.kh_].middle_buffer_ = malloc(thread_num * size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(param->a_buffer_[unit->winograd_.kh_].middle_buffer_); + + param->a_buffer_[unit->winograd_.kh_].dest_buffer_ = malloc(thread_num * size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(param->a_buffer_[unit->winograd_.kh_].dest_buffer_); + } + + size = unit->winograd_.kh_ * unit->winograd_.kw_ * param->oc_up_ * WINOGRAD_DEFAULT_TILE; + unit->winograd_.b_buffer_ = malloc(thread_num * size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(unit->winograd_.b_buffer_); + + size = unit->winograd_.kh_ * unit->winograd_.kw_ * param->oc_div_ * WINOGRAD_DEFAULT_TILE * compute->tile_num_; + unit->tmp_buffer_ = malloc(thread_num * size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(unit->tmp_buffer_); + } else { + size = param->oc_div_ * unit->w_size_ * unit->h_size_ * WINOGRAD_DEFAULT_TILE * compute->tile_num_; + unit->tmp_buffer_ = malloc(thread_num * size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(unit->tmp_buffer_); + } + } + + return NNACL_OK; +} + +int DeConvWgFp32Run(void *cdata, int task_id, float l, float r) { + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + ConvParameter *conv_param = (ConvParameter *)deconv->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + DeConvParam *param = &deconv->param_; + ConvComputeParam *compute = &deconv->conv_.compute_; + + for (int tile_index = task_id; tile_index < param->in_tile_count_; tile_index += deconv->conv_.base_.thread_nr_) { + int size = WINOGRAD_DEFAULT_UNIT * WINOGRAD_DEFAULT_UNIT * WINOGRAD_DEFAULT_TILE * param->ic_up_; + float *tile_in = deconv->tile_input_ + task_id * size; + size = param->out_tile_w_ * param->out_tile_h_ * WINOGRAD_DEFAULT_TILE * param->oc_div_ * compute->tile_num_; + float *tile_out = deconv->tile_output_ + task_id * size; + (void)memset(tile_out, 0, size * sizeof(float)); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(tile_index, WINOGRAD_DEFAULT_TILE, NNACL_ERR); + int start_index = tile_index * WINOGRAD_DEFAULT_TILE; + int cal_count = NNACL_MIN(WINOGRAD_DEFAULT_TILE, param->in_tile_w_count_ * param->in_tile_h_count_ - start_index); + + int ret = DeconvWg(deconv->nhwc_input_, tile_in, tile_out, start_index, cal_count, conv_param, param, task_id); + if (ret != NNACL_OK) { + return ret; + } + + (void)pthread_mutex_lock(&deconv->lock_); + (void)DeconvWgPost(tile_out, deconv->nc4hw4_output_, conv_param, param, cal_count, tile_index); + (void)pthread_mutex_unlock(&deconv->lock_); + } + return NNACL_OK; +} + +int DeConvWgPostFp32Run(void *cdata, int task_id, float l, float r) { + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + ConvComputeParam *compute = &deconv->conv_.compute_; + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, deconv->thread_stride_hw_, NNACL_ERR); + int output_stride_plane = task_id * deconv->thread_stride_hw_; + int rest_plane = compute->out_hw_ - output_stride_plane; + int current_plane = MSMIN(rest_plane, deconv->thread_stride_hw_); + if (current_plane <= 0) { + return NNACL_OK; + } + + ActType act = ((ConvParameter *)deconv->conv_.base_.param_)->act_type_; + float *bias = (float *)deconv->conv_.bias_data_; + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_stride_plane, deconv->conv_.compute_.tile_num_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_stride_plane, deconv->conv_.compute_.out_c_, NNACL_ERR); + WinogradPostConvFuncFp32CX(deconv->nc4hw4_output_ + output_stride_plane * compute->tile_num_, + deconv->nhwc_output_ + output_stride_plane * compute->out_c_, bias, compute->out_c_, + current_plane, compute->out_hw_, act); + return NNACL_OK; +} + +int DeConvWinogradInitComputeParam(DeConvWinogradStruct *deconv) { + deconv->valid_weight_shape_ = CheckShaleValid(&deconv->conv_.base_.in_[SECOND_INPUT], Num1); + if (deconv->valid_weight_shape_ == false) { + return NNACL_OK; + } + + ConvComputeParam *compute = &deconv->conv_.compute_; + DeConvParam *param = &deconv->param_; + + param->kernel_plane_ = compute->kernel_hw_; + param->ic_div_ = UP_DIV(compute->in_c_, compute->tile_num_); + param->oc_div_ = UP_DIV(compute->out_c_, compute->tile_num_); + param->ic_up_ = param->ic_div_ * compute->tile_num_; + param->oc_up_ = param->oc_div_ * compute->tile_num_; + + param->compute_size_ = 0; + for (int si_h = 0; si_h < compute->stride_h_; si_h++) { + for (int si_w = 0; si_w < compute->stride_w_; si_w++) { + if (si_h < compute->kernel_h_ && si_w < compute->kernel_w_) { + param->compute_size_++; + } + } + } + + size_t size = (size_t)param->compute_size_ * sizeof(DeConvComputeUnit); + param->compute_units_ = (DeConvComputeUnit *)(malloc(size)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(param->compute_units_); + + int cur_count = 0; + for (int si_h = 0; si_h < compute->stride_h_; si_h++) { + if (si_h >= compute->kernel_h_) { + continue; + } + for (int si_w = 0; si_w < compute->stride_w_; si_w++) { + if (si_w >= compute->kernel_w_) { + continue; + } + + int h_size = 1 + (compute->kernel_h_ - si_h - 1) / compute->stride_h_; + int w_size = 1 + (compute->kernel_w_ - si_w - 1) / compute->stride_w_; + + DeConvComputeUnit unit; + unit.winograd_.AT_ = NULL; + unit.winograd_.BT_ = NULL; + + unit.h_start_ = si_h; + unit.w_start_ = si_w; + unit.h_size_ = h_size; + unit.w_size_ = w_size; + + unit.use_winograd_ = false; + if (h_size == w_size) { + unit.winograd_.k_ = unit.h_size_; + unit.winograd_.i_ = WINOGRAD_DEFAULT_UNIT; + unit.winograd_.o_ = WINOGRAD_DEFAULT_UNIT + unit.h_size_ - 1; + unit.winograd_.kh_ = unit.h_size_ + WINOGRAD_DEFAULT_UNIT - 1; + unit.winograd_.kw_ = unit.w_size_ + WINOGRAD_DEFAULT_UNIT - 1; + unit.use_winograd_ = unit.winograd_.kh_ < WINOGRAD_MAX_COUNT && unit.winograd_.kw_ < WINOGRAD_MAX_COUNT; + } + if (unit.use_winograd_) { + unit.winograd_.b_buffer_ = NULL; + unit.weight_ = malloc(unit.winograd_.kh_ * unit.winograd_.kw_ * param->oc_up_ * param->ic_up_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(unit.weight_); + } else { + unit.weight_ = malloc(h_size * w_size * param->ic_up_ * param->oc_up_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(unit.weight_); + } + unit.tmp_buffer_ = NULL; + param->compute_units_[cur_count] = unit; + cur_count++; + } + } + return NNACL_OK; +} + +int DeConvWinogradInitDataParam(DeConvWinogradStruct *deconv) { + TensorC *weight_tensor = deconv->conv_.base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(weight_tensor); + float *nhwc_weight = weight_tensor->data_; + if (nhwc_weight == NULL) { + deconv->conv_.is_repack_ = true; + return NNACL_OK; + } + + DeConvParam *param = &deconv->param_; + + /* unit data : weight & winograd data */ + for (int i = 0; i < param->compute_size_; i++) { + DeConvComputeUnit *unit = ¶m->compute_units_[i]; + int ret = PackDeConvWgDataFp32(nhwc_weight, unit, (ConvParameter *)deconv->conv_.base_.param_, param); + if (ret != NNACL_OK) { + return ret; + } + } + + /* bias */ + ExecEnv *env = deconv->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + if (deconv->conv_.bias_data_ != NULL) { + env->Free(env->allocator_, deconv->conv_.bias_data_); + deconv->conv_.bias_data_ = NULL; + } + deconv->conv_.bias_data_ = env->Alloc(env->allocator_, param->oc_up_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->conv_.bias_data_); + (void)memset(deconv->conv_.bias_data_, 0, param->oc_up_ * sizeof(float)); + + if (deconv->conv_.base_.in_size_ == THREE_TENSOR) { + TensorC *bias_tensor = deconv->conv_.base_.in_[THIRD_INPUT]; + if (bias_tensor->shape_size_ == Num1 && NNACLGetElementNum(bias_tensor) == deconv->conv_.compute_.out_c_) { + (void)memcpy(deconv->conv_.bias_data_, bias_tensor->data_, deconv->conv_.compute_.out_c_ * sizeof(float)); + } + } + return NNACL_OK; +} + +int DeConvWinogradInitRunBuf(DeConvWinogradStruct *deconv) { + ExecEnv *env = deconv->conv_.base_.env_; + + int size = deconv->param_.oc_up_ * deconv->conv_.compute_.out_hw_; + deconv->nc4hw4_output_ = (float *)env->Alloc(env->allocator_, size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->nc4hw4_output_); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->param_.out_tile_w_, deconv->param_.out_tile_h_, NNACL_ERR); + int out_tile_hw = deconv->param_.out_tile_w_ * deconv->param_.out_tile_h_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.base_.thread_nr_, out_tile_hw, NNACL_ERR); + int total_out_tile_hw = deconv->conv_.base_.thread_nr_ * out_tile_hw; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(WINOGRAD_DEFAULT_TILE, deconv->param_.oc_up_, NNACL_ERR); + int tile_oc_up = WINOGRAD_DEFAULT_TILE * deconv->param_.oc_up_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_out_tile_hw, tile_oc_up, NNACL_ERR); + size = total_out_tile_hw * tile_oc_up; + deconv->tile_output_ = (float *)env->Alloc(env->allocator_, size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->tile_output_); + + return NNACL_OK; +} + +void DeConvWinogradFreeRunBuf(DeConvWinogradStruct *deconv) { + ExecEnv *env = deconv->conv_.base_.env_; + + if (deconv->nc4hw4_output_ != NULL) { + env->Free(env->allocator_, deconv->nc4hw4_output_); + deconv->nc4hw4_output_ = NULL; + } + + if (deconv->tile_output_ != NULL) { + env->Free(env->allocator_, deconv->tile_output_); + deconv->tile_output_ = NULL; + } +} + +int InitTrainComputeInit(DeConvWinogradStruct *deconv) { + if (!deconv->valid_weight_shape_) { + int ret = DeConvWinogradInitComputeParam(deconv); + if (ret != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return ret; + } + if (!deconv->valid_weight_shape_ || DeConvWinogradInitParameter(deconv) != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return NNACL_DECONVOLUTION_DEPTHWISE_INVALID_WEIGHT_SHAPE; + } + } + + if (deconv->conv_.is_repack_ && DeConvWinogradInitDataParam(deconv) != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return NNACL_DECONVOLUTION_DEPTHWISE_INVALID_WEIGHT_REPACK; + } + + return NNACL_OK; +} + +int DeConvWinogradPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + ConvComputeParam *compute = &deconv->conv_.compute_; + NNACL_CHECK_FALSE(compute->dilation_h_ != Num1, NNACL_DECONVOLUTION_DEPTHWISE_DILATION_INVALID); + NNACL_CHECK_FALSE(compute->dilation_w_ != Num1, NNACL_DECONVOLUTION_DEPTHWISE_DILATION_INVALID); + NNACL_CHECK_FALSE(compute->stride_h_ == Num0, NNACL_DECONVOLUTION_DEPTHWISE_STRIDE_INVALID); + NNACL_CHECK_FALSE(compute->stride_w_ == Num0, NNACL_DECONVOLUTION_DEPTHWISE_STRIDE_INVALID); + +#ifdef ENABLE_AVX + compute->tile_num_ = C8NUM; +#else + compute->tile_num_ = C4NUM; +#endif + + ConvBaseUpdateOriginWeightAndBias(&deconv->conv_); + + int ret = DeConvWinogradInitComputeParam(deconv); + if (ret != NNACL_OK) { + return ret; + } + + if (deconv->valid_weight_shape_) { + ret = DeConvWinogradInitDataParam(deconv); + if (ret != NNACL_OK) { + return ret; + } + } + + // when input data is const tensor, save data in kernel + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + if (NNACLIsConst(input_tensor)) { + deconv->origin_input_ = (float *)malloc(NNACLGetSize(input_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->origin_input_); + (void)memcpy(deconv->origin_input_, input_tensor->data_, NNACLGetSize(input_tensor)); + } + return NNACL_OK; +} + +int DeConvWinogradResize(KernelBase *self) { + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + (void)ConvBaseUpdateComputeInfo(&deconv->conv_); + + int ret = DeConvCheckvResizeValid(&deconv->conv_); + if (ret != NNACL_OK) { + return ret; + } + + DeConvWinogradFreeResizeBuf(deconv); + + ret = ConvBasePrepare(&deconv->conv_); + if (ret != NNACL_OK) { + return ret; + } + + if (!deconv->valid_weight_shape_) { + ret = DeConvWinogradInitComputeParam(deconv); + if (ret != NNACL_OK) { + return ret; + } + if (!deconv->valid_weight_shape_) { + return NNACL_OK; + } + ret = DeConvWinogradInitDataParam(deconv); + if (ret != NNACL_OK) { + return ret; + } + } + + ret = DeConvWinogradInitParameter(deconv); + if (ret != NNACL_OK) { + return ret; + } + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.compute_.out_hw_, deconv->conv_.compute_.out_c_, NNACL_ERR); + int output_chw = deconv->conv_.compute_.out_hw_ * deconv->conv_.compute_.out_c_; + if (output_chw <= kDeconvWinogradMaxPixel) { + self->thread_nr_ = NNACL_MIN(self->thread_nr_, Num3); + } + return NNACL_OK; +} + +int DeConvWinogradRelease(KernelBase *self) { + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + + DeConvWinogradFreeResizeBuf(deconv); + DeConvWinogradFreeDeconvParam(deconv); + + if (deconv->origin_input_ != NULL) { + free(deconv->origin_input_); + deconv->origin_input_ = NULL; + } + return NNACL_OK; +} + +int DeConvWinogradCompute(KernelBase *self) { + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + DeConvParam *param = &deconv->param_; + ConvComputeParam *compute_ = &deconv->conv_.compute_; + + int ret = DeConvWinogradInitRunBuf(deconv); + if (ret != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return ret; + } + + ret = InitTrainComputeInit(deconv); + if (ret != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return ret; + } + + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + float *src_in = deconv->origin_input_ != NULL ? deconv->origin_input_ : (float *)in_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_in); + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + float *src_out = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_out); + + int input_chw = compute_->in_hw_ * compute_->in_c_; + int output_chw = compute_->out_hw_ * compute_->out_c_; + for (int batch_index = 0; batch_index < compute_->in_n_; batch_index++) { + deconv->nhwc_input_ = src_in + batch_index * input_chw; + deconv->nhwc_output_ = src_out + batch_index * output_chw; + + (void)memset(deconv->nc4hw4_output_, 0, compute_->out_hw_ * param->oc_div_ * compute_->tile_num_ * sizeof(float)); + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, DeConvWgFp32Run, self, self->thread_nr_); + if (ret != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return ret; + } + + /* post bias activate and nhwc */ + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, DeConvWgPostFp32Run, self, self->thread_nr_); + if (ret != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return ret; + } + } + + DeConvWinogradFreeRunBuf(deconv); + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateDeConvWinograd(ConvParameter *param) { + DeConvWinogradStruct *deconv_winograd = (DeConvWinogradStruct *)malloc(sizeof(DeConvWinogradStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(deconv_winograd); + memset(deconv_winograd, 0, sizeof(DeConvWinogradStruct)); + + deconv_winograd->conv_.base_.Prepare = DeConvWinogradPrepare; + deconv_winograd->conv_.base_.Resize = DeConvWinogradResize; + deconv_winograd->conv_.base_.Release = DeConvWinogradRelease; + deconv_winograd->conv_.base_.Compute = DeConvWinogradCompute; + return &deconv_winograd->conv_; +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_winograd.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_winograd.h new file mode 100644 index 0000000000000000000000000000000000000000..eabdffdfe04fb898c545db179143bd850bbb1900 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_winograd.h @@ -0,0 +1,52 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_DECONVOLUTION_WINOGRAD_H_ +#define NNACL_KERNEL_DECONVOLUTION_WINOGRAD_H_ + +#ifndef _WIN32 +#ifndef ENABLE_MCU +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/convolution_base.h" + +#define kDeconvWinogradMaxPixel 3145728 +#define WINOGRAD_DEFAULT_UNIT 3 +#define WINOGRAD_DEFAULT_TILE 8 +#define WINOGRAD_MAX_COUNT 8 + +typedef struct DeConvWinogradStruct { + ConvolutionBaseStruct conv_; + DeConvParam param_; + pthread_mutex_t lock_; + int thread_num_hw_; + int thread_stride_hw_; + float *nhwc_input_; + float *nhwc_output_; + float *tile_input_; + float *tile_output_; + float *origin_input_; + float *nc4hw4_output_; + bool valid_weight_shape_; +} DeConvWinogradStruct; + +#define NNACL_DECONV_WINOGRAD_HW_MAX 2000 + +ConvolutionBaseStruct *CreateDeConvWinograd(ConvParameter *param); +#endif +#endif +#endif // NNACL_KERNEL_DECONVOLUTION_WINOGRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/default_kernel_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/default_kernel_base.c new file mode 100644 index 0000000000000000000000000000000000000000..814336c233cca66f16bd8500a832dd8450214b2e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/default_kernel_base.c @@ -0,0 +1,55 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/default_kernel_base.h" + +int DefaultPrepare3In1Out(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int DefaultPrepare3In2Out(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < TWO_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int DefaultPrepare1In2Out(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < TWO_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int DefaultPrepare1In1Out(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int DefaultPrepare2In1Out(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int DefaultResize(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + return NNACL_OK; +} + +int DefaultRelease(KernelBase *self) { return NNACL_OK; } diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/default_kernel_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/default_kernel_base.h new file mode 100644 index 0000000000000000000000000000000000000000..ba666539e1f8adf4056eefc8bcd9cb35d8f5e851 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/default_kernel_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_DEFAULT_KERNEL_BASE_H_ +#define NNACL_KERNEL_DEFAULT_KERNEL_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +int DefaultPrepare3In2Out(KernelBase *self); +int DefaultPrepare1In1Out(KernelBase *self); +int DefaultPrepare2In1Out(KernelBase *self); +int DefaultPrepare1In2Out(KernelBase *self); +int DefaultPrepare3In1Out(KernelBase *self); +int DefaultResize(KernelBase *self); +int DefaultRelease(KernelBase *self); + +#endif // NNACL_KERNEL_DEFAULT_KERNEL_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/depth_to_space.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/depth_to_space.c new file mode 100644 index 0000000000000000000000000000000000000000..3652b9fcfc10659e35b3cdaea2d63abcf5eb4ef1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/depth_to_space.c @@ -0,0 +1,80 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/depth_to_space.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/depth_to_space_parameter.h" +#include "nnacl_c/base/depth_to_space_base.h" + +int DepthToSpaceResize(KernelBase *self) { + DepthToSpaceStruct *depth_to_space = (DepthToSpaceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(depth_to_space); + DepthToSpaceArgs *args = &depth_to_space->args_; + + TensorC *input = self->in_[FIRST_INPUT]; + int32_t in_strides[DIMENSION_4D] = {0}; + ComputeStrides(input->shape_, in_strides, input->shape_size_); + args->in_stride_dim0_ = in_strides[Index0]; + args->in_stride_dim1_ = in_strides[Index1]; + args->in_stride_dim2_ = in_strides[Index2]; + + TensorC *output = self->out_[OUTPUT_INDEX]; + int32_t out_strides[DIMENSION_4D] = {0}; + ComputeStrides(output->shape_, out_strides, output->shape_size_); + args->out_stride_dim0_ = out_strides[Index0]; + args->out_stride_dim1_ = out_strides[Index1]; + args->out_stride_dim2_ = out_strides[Index2]; + return NNACL_OK; +} + +int DepthToSpaceCompute(KernelBase *self) { + DepthToSpaceStruct *depth_to_space = (DepthToSpaceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(depth_to_space); + int mode = ((DepthToSpaceParameter *)self->param_)->mode_; + + TensorC *input = self->in_[FIRST_INPUT]; + TensorC *output = self->out_[OUTPUT_INDEX]; + + if (mode == 0) { + // RCD + DepthToSpaceForNHWC(input->data_, output->data_, input->shape_, &depth_to_space->args_); + } else if (mode == 1) { + // CRD + DepthToSpaceCRDForNHWC(input->data_, output->data_, input->shape_, &depth_to_space->args_); + } else { + return NNACL_DEPTH_TO_SPACE_INVALID_MODE; + } + return NNACL_OK; +} + +KernelBase *CreateDepthToSpace(OpParameter *param, int data_type) { + DepthToSpaceStruct *depth_to_space = (DepthToSpaceStruct *)malloc(sizeof(DepthToSpaceStruct)); + NNACL_CHECK_NULL_RETURN_NULL(depth_to_space); + memset(depth_to_space, 0, sizeof(DepthToSpaceStruct)); + + depth_to_space->args_.data_type_size_ = DataTypeCSize(data_type); + depth_to_space->args_.block_size_ = ((DepthToSpaceParameter *)param)->block_size_; + depth_to_space->base_.Release = DefaultRelease; + depth_to_space->base_.Prepare = DefaultPrepare1In1Out; + depth_to_space->base_.Resize = DepthToSpaceResize; + depth_to_space->base_.Compute = DepthToSpaceCompute; + return (KernelBase *)depth_to_space; +} + +REG_KERNEL_CREATOR(PrimType_DepthToSpace, kNumberTypeFloat32, CreateDepthToSpace) +REG_KERNEL_CREATOR(PrimType_DepthToSpace, kNumberTypeFloat16, CreateDepthToSpace) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/depth_to_space.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/depth_to_space.h new file mode 100644 index 0000000000000000000000000000000000000000..969ff818f27f0061cd83d8a4f089098796a7a607 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/depth_to_space.h @@ -0,0 +1,42 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_DEPTH_TO_SPACE_H_ +#define NNACL_KERNEL_DEPTH_TO_SPACE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct DepthToSpaceArgs { + int32_t in_stride_dim0_; + int32_t in_stride_dim1_; + int32_t in_stride_dim2_; + int32_t out_stride_dim0_; + int32_t out_stride_dim1_; + int32_t out_stride_dim2_; + uint8_t data_type_size_; + int32_t block_size_; +} DepthToSpaceArgs; + +typedef struct DepthToSpaceStruct { + KernelBase base_; + DepthToSpaceArgs args_; +} DepthToSpaceStruct; + +KernelBase *CreateDepthToSpace(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_DEPTH_TO_SPACE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/exp.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/exp.c new file mode 100644 index 0000000000000000000000000000000000000000..e914d553475a8cf96276fd92c75795bf8ba6c23f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/exp.c @@ -0,0 +1,86 @@ +/** + * Copyright 2022-2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/exp.h" +#include +#include "nnacl_c/exp_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/exp_fp16.h" +#endif + +int ExpRunImpl(void *cdata, int task_id, float l, float r) { + ExpStruct *exp = (ExpStruct *)cdata; + return exp->Exp(exp->base_.in_[FIRST_INPUT]->data_, exp->base_.out_[OUTPUT_INDEX]->data_, exp, task_id); +} + +int ExpResize(struct KernelBase *self) { + ExpStruct *exp = (ExpStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(exp); + ExpParameter *param = (ExpParameter *)exp->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + exp->element_num_ = NNACLGetElementNum(exp->base_.in_[FIRST_INPUT]); + return NNACL_OK; +} + +int ExpPrepare(struct KernelBase *self) { + ExpStruct *exp = (ExpStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(exp); + ExpParameter *param = (ExpParameter *)exp->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_FALSE(self->in_size_ < 1 || self->out_size_ < 1, NNACL_TENSOR_SIZE_INVALID); + + float log_base = (param->base_ == -1) ? 1 : logf(param->base_); + float epsilon = 0.000001; + exp->in_scale_ = param->scale_ * log_base; + if (param->shift_ == 0) { + exp->out_scale_ = 1; + } else { + if (fabs(log_base - 1) < epsilon) { + exp->out_scale_ = expf(param->shift_); + } else { + exp->out_scale_ = powf(param->base_, param->shift_); + } + } + + return NNACL_OK; +} + +int ExpCompute(struct KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, ExpRunImpl, self, self->thread_nr_); +} + +KernelBase *CreateExp(OpParameter *param, int data_type) { + ExpStruct *exp = (ExpStruct *)malloc(sizeof(ExpStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(exp); + exp->base_.Prepare = ExpPrepare; + exp->base_.Resize = ExpResize; + exp->base_.Release = DefaultRelease; + exp->base_.Compute = ExpCompute; + exp->Exp = ExpFusionFp32; +#ifdef ENABLE_FP16 + if (data_type == kNumberTypeFloat16) { + exp->Exp = ExpFusionFp16; + } +#endif + return (KernelBase *)exp; +} + +REG_KERNEL_CREATOR(PrimType_ExpFusion, kNumberTypeFloat32, CreateExp) +REG_KERNEL_CREATOR(PrimType_ExpFusion, kNumberTypeFloat16, CreateExp) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/exp.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/exp.h new file mode 100644 index 0000000000000000000000000000000000000000..35a5d2d8995af858f80a27bff874f5c30b1c0537 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/exp.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_EXP_H_ +#define NNACL_KERNEL_EXP_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ExpStruct { + KernelBase base_; + float in_scale_; + float out_scale_; + int element_num_; + int (*Exp)(const void *in, void *out, const struct ExpStruct *exp, int task_id); +} ExpStruct; + +KernelBase *CreateExp(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_EXP_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_compare_f16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_compare_f16.c new file mode 100644 index 0000000000000000000000000000000000000000..1b6877a74b008b712981b3b5e53d83eb9fa7be9e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_compare_f16.c @@ -0,0 +1,110 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/f16/arithmetic_compare_f16.h" +#include "nnacl_c/kernel/f16/arithmetic_f16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" + +typedef struct ArithmeticCompareF16Funcions { + int primitive_type_; + int activation_type_; + int (*compute_)(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); + int (*optimzie_)(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +} ArithmeticCompareF16Funcions; + +typedef struct ArithmeticCompareF16Struct { + ArithmeticF16Struct arithmetic_f16_; + ArithmeticCompareF16Funcions functions_; +} ArithmeticCompareF16Struct; + +void InitArithmeticCompareF16RunFunction(KernelBase *base) { + ArithmeticCompareF16Struct *arithmetic_compare_f16 = (ArithmeticCompareF16Struct *)base; + ArithmeticParameter *arithmetic_param = (ArithmeticParameter *)base->param_; + + ArithmeticCompareF16Funcions arithmetic_cp_fun_table_fp16[] = { + {PrimType_NotEqual, ActType_No, ElementNotEqualFp16, ElementOptNotEqualFp16}, + {PrimType_Equal, ActType_No, ElementEqualFp16, ElementOptEqualFp16}, + {PrimType_Less, ActType_No, ElementLessFp16, ElementOptLessFp16}, + {PrimType_LessEqual, ActType_No, ElementLessEqualFp16, ElementOptLessEqualFp16}, + {PrimType_Greater, ActType_No, ElementGreaterFp16, ElementOptGreaterFp16}, + {PrimType_GreaterEqual, ActType_No, ElementGreaterEqualFp16, ElementOptGreaterEqualFp16}}; + + size_t length = sizeof(arithmetic_cp_fun_table_fp16) / sizeof(ArithmeticCompareF16Funcions); + for (size_t i = 0; i < length; i++) { + if (arithmetic_cp_fun_table_fp16[i].primitive_type_ == + arithmetic_compare_f16->arithmetic_f16_.arithmetic_.primitive_type_ && + arithmetic_cp_fun_table_fp16[i].activation_type_ == arithmetic_param->activation_type_) { + arithmetic_compare_f16->functions_ = arithmetic_cp_fun_table_fp16[i]; + return; + } + } +} + +int ArithmeticCompareF16DoExecute(KernelBase *base, const void *input0, const void *input1, void *output, + int64_t size) { + ArithmeticCompareF16Struct *arithmetic_compare_f16 = (ArithmeticCompareF16Struct *)base; + + if (arithmetic_compare_f16->arithmetic_f16_.arithmetic_.scalar_opt_) { + bool first_scalar = arithmetic_compare_f16->arithmetic_f16_.arithmetic_.in_elements_num0_ == 1; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare_f16->functions_.optimzie_); + return arithmetic_compare_f16->functions_.optimzie_((const float16_t *)input0, (const float16_t *)input1, + (uint8_t *)output, size, first_scalar); + } + + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare_f16->functions_.compute_); + return arithmetic_compare_f16->functions_.compute_((const float16_t *)input0, (const float16_t *)input1, + (uint8_t *)output, size); +} +int ArithmeticCompareF16Compute(KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + arithmetic->in_data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + arithmetic->out_data_size_ = DataTypeCSize(self->out_[OUTPUT_INDEX]->data_type_); + return ArithmeticF16Compute(self); +} + +KernelBase *CreateArithmeticCompareF16(OpParameter *param, int data_type) { + ArithmeticCompareF16Struct *arithmetic_compare_f16 = + (ArithmeticCompareF16Struct *)malloc(sizeof(ArithmeticCompareF16Struct)); + NNACL_CHECK_NULL_RETURN_NULL(arithmetic_compare_f16); + memset(arithmetic_compare_f16, 0, sizeof(ArithmeticF16Struct)); + + ArithmeticStruct *arithmetic = &arithmetic_compare_f16->arithmetic_f16_.arithmetic_; + arithmetic->block_boundary_infos_size_ = 0; + arithmetic->a_matrix_.batch_post_sum_ = NULL; + arithmetic->b_matrix_.batch_post_sum_ = NULL; + arithmetic->c_matrix_.batch_post_sum_ = NULL; + arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL; + arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL; + arithmetic->base_.Prepare = ArithmeticPrepare; + arithmetic->base_.Resize = ArithmeticF16Resize; + arithmetic->base_.Release = ArithmeticRelease; + arithmetic->base_.Compute = ArithmeticCompareF16Compute; + + arithmetic->execute_ = ArithmeticCompareF16DoExecute; + arithmetic->tile_function_ = TileOneDimensionFp16; + arithmetic->init_function_ = InitArithmeticCompareF16RunFunction; + + return (KernelBase *)arithmetic_compare_f16; +} + +REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeFloat16, CreateArithmeticCompareF16) +REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeFloat16, CreateArithmeticCompareF16) +REG_KERNEL_CREATOR(PrimType_Less, kNumberTypeFloat16, CreateArithmeticCompareF16) +REG_KERNEL_CREATOR(PrimType_LessEqual, kNumberTypeFloat16, CreateArithmeticCompareF16) +REG_KERNEL_CREATOR(PrimType_Greater, kNumberTypeFloat16, CreateArithmeticCompareF16) +REG_KERNEL_CREATOR(PrimType_GreaterEqual, kNumberTypeFloat16, CreateArithmeticCompareF16) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_compare_f16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_compare_f16.h new file mode 100644 index 0000000000000000000000000000000000000000..0727c77d6773c9e4825325fe4de90c25dcd9ff70 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_compare_f16.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_F16_ARITHMETIC_COMPARE_F16_H_ +#define NNACL_KERNEL_F16_ARITHMETIC_COMPARE_F16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateArithmeticCompareF16(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_F16_ARITHMETIC_COMPARE_F16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_f16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_f16.c new file mode 100644 index 0000000000000000000000000000000000000000..10ca660111ffe0f65dc39bef0c5f59f8685820c6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_f16.c @@ -0,0 +1,195 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/f16/arithmetic_f16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" +#include "nnacl_c/fp16/utils_fp16.h" +#include "nnacl_c/tensor_c_utils.h" + +void InitArithmeticF16RunFunction(KernelBase *base) { + ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)base; + + ArithmeticF16Funcions f16_fun_table[] = { + {PrimType_MulFusion, ActType_Relu, ElementMulReluFp16, ElementOptMulReluFp16}, + {PrimType_MulFusion, ActType_Relu6, ElementMulRelu6Fp16, ElementOptMulRelu6Fp16}, + {PrimType_MulFusion, ActType_No, ElementMulFp16, ElementOptMulFp16}, + {PrimType_AddFusion, ActType_Relu, ElementAddReluFp16, ElementOptAddReluFp16}, + {PrimType_AddFusion, ActType_Relu6, ElementAddRelu6Fp16, ElementOptAddRelu6Fp16}, + {PrimType_AddFusion, ActType_No, ElementAddFp16, ElementOptAddFp16}, + {PrimType_SubFusion, ActType_Relu, ElementSubReluFp16, ElementOptSubReluFp16}, + {PrimType_SubFusion, ActType_Relu6, ElementSubRelu6Fp16, ElementOptSubRelu6Fp16}, + {PrimType_SubFusion, ActType_No, ElementSubFp16, ElementOptSubFp16}, + {PrimType_DivFusion, ActType_Relu, ElementDivReluFp16, ElementOptDivReluFp16}, + {PrimType_DivFusion, ActType_Relu6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16}, + {PrimType_DivFusion, ActType_No, ElementDivFp16, ElementOptDivFp16}, + {PrimType_RealDiv, ActType_Relu, ElementDivReluFp16, ElementOptDivReluFp16}, + {PrimType_RealDiv, ActType_Relu6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16}, + {PrimType_RealDiv, ActType_No, ElementDivFp16, ElementOptDivFp16}, + {PrimType_FloorMod, ActType_No, ElementFloorModFp16, ElementOptFloorModFp16}, + {PrimType_FloorDiv, ActType_No, ElementFloorDivFp16, ElementOptFloorDivFp16}, + {PrimType_LogicalAnd, ActType_No, ElementLogicalAndFp16, ElementOptLogicalAndFp16}, + {PrimType_LogicalOr, ActType_No, ElementLogicalOrFp16, ElementOptLogicalOrFp16}, + {PrimType_SquaredDifference, ActType_No, ElementSquaredDifferenceFp16, ElementOptSquaredDifferenceFp16}, + {PrimType_Maximum, ActType_No, ElementMaximumFp16, ElementOptMaximumFp16}, + {PrimType_Minimum, ActType_No, ElementMinimumFp16, ElementOptMinimumFp16}}; + + size_t length = sizeof(f16_fun_table) / sizeof(ArithmeticF16Funcions); + for (size_t i = 0; i < length; i++) { + if (f16_fun_table[i].primitive_type_ == arithmetic_f16->arithmetic_.primitive_type_ && + f16_fun_table[i].activation_type_ == + ((ArithmeticParameter *)(arithmetic_f16->arithmetic_.base_.param_))->activation_type_) { + arithmetic_f16->functions_ = f16_fun_table[i]; + return; + } + } +} + +int ArithmeticF16DoExecute(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size) { + ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)base; + + if (arithmetic_f16->arithmetic_.scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16->functions_.optimzie_); + return arithmetic_f16->functions_.optimzie_((const float16_t *)input0, (const float16_t *)input1, + (float16_t *)output, size, + arithmetic_f16->arithmetic_.in_elements_num0_ == 1); + } + + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16->functions_.compute_); + return arithmetic_f16->functions_.compute_((const float16_t *)input0, (const float16_t *)input1, (float16_t *)output, + size); +} + +int ArithmeticF16Resize(KernelBase *self) { + ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16); + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + + arithmetic->in_data_size_ = sizeof(float16_t); + arithmetic->out_data_size_ = sizeof(float16_t); + if (arithmetic->in_elements_num1_ != 1 && arithmetic->in_elements_num0_ != 1) { + if (arithmetic->a_matrix_.is_const_ && self->in_[FIRST_INPUT]->data_type_ == kNumberTypeFloat32) { + TensorC *t = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(t->data_); + void *f32_data = t->data_; + t->data_type_ = kNumberTypeFloat16; + t->data_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(t)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_); + Float32ToFloat16((float *)(f32_data), (float16_t *)(t->data_), NNACLGetElementNum(t)); + self->env_->Free(self->env_->allocator_, f32_data); + } + if (arithmetic->b_matrix_.is_const_ && self->in_[SECOND_INPUT]->data_type_ == kNumberTypeFloat32) { + TensorC *t = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(t->data_); + void *f32_data = t->data_; + t->data_type_ = kNumberTypeFloat16; + t->data_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(t)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_); + Float32ToFloat16((float *)(f32_data), (float16_t *)(t->data_), NNACLGetElementNum(t)); + self->env_->Free(self->env_->allocator_, f32_data); + } + } + return ArithmeticResize(self); +} + +void FreeArithmeticF16Buffers(ArithmeticF16Struct *arithmetic_f16) { + for (int i = 0; i < THREE_TENSOR; i++) { + if (arithmetic_f16->tmp_buffer_[i] != NULL) { + arithmetic_f16->arithmetic_.base_.env_->Free(arithmetic_f16->arithmetic_.base_.env_->allocator_, + arithmetic_f16->tmp_buffer_[i]); + arithmetic_f16->tmp_buffer_[i] = NULL; + } + } +} + +int ArithmeticF16Compute(KernelBase *self) { + ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16); + + int in0_data_type = self->in_[FIRST_INPUT]->data_type_; + int in1_data_type = self->in_[SECOND_INPUT]->data_type_; + int out_data_type = self->out_[OUTPUT_INDEX]->data_type_; + + NNACL_CHECK_FALSE(in0_data_type != kNumberTypeFloat32 && in0_data_type != kNumberTypeFloat16, + NNACL_UNSUPPORTED_DATA_TYPE); + NNACL_CHECK_FALSE(in1_data_type != kNumberTypeFloat16 && in1_data_type != kNumberTypeFloat32, + NNACL_UNSUPPORTED_DATA_TYPE); + + if (!arithmetic_f16->arithmetic_.a_matrix_.is_valid_) { + arithmetic_f16->arithmetic_.a_matrix_.data_ = GetOrAllocFp16Data(self->in_[FIRST_INPUT], self->env_, true); + arithmetic_f16->tmp_buffer_[FIRST_INPUT] = + in0_data_type == kNumberTypeFloat16 ? NULL : arithmetic_f16->arithmetic_.a_matrix_.data_; + } + + if (!arithmetic_f16->arithmetic_.b_matrix_.is_valid_) { + arithmetic_f16->arithmetic_.b_matrix_.data_ = GetOrAllocFp16Data(self->in_[SECOND_INPUT], self->env_, true); + arithmetic_f16->tmp_buffer_[SECOND_INPUT] = + in1_data_type == kNumberTypeFloat16 ? NULL : arithmetic_f16->arithmetic_.b_matrix_.data_; + } + + arithmetic_f16->arithmetic_.c_matrix_.data_ = GetOrAllocFp16Data(self->out_[OUTPUT_INDEX], self->env_, false); + arithmetic_f16->tmp_buffer_[THIRD_INPUT] = + out_data_type == kNumberTypeFloat16 ? NULL : arithmetic_f16->arithmetic_.c_matrix_.data_; + + int ret = ArithmeticCompute(self); + if (ret == NNACL_OK && out_data_type == kNumberTypeFloat32) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16->arithmetic_.c_matrix_.data_); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]->data_); + Float16ToFloat32((float16_t *)(arithmetic_f16->arithmetic_.c_matrix_.data_), + (float *)(self->out_[OUTPUT_INDEX]->data_), NNACLGetElementNum(self->out_[OUTPUT_INDEX])); + } + + FreeArithmeticF16Buffers(arithmetic_f16); + return NNACL_OK; +} + +KernelBase *CreateArithmeticF16(OpParameter *param, int data_type) { + ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)malloc(sizeof(ArithmeticF16Struct)); + NNACL_CHECK_NULL_RETURN_NULL(arithmetic_f16); + memset(arithmetic_f16, 0, sizeof(ArithmeticF16Struct)); + + ArithmeticStruct *arithmetic = &arithmetic_f16->arithmetic_; + arithmetic->block_boundary_infos_size_ = 0; + arithmetic->a_matrix_.batch_post_sum_ = NULL; + arithmetic->b_matrix_.batch_post_sum_ = NULL; + arithmetic->c_matrix_.batch_post_sum_ = NULL; + arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL; + arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL; + arithmetic->base_.Prepare = ArithmeticPrepare; + arithmetic->base_.Resize = ArithmeticF16Resize; + arithmetic->base_.Release = ArithmeticRelease; + arithmetic->base_.Compute = ArithmeticF16Compute; + + arithmetic->execute_ = ArithmeticF16DoExecute; + arithmetic->tile_function_ = TileOneDimensionFp16; + arithmetic->init_function_ = InitArithmeticF16RunFunction; + + return (KernelBase *)arithmetic_f16; +} + +REG_KERNEL_CREATOR(PrimType_MulFusion, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_SubFusion, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_DivFusion, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_FloorMod, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_FloorDiv, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_LogicalOr, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_Maximum, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_Minimum, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_Eltwise, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_RealDiv, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_SquaredDifference, kNumberTypeFloat16, CreateArithmeticF16) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_f16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_f16.h new file mode 100644 index 0000000000000000000000000000000000000000..9e6f8fcc4c5b7cc04684a924eb38feeec75c3e07 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_f16.h @@ -0,0 +1,42 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_F16_ARITHMETIC_F16_H_ +#define NNACL_KERNEL_F16_ARITHMETIC_F16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/arithmetic.h" + +typedef struct ArithmeticF16Funcions { + int primitive_type_; + int activation_type_; + int (*compute_)(const float16_t *in1, const float16_t *in2, float16_t *out, int ele); + int (*optimzie_)(const float16_t *in1, const float16_t *in2, float16_t *out, int ele, bool first_scalar); +} ArithmeticF16Funcions; + +typedef struct ArithmeticF16Struct { + ArithmeticStruct arithmetic_; + ArithmeticF16Funcions functions_; + void *tmp_buffer_[THREE_TENSOR]; /* in_size + out_size */ +} ArithmeticF16Struct; + +KernelBase *CreateArithmeticF16(OpParameter *param, int data_type); +int ArithmeticF16Resize(KernelBase *self); +int ArithmeticF16Compute(KernelBase *self); + +#endif // MINDSPORE_NNACL_KERNEL_F16_ARITHMETIC_F16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/concat_f16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/concat_f16.c new file mode 100644 index 0000000000000000000000000000000000000000..ab6f67c18e5c4b5434471910fe494df0cee94e01 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/concat_f16.c @@ -0,0 +1,132 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/f16/concat_f16.h" +#include "nnacl_c/kernel/concat.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/fp16/utils_fp16.h" +#include "nnacl_c/tensor_c_utils.h" + +typedef struct ConcatF16Struct { + ConcatStruct concat_; + void **tmp_buffer_; /* in_size + out_size */ +} ConcatF16Struct; + +int ConcatEnsureFp16InputsAndOutput(ConcatF16Struct *concat_f16) { + ConcatStruct *concat = &concat_f16->concat_; + + int tmp_buffer_size = (concat->base_.in_size_ + concat->base_.out_size_) * sizeof(float16_t *); + concat_f16->tmp_buffer_ = concat->base_.env_->Alloc(concat->base_.env_->allocator_, tmp_buffer_size); + NNACL_CHECK_NULL_RETURN_ERR(concat_f16->tmp_buffer_); + memset(concat_f16->tmp_buffer_, 0, tmp_buffer_size); + + for (size_t i = 0; i < concat->base_.in_size_; ++i) { + if (!concat->is_with_data_[i]) { + continue; + } + + concat->inputs_ptr_[i] = GetOrAllocFp16Data(concat->base_.in_[i], concat->base_.env_, true); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(concat->inputs_ptr_[i]); + if (concat->base_.in_[i]->data_type_ == kNumberTypeFloat32 || + concat->base_.in_[i]->data_type_ == kNumberTypeFloat) { + concat_f16->tmp_buffer_[i] = concat->inputs_ptr_[i]; + } + } + + concat->output_ = GetOrAllocFp16Data(concat->base_.out_[OUTPUT_INDEX], concat->base_.env_, false); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(concat->output_); + if (concat->base_.out_[OUTPUT_INDEX]->data_type_ == kNumberTypeFloat32 || + concat->base_.out_[OUTPUT_INDEX]->data_type_ == kNumberTypeFloat) { + concat_f16->tmp_buffer_[concat->base_.in_size_] = concat->output_; + } + return NNACL_OK; +} + +int ConcatFp16Run(void *cdata, int task_id, float l, float r) { + ConcatF16Struct *concat_f16 = (ConcatF16Struct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(concat_f16); + ConcatStruct *concat = &concat_f16->concat_; + return DoConcat(concat, task_id); +} + +void ConcatF16FreeTmpBuffer(ConcatF16Struct *concat_f16) { + if (concat_f16->tmp_buffer_ != NULL) { + /* free tmp_buffer_[i] */ + for (int i = 0; i < (concat_f16->concat_.base_.in_size_ + concat_f16->concat_.base_.out_size_); i++) { + if (concat_f16->tmp_buffer_[i] != NULL) { + concat_f16->concat_.base_.env_->Free(concat_f16->concat_.base_.env_->allocator_, concat_f16->tmp_buffer_[i]); + } + concat_f16->tmp_buffer_[i] = NULL; + } + + /* free tmp_buffer_ */ + concat_f16->concat_.base_.env_->Free(concat_f16->concat_.base_.env_->allocator_, concat_f16->tmp_buffer_); + concat_f16->tmp_buffer_ = NULL; + } +} + +int ConcatF16Compute(KernelBase *self) { + ConcatF16Struct *concat_f16 = (ConcatF16Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(concat_f16); + ConcatStruct *concat = &concat_f16->concat_; + + if (concat->outer_size_ == 0 || concat->inner_sizes_[self->in_size_] == 0) { + return NNACL_OK; + } + + int ret = ConcatEnsureFp16InputsAndOutput(concat_f16); + if (ret != NNACL_OK) { + ConcatF16FreeTmpBuffer(concat_f16); + return ret; + } + + NNACL_CHECK_NULL_RETURN_ERR(concat->output_); + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConcatFp16Run, self, self->thread_nr_); + if (ret == NNACL_OK) { + TensorC *output_tensor = concat->base_.out_[FIRST_INPUT]; + if (output_tensor->data_type_ == kNumberTypeFloat32 || output_tensor->data_type_ == kNumberTypeFloat) { + float *output = concat->base_.out_[FIRST_INPUT]->data_; + if (output == NULL) { + ret = NNACL_CONCAT_F16_OUTPUT_DATA_INVALID; + } else { + Float16ToFloat32((float16_t *)concat->output_, output, NNACLGetElementNum(output_tensor)); + } + } + } + + ConcatF16FreeTmpBuffer(concat_f16); + return ret; +} + +KernelBase *CreateConcatF16(OpParameter *param, int data_type) { + ConcatF16Struct *concat_f16 = (ConcatF16Struct *)malloc(sizeof(ConcatF16Struct)); + NNACL_CHECK_NULL_RETURN_NULL(concat_f16); + memset(concat_f16, 0, sizeof(ConcatF16Struct)); + + ConcatStruct *concat = &concat_f16->concat_; + concat->data_type_ = kNumberTypeFloat16; + concat->inner_sizes_ = NULL; + concat->inputs_ptr_ = NULL; + concat->is_with_data_ = NULL; + concat->base_.Prepare = ConcatPepare; + concat->base_.Resize = ConcatResize; + concat->base_.Release = ConcatRelease; + concat->base_.Compute = ConcatF16Compute; + concat_f16->tmp_buffer_ = NULL; + return (KernelBase *)concat; +} + +REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeFloat16, CreateConcatF16) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/concat_f16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/concat_f16.h new file mode 100644 index 0000000000000000000000000000000000000000..7a6eb5af08618167a021657473742b63d0d913db --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/concat_f16.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_KERNEL_F16_CONCAT_F16_H_ +#define MINDSPORE_NNACL_KERNEL_F16_CONCAT_F16_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateConcatF16(OpParameter *param, int data_type); + +#endif // MINDSPORE_NNACL_KERNEL_F16_CONCAT_F16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/reduce_f16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/reduce_f16.c new file mode 100644 index 0000000000000000000000000000000000000000..0f1cb06e3eaaf28b468aa445d56df4faed0e8a8b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/reduce_f16.c @@ -0,0 +1,118 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/f16/reduce_f16.h" +#include "nnacl_c/fp16/reduce_fp16.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +typedef struct ReduceF16Compute { + int type_; + int (*f16_reducer_)(const int outer_size, const int inner_size, const int axis_size, const float16_t *src_data, + float16_t *dst_data, const int tid, const int thread_num); +} ReduceF16Compute; + +typedef struct ReduceF16Struct { + ReduceStruct reduce_; + ReduceF16Compute compute_; +} ReduceF16Struct; + +int CallReduceF16Unit(KernelBase *base, int task_id) { + ReduceF16Struct *reduce_f16 = (ReduceF16Struct *)base; + NNACL_CHECK_NULL_RETURN_ERR(reduce_f16->reduce_.src_data_); + NNACL_CHECK_NULL_RETURN_ERR(reduce_f16->reduce_.src_data_); + NNACL_CHECK_NULL_RETURN_ERR(reduce_f16->compute_.f16_reducer_); + + return reduce_f16->compute_.f16_reducer_(reduce_f16->reduce_.outer_size_, reduce_f16->reduce_.inner_size_, + reduce_f16->reduce_.axis_size_, + (const float16_t *)reduce_f16->reduce_.src_data_, + (float16_t *)reduce_f16->reduce_.dst_data_, task_id, base->thread_nr_); +} + +void InitialReduceF16KernelList(KernelBase *base) { + ReduceF16Struct *reduce_f16 = (ReduceF16Struct *)base; + ReduceParameter *param = (ReduceParameter *)(base->param_); + + ReduceF16Compute func_list[] = {{Reduce_Sum, ReduceSumFp16}, {Reduce_Mean, ReduceMeanFp16}, + {Reduce_Max, ReduceMaxFp16}, {Reduce_Min, ReduceMinFp16}, + {Reduce_Prod, ReduceProdFp16}, {Reduce_SumSquare, ReduceSumFp16}, + {Reduce_ASum, ReduceSumFp16}, {Reduce_L2, ReduceL2NormFp16}}; + + size_t list_len = sizeof(func_list) / sizeof(ReduceF16Compute); + for (size_t i = 0; i < list_len; ++i) { + if (param->mode_ == func_list[i].type_) { + reduce_f16->compute_ = func_list[i]; + return; + } + } +} + +void HandleReduceF16ASumAndSumSquare(KernelBase *base) { + TensorC *in_tensor = base->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(in_tensor); + float16_t *data = (float16_t *)in_tensor->data_; + NNACL_CHECK_NULL_RETURN_VOID(data); + + int num = NNACLGetElementNum(in_tensor); + + if (((ReduceParameter *)base->param_)->mode_ == Reduce_ASum) { + for (int i = 0; i < num; ++i) { + if (data[i] < 0.0f) { + data[i] = 0.0f - data[i]; + } + } + } + + if (((ReduceParameter *)base->param_)->mode_ == Reduce_SumSquare) { + for (int i = 0; i < num; ++i) { + data[i] = data[i] * data[i]; + } + return; + } +} + +int CalculateReduceF16CoeffOutput(KernelBase *base) { + TensorC *out_tensor = base->out_[OUTPUT_INDEX]; + int num = NNACLGetElementNum(out_tensor); + + float16_t *out_data = (float16_t *)out_tensor->data_; + for (int i = 0; i < num; ++i) { + out_data[i] *= ((ReduceParameter *)base->param_)->coeff; + } + return NNACL_OK; +} + +KernelBase *CreateReduceF16(OpParameter *param, int data_type) { + ReduceF16Struct *reduce_f16 = (ReduceF16Struct *)malloc(sizeof(ReduceF16Struct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(reduce_f16); + memset(reduce_f16, 0, sizeof(ReduceF16Struct)); + + ReduceStruct *reduce = &reduce_f16->reduce_; + reduce->data_type_ = data_type; + reduce->base_.Release = DefaultRelease; + reduce->base_.Prepare = ReducePrepare; + reduce->base_.Resize = ReduceResize; + reduce->base_.Compute = ReduceCompute; + + reduce->handle_sum_square_ = HandleReduceF16ASumAndSumSquare; + reduce->calculate_coeff_ = CalculateReduceF16CoeffOutput; + reduce->init_kernel_list_ = InitialReduceF16KernelList; + reduce->call_uint_ = CallReduceF16Unit; + + return (KernelBase *)reduce_f16; +} + +REG_KERNEL_CREATOR(PrimType_ReduceFusion, kNumberTypeFloat16, CreateReduceF16) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/reduce_f16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/reduce_f16.h new file mode 100644 index 0000000000000000000000000000000000000000..df990afdd7c53a33e86f22791b8965a1c2dc6970 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/reduce_f16.h @@ -0,0 +1,27 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_F16_REDUCE_F16_H_ +#define NNACL_KERNEL_F16_REDUCE_F16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/reduce.h" + +KernelBase *CreateReduceF16(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_F16_REDUCE_F16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/stack_f16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/stack_f16.c new file mode 100644 index 0000000000000000000000000000000000000000..cc748a0ea83f04056c18dbefe74c909d6b60b25d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/stack_f16.c @@ -0,0 +1,96 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/f16/stack_f16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/tensor_c_utils.h" + +void *StackF16InitBuffer(KernelBase *base, TensorC *t, bool init) { + if (init == false) { + return t->data_; + } + + int ele_num = NNACLGetElementNum(t); + void *f16_buffer = base->env_->Alloc(base->env_->allocator_, ele_num * sizeof(float16_t)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(f16_buffer); + Float32ToFloat16(t->data_, f16_buffer, ele_num); + return f16_buffer; +} + +int StackF16InitMallocFlags(StackF16Struct *stack_f16) { + KernelBase *base = (KernelBase *)stack_f16; + stack_f16->init_ = base->env_->Alloc(base->env_->allocator_, (base->in_size_ + base->out_size_) * sizeof(bool)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(stack_f16->init_); + + for (size_t i = 0; i < base->in_size_; ++i) { + stack_f16->init_[i] = base->in_[i]->data_type_ == kNumberTypeFloat32; + stack_f16->stack_.buffers_[i] = StackF16InitBuffer(base, base->in_[i], stack_f16->init_[i]); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(stack_f16->stack_.buffers_[i]); + } + stack_f16->init_[base->in_size_] = base->out_[OUTPUT_INDEX]->data_type_ == kNumberTypeFloat32; + stack_f16->stack_.buffers_[base->in_size_] = + StackF16InitBuffer(base, base->out_[OUTPUT_INDEX], stack_f16->init_[base->in_size_]); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(stack_f16->stack_.buffers_[base->in_size_]); + return NNACL_OK; +} + +void StackF16FreeBuffer(StackF16Struct *stack_f16) { + if (stack_f16->init_[stack_f16->stack_.base_.in_size_]) { + /* output transfer */ + Float16ToFloat32((float16_t *)stack_f16->stack_.buffers_[stack_f16->stack_.base_.in_size_], + (float *)stack_f16->stack_.base_.out_[OUTPUT_INDEX]->data_, + NNACLGetElementNum(stack_f16->stack_.base_.out_[OUTPUT_INDEX])); + } + + for (size_t i = 0; i < (stack_f16->stack_.base_.in_size_ + stack_f16->stack_.base_.out_size_); ++i) { + if (stack_f16->init_[i]) { + stack_f16->stack_.base_.env_->Free(stack_f16->stack_.base_.env_->allocator_, stack_f16->stack_.buffers_[i]); + } + stack_f16->stack_.buffers_[i] = NULL; + } + + stack_f16->stack_.base_.env_->Free(stack_f16->stack_.base_.env_->allocator_, stack_f16->init_); + stack_f16->init_ = NULL; +} + +int StackF16Compute(KernelBase *self) { + StackF16Struct *stack_f16 = (StackF16Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(stack_f16); + + int ret = StackF16InitMallocFlags(stack_f16); + if (ret != NNACL_OK) { + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, StackRun, self, self->thread_nr_); + StackF16FreeBuffer(stack_f16); + return ret; +} + +KernelBase *CreateStackF16(OpParameter *param, int data_type) { + StackF16Struct *stack_f16 = (StackF16Struct *)malloc(sizeof(StackF16Struct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(stack_f16); + StackStruct *stack = &stack_f16->stack_; + stack->buffers_ = NULL; + stack->data_type_ = data_type; + stack->base_.Release = StackRelease; + stack->base_.Prepare = StackPrepare; + stack->base_.Resize = StackResize; + stack->base_.Compute = StackF16Compute; + return (KernelBase *)stack; +} + +REG_KERNEL_CREATOR(PrimType_Stack, kNumberTypeFloat16, CreateStackF16) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/stack_f16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/stack_f16.h new file mode 100644 index 0000000000000000000000000000000000000000..640f04fbbfa744843b5462c9bf2f85302d8cdb06 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/stack_f16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_F16_STACK_F16_H_ +#define NNACL_KERNEL_F16_STACK_F16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/stack.h" + +typedef struct StackF16Struct { + StackStruct stack_; + bool *init_; +} StackF16Struct; + +KernelBase *CreateStackF16(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_F16_STACK_F16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fill.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fill.c new file mode 100644 index 0000000000000000000000000000000000000000..098704ab10f82cbf89a1ff07bc23090b6abedb83 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fill.c @@ -0,0 +1,102 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/fill.h" +#include "nnacl_c/fill_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/base/fill_base.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/fill_fp16.h" +#endif + +int FillResize(struct KernelBase *self) { + FillStruct *fill = (FillStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fill); + fill->base_.thread_nr_ = fill->base_.UpdateThread( + TC_PTYPE(PrimType_Fill), 0, 1, NNACLGetSize(fill->base_.out_[OUTPUT_INDEX]), fill->base_.thread_nr_); + + NNACL_CHECK_NULL_RETURN_ERR(fill->base_.out_[OUTPUT_INDEX]); + fill->data_size_ = (int)NNACLGetElementNum(fill->base_.out_[OUTPUT_INDEX]); + fill->thread_sz_count_ = MSMIN(fill->base_.thread_nr_, fill->data_size_); + if (fill->thread_sz_count_ != 0) { + fill->thread_sz_stride_ = UP_DIV(fill->data_size_, fill->thread_sz_count_); + } + return NNACL_OK; +} + +int FillImpl(void *cdata, int task_id, float l, float r) { + FillStruct *fill = (FillStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(fill); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, fill->thread_sz_stride_, NNACL_ERR); + int size = MSMIN(fill->thread_sz_stride_, fill->data_size_ - task_id * fill->thread_sz_stride_); + NNACL_CHECK_FALSE(size <= 0, NNACL_OK); + int offset = task_id * fill->thread_sz_stride_; + int ret = NNACL_OK; + switch (fill->base_.in_[FIRST_INPUT]->data_type_) { +#ifdef ENABLE_FP16 + case kNumberTypeFloat16: + ret = FillFp16((float16_t *)fill->out_ptr_ + offset, size, ((float16_t *)fill->src_data_)[FIRST_INPUT]); + break; +#endif + case kNumberTypeFloat32: + ret = FillFp32((float *)fill->out_ptr_ + offset, size, ((float *)fill->src_data_)[FIRST_INPUT]); + break; + case kNumberTypeInt32: + ret = FillInt32((int *)fill->out_ptr_ + offset, size, ((int *)fill->src_data_)[FIRST_INPUT]); + break; + case kNumberTypeBool: + ret = FillBool((bool *)fill->out_ptr_ + offset, size, ((bool *)fill->src_data_)[FIRST_INPUT]); + break; + default: + return NNACL_FILL_DATA_TYPE_INVALID; + } + return ret; +} + +int FillCompute(struct KernelBase *self) { + FillStruct *fill = (FillStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fill); + + fill->src_data_ = (void *)fill->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(fill->src_data_); + fill->out_ptr_ = (void *)fill->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(fill->out_ptr_); + + return self->env_->ParallelLaunch(self->env_->thread_pool_, FillImpl, fill, fill->base_.thread_nr_); +} + +KernelBase *CreateFill(OpParameter *param, int data_type) { + FillStruct *fill = (FillStruct *)malloc(sizeof(FillStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(fill); + fill->base_.Prepare = DefaultPrepare2In1Out; + fill->base_.Resize = FillResize; + fill->base_.Release = DefaultRelease; + fill->base_.Compute = FillCompute; + return (KernelBase *)fill; +} + +REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeBool, CreateFill); +REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeInt32, CreateFill); +REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeFloat32, CreateFill); +REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeFloat16, CreateFill); + +REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeBool, CreateFill); +REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeInt32, CreateFill); +REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeFloat32, CreateFill); +REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeFloat16, CreateFill); diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fill.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fill.h new file mode 100644 index 0000000000000000000000000000000000000000..3cb441367d8e0876df38fe939510a25d19100c71 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fill.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_FILL_H_ +#define NNACL_KERNEL_FILL_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct FillStruct { + KernelBase base_; + int thread_sz_count_; + int thread_sz_stride_; + int data_size_; + void *src_data_; + void *out_ptr_; + int thread_count_; +} FillStruct; + +KernelBase *CreateFill(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_FILL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fullconnection.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fullconnection.c new file mode 100644 index 0000000000000000000000000000000000000000..215aa8e2565dbc926396450d3a43209067c572e2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fullconnection.c @@ -0,0 +1,81 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/fullconnection.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/kernel/matmul_create.h" + +int FullConnectionPrepare(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + + NNACL_CHECK_FALSE(self->in_size_ < C2NUM, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < C1NUM, NNACL_ERR); + + if (matmul->a_const_ || matmul->infer_shape_) { + int *a_shape = self->in_[FIRST_INPUT]->shape_; + matmul->compute_.row_ = a_shape[0]; + matmul->compute_.deep_ = a_shape[1]; + } + + if (matmul->b_const_ || matmul->infer_shape_) { + int *b_shape = self->in_[SECOND_INPUT]->shape_; + matmul->compute_.col_ = b_shape[0]; + matmul->compute_.deep_ = b_shape[1]; + } + + matmul->batch_ = 1; + matmul->a_batch_ = 1; + matmul->b_batch_ = 1; + + MatMulParameter *param = (MatMulParameter *)matmul->base_.param_; + param->a_transpose_ = false; + param->b_transpose_ = true; + + int ret = MatmulBaseMallocBatchOffset(matmul); + if (ret != NNACL_OK) { + return ret; + } + + return MatmulBasePrepare(self); +} + +int FullConnectionResize(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + NNACL_CHECK_TRUE_RET(self->out_[0]->shape_size_ > 0, NNACL_ERR); + + int row = 1; + for (size_t i = 0; i < self->out_[0]->shape_size_ - 1; ++i) { + row *= (self->out_[OUTPUT_INDEX]->shape_)[i]; + } + matmul->compute_.row_ = row; + matmul->compute_.col_ = (self->out_[OUTPUT_INDEX]->shape_)[self->out_[0]->shape_size_ - 1]; + matmul->compute_.deep_ = self->in_[SECOND_INPUT]->shape_[SECOND_INPUT]; + + return MatmulBaseResize(self); +} + +KernelBase *CreateFullconnection(OpParameter *param, int data_type) { + KernelBase *kernel = NULL; + if (data_type == kNumberTypeFloat32) { + kernel = CreateMatmulKernel(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(kernel); + kernel->Prepare = FullConnectionPrepare; + kernel->Resize = FullConnectionResize; + } + return kernel; +} + +REG_KERNEL_CREATOR(PrimType_FullConnection, kNumberTypeFloat32, CreateFullconnection); diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fullconnection.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fullconnection.h new file mode 100644 index 0000000000000000000000000000000000000000..a54116d959b6e55d92daae97139b888b5976efcf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fullconnection.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_FULLCONNECTION_H_ +#define NNACL_KERNEL_FULLCONNECTION_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateFullconnection(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_FULLCONNECTION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fused_batch_norm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fused_batch_norm.c new file mode 100644 index 0000000000000000000000000000000000000000..5e49da11b0d2ec8c762d6ca6d7b124e2cd186990 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fused_batch_norm.c @@ -0,0 +1,327 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/fused_batch_norm.h" +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/batchnorm_parameter.h" +#include "nnacl_c/fp32/batchnorm_fp32.h" +#include "nnacl_c/fp32/scale_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/scale_fp16.h" +#include "nnacl_c/fp16/batchnorm_fp16.h" +#endif + +int FusedBatchNormInitScaleParam(FusedBatchNormStruct *fused_batch_norm) { + ScaleStruct *scale = &fused_batch_norm->scale_param_; + scale->base_.thread_nr_ = fused_batch_norm->bn_.base_.thread_nr_; + + scale->axis_ = kNHWC_C; + TensorC *in_tensor = fused_batch_norm->bn_.base_.in_[FIRST_INPUT]; + if (in_tensor->shape_size_ != DIMENSION_4D) { + return NNACL_FUSED_BATCH_NORM_NO_CHANGE; + } + + scale->outer_size_ = 1; + for (int i = 0; i < scale->axis_; i++) { + scale->outer_size_ *= in_tensor->shape_[i]; + } + scale->axis_size_ = in_tensor->shape_[Index3]; + scale->inner_size_ = 1; + return NNACL_OK; +} + +void FusedBatchNormCalculateScaleF32(FusedBatchNormStruct *fbn, const void *scale_data, const void *bias_data, + const void *mean_data, const void *var_data, float eps, int kernel_num) { + float *fp32_scale_origin = (float *)scale_data; + float *fp32_var_origin = (float *)var_data; + float *fp32_bias_origin = (float *)bias_data; + float *fp32_mean_origin = (float *)mean_data; + + float *fp32_scale = (float *)fbn->scale_; + for (int i = 0; i < kernel_num; i++) { + fp32_scale[i] = fp32_scale_origin[i] / sqrtf(fp32_var_origin[i] + eps); + } + + float *fp32_offset = (float *)fbn->offset_; + for (int i = 0; i < kernel_num; i++) { + fp32_offset[i] = fp32_bias_origin[i] - fp32_mean_origin[i] * fp32_scale[i]; + } +} + +void FusedBatchNormCalculateScaleF16(FusedBatchNormStruct *fbn, const void *scale_data, const void *bias_data, + const void *mean_data, const void *var_data, float eps, int kernel_num) { +#ifdef ENABLE_FP16 + float16_t *fp16_scale_origin = (float16_t *)scale_data; + float16_t *fp16_var_origin = (float16_t *)var_data; + float16_t *fp16_bias_origin = (float16_t *)bias_data; + float16_t *fp16_mean_origin = (float16_t *)mean_data; + + float16_t *fp16_scale = (float16_t *)fbn->scale_; + for (int i = 0; i < kernel_num; i++) { + fp16_scale[i] = fp16_scale_origin[i] / sqrtf(fp16_var_origin[i] + eps); + } + + float16_t *fp16_offset = (float16_t *)fbn->offset_; + for (int i = 0; i < kernel_num; i++) { + fp16_offset[i] = fp16_bias_origin[i] - fp16_mean_origin[i] * fp16_scale[i]; + } +#endif +} + +void FusedBatchNormRunFp16(FusedBatchNormStruct *fused_batch_norm, int task_id) { +#ifdef ENABLE_FP16 + void *in_data = fused_batch_norm->bn_.base_.in_[FIRST_INPUT]->data_; + void *out_data = fused_batch_norm->bn_.base_.out_[OUTPUT_INDEX]->data_; + + if (fused_batch_norm->is_scale_) { + DoScaleFp16((float16_t *)in_data, (float16_t *)out_data, (float16_t *)fused_batch_norm->scale_, + (float16_t *)fused_batch_norm->offset_, task_id, &fused_batch_norm->scale_param_); + } else { + FusedBatchNormFp16((float16_t *)in_data, (float16_t *)fused_batch_norm->scale_, + (float16_t *)fused_batch_norm->offset_, (float16_t *)fused_batch_norm->bn_.mean_, + (float16_t *)fused_batch_norm->bn_.variance_, &fused_batch_norm->bn_, task_id, + fused_batch_norm->bn_.base_.thread_nr_, (float16_t *)out_data); + } +#endif +} + +int FusedBatchNormBatchnorm2Scale(FusedBatchNormStruct *fused_batch_norm, const void *scale_data, const void *bias_data, + const void *mean_data, const void *var_data, float eps, int kernel_num) { + int ret = FusedBatchNormInitScaleParam(fused_batch_norm); + if (ret != NNACL_OK) { + return ret; + } + + ExecEnv *env = fused_batch_norm->bn_.base_.env_; + TensorC *scale_tensor = fused_batch_norm->bn_.base_.in_[SECOND_INPUT]; + fused_batch_norm->scale_ = env->Alloc(env->allocator_, NNACLGetSize(scale_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->scale_); + TensorC *offset_tensor = fused_batch_norm->bn_.base_.in_[THIRD_INPUT]; + fused_batch_norm->offset_ = env->Alloc(env->allocator_, NNACLGetSize(offset_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->offset_); + + // new scale: -scale / sqrt(variance + eps) + // new bias: -scale * mean / sqrt(variance + eps) + bias + if (fused_batch_norm->bn_.data_type_ == kNumberTypeFloat16) { + FusedBatchNormCalculateScaleF16(fused_batch_norm, scale_data, bias_data, mean_data, var_data, eps, kernel_num); + } else { + FusedBatchNormCalculateScaleF32(fused_batch_norm, scale_data, bias_data, mean_data, var_data, eps, kernel_num); + } + + fused_batch_norm->is_scale_ = true; + return NNACL_OK; +} + +int FusedBatchNormInitConstTensor(FusedBatchNormStruct *fused_batch_norm) { + TensorC *scale_tensor = fused_batch_norm->bn_.base_.in_[SECOND_INPUT]; + TensorC *offset_tensor = fused_batch_norm->bn_.base_.in_[THIRD_INPUT]; + TensorC *mean_tensor = fused_batch_norm->bn_.base_.in_[FOURTH_INPUT]; + TensorC *variance_tensor = fused_batch_norm->bn_.base_.in_[FIFTH_INPUT]; + + if (!fused_batch_norm->bn_.base_.train_session_) { + int ret = FusedBatchNormBatchnorm2Scale( + fused_batch_norm, (float *)scale_tensor->data_, (float *)offset_tensor->data_, (float *)mean_tensor->data_, + (float *)variance_tensor->data_, fused_batch_norm->bn_.epsilon_, NNACLGetElementNum(scale_tensor)); + if (ret == NNACL_OK) { + return NNACL_OK; + } else { + fused_batch_norm->bn_.base_.Release(&fused_batch_norm->bn_.base_); + if (ret != NNACL_FUSED_BATCH_NORM_NO_CHANGE) { + return NNACL_FUSED_BATCH_NORM_TO_SCALE_FAILED; + } + } + } + + ExecEnv *env = fused_batch_norm->bn_.base_.env_; + fused_batch_norm->scale_ = env->Alloc(env->allocator_, NNACLGetSize(scale_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->scale_); + (void)memcpy(fused_batch_norm->scale_, scale_tensor->data_, NNACLGetSize(scale_tensor)); + fused_batch_norm->offset_ = env->Alloc(env->allocator_, NNACLGetSize(offset_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->offset_); + (void)memcpy(fused_batch_norm->offset_, offset_tensor->data_, NNACLGetSize(offset_tensor)); + fused_batch_norm->bn_.mean_ = env->Alloc(env->allocator_, NNACLGetSize(mean_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->bn_.mean_); + (void)memcpy(fused_batch_norm->bn_.mean_, mean_tensor->data_, NNACLGetSize(mean_tensor)); + fused_batch_norm->bn_.variance_ = env->Alloc(env->allocator_, NNACLGetSize(variance_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->bn_.variance_); + (void)memcpy(fused_batch_norm->bn_.variance_, variance_tensor->data_, NNACLGetSize(variance_tensor)); + return NNACL_OK; +} + +void FusedBatchNormRunFp32(FusedBatchNormStruct *fused_batch_norm, int task_id) { + void *in_data = fused_batch_norm->bn_.base_.in_[FIRST_INPUT]->data_; + void *out_data = fused_batch_norm->bn_.base_.out_[OUTPUT_INDEX]->data_; + + if (fused_batch_norm->is_scale_) { + DoScale((float *)in_data, (float *)out_data, (float *)fused_batch_norm->scale_, (float *)fused_batch_norm->offset_, + task_id, &fused_batch_norm->scale_param_); + } else { + FusedBatchNormFp32((float *)in_data, (float *)fused_batch_norm->scale_, (float *)fused_batch_norm->offset_, + (float *)fused_batch_norm->bn_.mean_, (float *)fused_batch_norm->bn_.variance_, + &fused_batch_norm->bn_, task_id, fused_batch_norm->bn_.base_.thread_nr_, (float *)out_data); + } +} + +int FusedBatchNormRun(void *cdata, int task_id, float l, float r) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + if (fused_batch_norm->bn_.data_type_ == kNumberTypeFloat16) { + FusedBatchNormRunFp16(fused_batch_norm, task_id); + } else if (fused_batch_norm->bn_.data_type_ == kNumberTypeFloat32) { + FusedBatchNormRunFp32(fused_batch_norm, task_id); + } + return NNACL_OK; +} + +int FusedBatchNormTrainComputeInit(FusedBatchNormStruct *fbn) { + if (fbn->bn_.base_.out_size_ < Num5) { + return NNACL_OK; + } + + TensorC *out_scale = fbn->bn_.base_.out_[SECOND_INPUT]; + TensorC *out_offset = fbn->bn_.base_.out_[THIRD_INPUT]; + TensorC *out_mean = fbn->bn_.base_.out_[FOURTH_INPUT]; + TensorC *out_var = fbn->bn_.base_.out_[FIFTH_INPUT]; + + void *current_mean = fbn->bn_.mean_; + void *current_var = fbn->bn_.variance_; + + bool schema_trained = ((BatchNormParameter *)fbn->bn_.base_.param_)->is_training_; + if (fbn->train_mode_ && schema_trained && fbn->bn_.base_.in_size_ >= Num5) { + TensorC *in_tensor = fbn->bn_.base_.in_[FIRST_INPUT]; + TensorC *scale_tensor = fbn->bn_.base_.in_[SECOND_INPUT]; + TensorC *offset_tensor = fbn->bn_.base_.in_[THIRD_INPUT]; + TensorC *mean_tensor = fbn->bn_.base_.in_[FOURTH_INPUT]; + TensorC *var_tensor = fbn->bn_.base_.in_[FIFTH_INPUT]; + if (in_tensor->data_ == NULL || scale_tensor->data_ == NULL || offset_tensor->data_ == NULL || + mean_tensor->data_ == NULL || var_tensor->data_ == NULL) { + return NNACL_FUSED_BATCH_TRAIN_DATA_INVALID; + } + + memset(current_mean, 0, NNACLGetSize(mean_tensor)); + memset(current_var, 0, NNACLGetSize(var_tensor)); + + bool isBatch2d = true; + if (fbn->bn_.base_.in_[FIRST_INPUT]->shape_size_ == Num2) isBatch2d = false; + + if (fbn->bn_.data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + FusedBatchNormFp16MeanVar((float16_t *)in_tensor->data_, (float16_t *)current_mean, current_var, &fbn->bn_, + (float16_t *)mean_tensor->data_, (float16_t *)var_tensor->data_); +#endif + } else { + FusedBatchNormFp32MeanVar((float *)in_tensor->data_, (float *)current_mean, current_var, &fbn->bn_, + (float *)mean_tensor->data_, (float *)var_tensor->data_, isBatch2d); + } + + (void)memcpy(out_scale->data_, scale_tensor->data_, NNACLGetSize(out_scale)); + (void)memcpy(out_offset->data_, offset_tensor->data_, NNACLGetSize(out_offset)); + (void)memcpy(out_mean->data_, current_mean, NNACLGetSize(out_mean)); + (void)memcpy(out_var->data_, current_var, NNACLGetSize(out_var)); + + // Copy to local variables + (void)memcpy(fbn->scale_, scale_tensor->data_, NNACLGetSize(scale_tensor)); + (void)memcpy(fbn->offset_, offset_tensor->data_, NNACLGetSize(offset_tensor)); + + fbn->trained_ = true; // trained at least once + return NNACL_OK; + } + + if (fbn->bn_.base_.train_session_) { + (void)memcpy(out_scale->data_, fbn->scale_, NNACLGetSize(out_scale)); + (void)memcpy(out_offset->data_, fbn->offset_, NNACLGetSize(out_offset)); + (void)memcpy(out_mean->data_, current_mean, NNACLGetSize(out_mean)); + (void)memcpy(out_var->data_, current_var, NNACLGetSize(out_var)); + } + + return NNACL_OK; +} + +int FusedBatchNormCompute(KernelBase *self) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + + int ret = FusedBatchNormTrainComputeInit(fused_batch_norm); + if (ret != NNACL_OK) { + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, FusedBatchNormRun, self, self->thread_nr_); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +int FusedBatchNormReSize(KernelBase *self) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + + int ret = BatchNormFillParam(&fused_batch_norm->bn_); + if (ret != NNACL_OK) { + return ret; + } + + (void)self->Release(self); + + return FusedBatchNormInitConstTensor(fused_batch_norm); +} + +int FusedBatchNormPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < FIVE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + fused_batch_norm->bn_.momentum_ = ((BatchNormParameter *)self->param_)->momentum_; + fused_batch_norm->bn_.epsilon_ = ((BatchNormParameter *)self->param_)->epsilon_; + return NNACL_OK; +} + +int FusedBatchNormRelease(KernelBase *self) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + + (void)BatchNormRelease(&fused_batch_norm->bn_.base_); + + if (fused_batch_norm->scale_ != NULL) { + self->env_->Free(self->env_->allocator_, fused_batch_norm->scale_); + fused_batch_norm->scale_ = NULL; + } + if (fused_batch_norm->offset_ != NULL) { + self->env_->Free(self->env_->allocator_, fused_batch_norm->offset_); + fused_batch_norm->offset_ = NULL; + } + return NNACL_OK; +} + +KernelBase *CreateFusedBatchNorm(OpParameter *param, int data_type) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)malloc(sizeof(FusedBatchNormStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(fused_batch_norm); + memset(fused_batch_norm, 0, sizeof(FusedBatchNormStruct)); + fused_batch_norm->bn_.data_type_ = data_type; + fused_batch_norm->bn_.base_.Prepare = FusedBatchNormPrepare; + fused_batch_norm->bn_.base_.Resize = FusedBatchNormReSize; + fused_batch_norm->bn_.base_.Release = FusedBatchNormRelease; + fused_batch_norm->bn_.base_.Compute = FusedBatchNormCompute; + return (KernelBase *)fused_batch_norm; +} + +REG_KERNEL_CREATOR(PrimType_FusedBatchNorm, kNumberTypeFloat16, CreateFusedBatchNorm) +REG_KERNEL_CREATOR(PrimType_FusedBatchNorm, kNumberTypeFloat32, CreateFusedBatchNorm) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fused_batch_norm.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fused_batch_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..15193c9ca63d5456b0b8e9437df10b6f013ce527 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fused_batch_norm.h @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_FUSED_BATCH_NORM_H_ +#define NNACL_KERNEL_FUSED_BATCH_NORM_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/batch_norm.h" +#include "nnacl_c/kernel/scale.h" + +typedef struct FusedBatchNormStruct { + BatchNormStruct bn_; + ScaleStruct scale_param_; + void *scale_; + void *offset_; + bool is_scale_; + bool trained_; + bool train_mode_; +} FusedBatchNormStruct; + +KernelBase *CreateFusedBatchNorm(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_FUSED_BATCH_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather.c new file mode 100644 index 0000000000000000000000000000000000000000..532640b2042b1cdff612e5e1fcfba401720130b1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather.c @@ -0,0 +1,241 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/gather.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" + +#define kGatherMinCostPerThread 16384 + +void GatherHandleCopy(GatherStruct *gather, int8_t **int8_in, int8_t **int8_out, int begin, int end, + int byte_in_stride) { + for (; begin < end; ++begin) { + int index = gather->indices_data_[begin]; + index = (index < 0 ? index + gather->limit_ : index); + if (index < 0 || index >= gather->limit_) { + memset(*int8_out, 0, gather->byte_inner_size_); + } else { + memcpy(*int8_out, *int8_in + index * gather->byte_inner_size_, gather->byte_inner_size_); + } + *int8_out += gather->byte_inner_size_; + } + *int8_in += byte_in_stride; +} + +int GatherRun(void *cdata, int task_id, float l, float r) { + GatherStruct *gather = (GatherStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(gather); + NNACL_CHECK_FALSE(task_id < 0, NNACL_ERR); + NNACL_CHECK_FALSE(task_id >= gather->block_infos_size_, NNACL_ERR); + + int8_t *int8_in = (int8_t *)(gather->base_.in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(int8_in); + int8_t *int8_out = (int8_t *)(gather->base_.out_[OUTPUT_INDEX]->data_); + NNACL_CHECK_NULL_RETURN_ERR(int8_out); + int begin_batch = gather->block_infos_[task_id].begin_batch_; + int begin_index = gather->block_infos_[task_id].begin_index_; + int end_batch = gather->block_infos_[task_id].end_batch_; + int end_index = gather->block_infos_[task_id].end_index_; + int64_t byte_in_stride = gather->limit_ * gather->byte_inner_size_; + int8_in += begin_batch * byte_in_stride; + int8_out += begin_batch * gather->indices_size_ * gather->byte_inner_size_ + begin_index * gather->byte_inner_size_; + if (begin_batch == end_batch) { + GatherHandleCopy(gather, &int8_in, &int8_out, begin_index, end_index, byte_in_stride); + return NNACL_OK; + } + GatherHandleCopy(gather, &int8_in, &int8_out, begin_index, gather->indices_size_, byte_in_stride); + ++begin_batch; + for (; begin_batch < end_batch; ++begin_batch) { + GatherHandleCopy(gather, &int8_in, &int8_out, 0, gather->indices_size_, byte_in_stride); + } + GatherHandleCopy(gather, &int8_in, &int8_out, 0, end_index, byte_in_stride); + return NNACL_OK; +} + +int AssignGatherIndicesData(GatherStruct *gather, bool is_indices_int32) { + TensorC *indices_tensor = gather->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(indices_tensor->data_); + + if (is_indices_int32) { + gather->indices_data_ = (int *)(indices_tensor->data_); + return NNACL_OK; + } + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(gather->indices_size_, (int)(sizeof(int)), NNACL_ERR); + gather->indices_data_ = + (int *)(gather->base_.env_->Alloc(gather->base_.env_->allocator_, gather->indices_size_ * sizeof(int))); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(gather->indices_data_); + + switch (indices_tensor->data_type_) { + case kNumberTypeInt64: + for (int i = 0; i < gather->indices_size_; i++) { + gather->indices_data_[i] = (int)((int64_t *)indices_tensor->data_)[i]; + } + break; + case kNumberTypeFloat: + case kNumberTypeFloat32: + for (int i = 0; i < gather->indices_size_; i++) { + gather->indices_data_[i] = (int)((float *)indices_tensor->data_)[i]; + } + break; + case kNumberTypeBool: + for (int i = 0; i < gather->indices_size_; i++) { + gather->indices_data_[i] = (int)((bool *)indices_tensor->data_)[i]; + } + break; + default: + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +int InitGatherDynamicStatus(GatherStruct *gather) { + int *in_shape = gather->base_.in_[FIRST_INPUT]->shape_; + int in_rank = (int)gather->base_.in_[FIRST_INPUT]->shape_size_; + NNACL_CHECK_TRUE_RET(gather->axis_ >= 0 && gather->axis_ < in_rank, NNACL_GATHER_AXIS_INVALID); + gather->limit_ = in_shape[gather->axis_]; + gather->outer_size_ = 1; + for (int i = 0; i < gather->axis_; ++i) { + gather->outer_size_ *= in_shape[i]; + } + gather->byte_inner_size_ = (int)DataTypeCSize(gather->base_.out_[OUTPUT_INDEX]->data_type_); + for (int i = gather->axis_ + 1; i < in_rank; ++i) { + gather->byte_inner_size_ *= in_shape[i]; + } + gather->indices_size_ = NNACLGetElementNum(gather->base_.in_[SECOND_INPUT]); + return NNACL_OK; +} + +void GatherUpdateThreadNumProcess(GatherStruct *gather) { + int all_bytes = NNACLGetSize(gather->base_.out_[OUTPUT_INDEX]); + if (all_bytes <= kGatherMinCostPerThread) { + gather->base_.thread_nr_ = 1; + return; + } + + gather->base_.thread_nr_ = + gather->base_.UpdateThread(TC_PTYPE(PrimType_Gather), 0, gather->byte_inner_size_, + NNACLGetSize(gather->base_.out_[OUTPUT_INDEX]), gather->base_.thread_nr_); + return; +} + +int ChooseGatherThreadCuttingStrategy(GatherStruct *gather) { + gather->block_infos_size_ = 0; + if (gather->outer_size_ == 0 || gather->indices_size_ == 0 || gather->byte_inner_size_ == 0) { + return NNACL_OK; + } + GatherUpdateThreadNumProcess(gather); + if (gather->base_.thread_nr_ > GATHER_BLOCK_INFOS_SIZE) { + gather->base_.thread_nr_ = GATHER_BLOCK_INFOS_SIZE; + } + + if (gather->base_.thread_nr_ == 1) { + gather->block_infos_[gather->block_infos_size_].begin_batch_ = 0; + gather->block_infos_[gather->block_infos_size_].begin_index_ = 0; + gather->block_infos_[gather->block_infos_size_].end_batch_ = gather->outer_size_; + gather->block_infos_[gather->block_infos_size_].end_index_ = 0; + gather->block_infos_size_++; + } else { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(gather->outer_size_, gather->indices_size_, NNACL_ERR); + int total_block = gather->outer_size_ * gather->indices_size_; + int block_size = total_block / gather->base_.thread_nr_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(block_size, gather->base_.thread_nr_, NNACL_ERR); + int remain_block = total_block - block_size * gather->base_.thread_nr_; + int start = 0; + while (start < total_block) { + GatherBlockBoundaryInfo block_boundary_info; + block_boundary_info.begin_batch_ = start / gather->indices_size_; + block_boundary_info.begin_index_ = start % gather->indices_size_; + start += block_size; + if (remain_block > 0) { + ++start; + --remain_block; + } + if (start >= total_block) { + start = total_block; + } + block_boundary_info.end_batch_ = start / gather->indices_size_; + block_boundary_info.end_index_ = start % gather->indices_size_; + gather->block_infos_[gather->block_infos_size_++] = block_boundary_info; + } + gather->base_.thread_nr_ = gather->block_infos_size_; + } + + return NNACL_OK; +} + +int GatherResize(KernelBase *self) { + GatherStruct *gather = (GatherStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather); + + int status = InitGatherDynamicStatus(gather); + NNACL_CHECK_FALSE(status != NNACL_OK, status); + + return ChooseGatherThreadCuttingStrategy(gather); +} + +int GatherPrepare(struct KernelBase *self) { + GatherStruct *gather = (GatherStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather); + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_GATHER_INPUT_TENSOR_INVALID); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_GATHER_OUTPUT_TENSOR_INVALID); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[THIRD_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[THIRD_INPUT]->data_); + gather->axis_ = *((int *)self->in_[THIRD_INPUT]->data_); + return NNACL_OK; +} + +int GatherCompute(struct KernelBase *self) { + GatherStruct *gather = (GatherStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather); + + if (gather->outer_size_ == 0 || gather->indices_size_ == 0 || gather->byte_inner_size_ == 0) { + return NNACL_OK; + } + + bool is_indices_int32 = self->in_[SECOND_INPUT]->data_type_ == kNumberTypeInt32; + int ret = AssignGatherIndicesData(gather, is_indices_int32); + if (ret != NNACL_OK) { + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, GatherRun, gather, gather->base_.thread_nr_); + + if (!is_indices_int32) { + self->env_->Free(self->env_->allocator_, gather->indices_data_); + gather->indices_data_ = NULL; + } + return ret; +} + +KernelBase *CreateGather(OpParameter *param, int data_type) { + GatherStruct *gather = (GatherStruct *)malloc(sizeof(GatherStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(gather); + gather->indices_data_ = NULL; + gather->block_infos_size_ = 0; + gather->base_.Prepare = GatherPrepare; + gather->base_.Resize = GatherResize; + gather->base_.Release = DefaultRelease; + gather->base_.Compute = GatherCompute; + return (KernelBase *)gather; +} + +REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeFloat16, CreateGather) +REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeFloat32, CreateGather) +REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeInt32, CreateGather) +REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeBool, CreateGather) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather.h new file mode 100644 index 0000000000000000000000000000000000000000..9a95133f566cf8c763d52e14dab418f4740768ca --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather.h @@ -0,0 +1,46 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_GATHER_H_ +#define NNACL_KERNEL_GATHER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +#define GATHER_BLOCK_INFOS_SIZE 32 + +typedef struct GatherBlockBoundaryInfo { + int64_t begin_batch_; + int64_t begin_index_; + int64_t end_batch_; + int64_t end_index_; +} GatherBlockBoundaryInfo; + +typedef struct GatherStruct { + KernelBase base_; + int axis_; + int limit_; + int outer_size_; + int indices_size_; + int byte_inner_size_; + int block_infos_size_; + int *indices_data_; + GatherBlockBoundaryInfo block_infos_[GATHER_BLOCK_INFOS_SIZE]; +} GatherStruct; + +KernelBase *CreateGather(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_GATHER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_d.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_d.c new file mode 100644 index 0000000000000000000000000000000000000000..f80909bde5c28acaabe72fcbb4c8c0eb2eab094f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_d.c @@ -0,0 +1,124 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either gather_dress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/gather_d.h" +#include "nnacl_c/gather_parameter.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/gather_d_base.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +typedef struct GatherDStru { + KernelBase base; +} GatherDStru; + +int GatherDPrepare(struct KernelBase *self) { + GatherDStru *gather_d = (GatherDStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_d); + GatherParameter *param = (GatherParameter *)gather_d->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_FALSE(self->in_size_ < kInputSize2 || self->out_size_ < 1, NNACL_TENSOR_SIZE_INVALID); + + param->axis_ = ((int *)(gather_d->base.in_[1]->data_))[0]; + return NNACL_OK; +} + +int GatherDResize(struct KernelBase *self) { + GatherDStru *gather_d = (GatherDStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_d); + GatherParameter *param = (GatherParameter *)gather_d->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + int input_rank = (int)gather_d->base.in_[0]->shape_size_; + NNACL_CHECK_FALSE(param->axis_ >= input_rank || param->axis_ < -input_rank, NNACL_GATHER_D_AXIS_INVALID); + + if (param->axis_ < 0) { + param->axis_ = param->axis_ + input_rank; + } + return NNACL_OK; +} + +int GatherDCompute(struct KernelBase *self) { + GatherDStru *gather_d_stru = (GatherDStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_d_stru); + GatherParameter *param = (GatherParameter *)gather_d_stru->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + TensorC *input = gather_d_stru->base.in_[0]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = gather_d_stru->base.out_[0]; + NNACL_CHECK_NULL_RETURN_ERR(output); + const void *input_data = input->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + const void *index_data = gather_d_stru->base.in_[2]->data_; + NNACL_CHECK_NULL_RETURN_ERR(index_data); + void *output_data = output->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + size_t input_shape[MAX_SHAPE_SIZE]; + for (size_t i = 0; i < input->shape_size_; i++) { + input_shape[i] = input->shape_[i]; + } + size_t output_shape[MAX_SHAPE_SIZE]; + for (size_t i = 0; i < output->shape_size_; i++) { + output_shape[i] = output->shape_[i]; + } + + int input_dtype = input->data_type_; + int index_dtype = gather_d_stru->base.in_[THIRD_INPUT]->data_type_; + int status = NNACL_ERR; + if (index_dtype == kNumberTypeInt32) { + if (input_dtype == kNumberTypeFloat32) { + status = GATHER_D(float, int32_t, (float *)output_data, (float *)input_data, (int32_t *)index_data, input_shape, + input->shape_size_, output_shape, output->shape_size_, param->axis_); + } else if (input_dtype == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + status = GATHER_D(float16_t, int32_t, (float16_t *)output_data, (float16_t *)input_data, (int32_t *)index_data, + input_shape, input->shape_size_, output_shape, output->shape_size_, param->axis_); +#endif + } else if (input_dtype == kNumberTypeInt32) { + status = GATHER_D(int32_t, int32_t, (int32_t *)output_data, (int32_t *)input_data, (int32_t *)index_data, + input_shape, input->shape_size_, output_shape, output->shape_size_, param->axis_); + } + } else if (index_dtype == kNumberTypeInt64) { + if (input_dtype == kNumberTypeFloat32) { + status = GATHER_D(float, int64_t, (float *)output_data, (float *)input_data, (int64_t *)index_data, input_shape, + input->shape_size_, output_shape, output->shape_size_, param->axis_); + } else if (input_dtype == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + status = GATHER_D(float16_t, int64_t, (float16_t *)output_data, (float16_t *)input_data, (int64_t *)index_data, + input_shape, input->shape_size_, output_shape, output->shape_size_, param->axis_); +#endif + } else if (input_dtype == kNumberTypeInt32) { + status = GATHER_D(int32_t, int64_t, (int32_t *)output_data, (int32_t *)input_data, (int64_t *)index_data, + input_shape, input->shape_size_, output_shape, output->shape_size_, param->axis_); + } + } + return status; +} + +KernelBase *CreateGatherD(OpParameter *param, int data_type) { + GatherDStru *gather_d = (GatherDStru *)malloc(sizeof(GatherDStru)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(gather_d); + gather_d->base.Prepare = GatherDPrepare; + gather_d->base.Resize = GatherDResize; + gather_d->base.Release = DefaultRelease; + gather_d->base.Compute = GatherDCompute; + return (KernelBase *)gather_d; +} + +REG_KERNEL_CREATOR(PrimType_GatherD, kNumberTypeFloat32, CreateGatherD); +REG_KERNEL_CREATOR(PrimType_GatherD, kNumberTypeInt32, CreateGatherD); +REG_KERNEL_CREATOR(PrimType_GatherD, kNumberTypeFloat16, CreateGatherD); diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_d.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_d.h new file mode 100644 index 0000000000000000000000000000000000000000..eb6f69e58e955a0bea655cecc9bd3b80b469214c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_d.h @@ -0,0 +1,25 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_GATHER_D_H_ +#define NNACL_KERNEL_GATHER_D_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateGatherD(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_GATHER_D_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_nd.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_nd.c new file mode 100644 index 0000000000000000000000000000000000000000..c1b29047368f5fcb62ef215186f2d8bcc896c0b5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_nd.c @@ -0,0 +1,168 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/gather_nd.h" +#include "nnacl_c/fp32/gatherNd_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/nnacl_common.h" + +int GatherNdInitOffset(GatherNdStruct *gather_nd) { + TensorC *input_tensor = gather_nd->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *indices_tensor = gather_nd->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(indices_tensor); + + if (indices_tensor->shape_size_ < 1) { + return NNACL_GATHER_ND_INDICES_RANK_INVALID; + } + + int in_rank = input_tensor->shape_size_; + int idx_lastshape = indices_tensor->shape_[indices_tensor->shape_size_ - 1]; + if (idx_lastshape > in_rank) { + return NNACL_GATHER_ND_INDICES_SHAPE_INVALID; + } + + gather_nd->area_ = 1; + for (int i = idx_lastshape; i < input_tensor->shape_size_; ++i) { + gather_nd->area_ *= input_tensor->shape_[i]; + } + + int in_stride[MAX_SHAPE_SIZE] = {0}; + in_stride[in_rank - 1] = 1; + for (int i = in_rank - 2; i >= 0; --i) { + in_stride[i] = input_tensor->shape_[i + 1] * in_stride[i + 1]; + } + + int idx_stride = idx_lastshape; + (void)memset(gather_nd->in_offset_, 0, gather_nd->count_ * sizeof(int)); + + if (indices_tensor->data_type_ == kNumberTypeInt || indices_tensor->data_type_ == kNumberTypeInt32) { + int32_t *indices_ptr = (int32_t *)indices_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(indices_ptr); + for (int j = 0; j < gather_nd->count_; ++j) { + for (int k = 0; k < idx_lastshape; ++k) { + gather_nd->in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride[k]; + } + } + } else if (indices_tensor->data_type_ == kNumberTypeInt64) { + int64_t *indices_ptr = (int64_t *)indices_tensor->data_; + for (int j = 0; j < gather_nd->count_; ++j) { + for (int k = 0; k < idx_lastshape; ++k) { + gather_nd->in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride[k]; + } + } + } else { + return NNACL_GATHER_ND_INDICES_DATA_TYPE_INVALID; + } + + return NNACL_OK; +} + +int GatherNdRun(void *cdata, int task_id, float l, float r) { + GatherNdStruct *gather_nd = (GatherNdStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd); + TensorC *input = gather_nd->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, gather_nd->thread_stride_, NNACL_ERR); + int count = NNACL_MIN(gather_nd->thread_stride_, gather_nd->count_ - task_id * gather_nd->thread_stride_); + if (count <= 0) { + return NNACL_OK; + } + + int offset = task_id * gather_nd->thread_stride_; + int dtype_len = DataTypeCSize(input->data_type_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(offset, gather_nd->area_, NNACL_ERR); + int8_t *out_ptr = (int8_t *)gather_nd->out_ptr_ + offset * gather_nd->area_ * dtype_len; + return GatherNd(gather_nd->in_ptr_, out_ptr, gather_nd->in_offset_ + offset, gather_nd->area_, count, dtype_len); +} + +int GatherNdCompute(KernelBase *self) { + GatherNdStruct *gather_nd = (GatherNdStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd); + + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + gather_nd->in_ptr_ = input->data_; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd->in_ptr_); + + TensorC *output = self->out_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(output); + gather_nd->out_ptr_ = output->data_; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd->out_ptr_); + + int ret = GatherNdInitOffset(gather_nd); + if (ret != NNACL_OK) { + return ret; + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, GatherNdRun, self, self->thread_nr_); +} + +int GatherNdRelease(KernelBase *self) { + GatherNdStruct *gather_nd = (GatherNdStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd); + if (gather_nd->in_offset_ != NULL) { + self->env_->Free(self->env_->allocator_, gather_nd->in_offset_); + gather_nd->in_offset_ = NULL; + } + return NNACL_OK; +} + +int GatherNdResize(KernelBase *self) { + (void)self->Release; + GatherNdStruct *gather_nd = (GatherNdStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd); + TensorC *indices_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(indices_tensor); + + gather_nd->count_ = 1; + for (int i = 0; i < indices_tensor->shape_size_ - 1; ++i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(gather_nd->count_, indices_tensor->shape_[i], NNACL_ERR); + gather_nd->count_ *= indices_tensor->shape_[i]; + } + + int min_count = INT32_MAX / sizeof(int); + if (gather_nd->count_ >= min_count) { + return NNACL_GATHER_ND_COUNT_INVALID; + } + + gather_nd->in_offset_ = self->env_->Alloc(self->env_->allocator_, gather_nd->count_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(gather_nd->in_offset_); + + gather_nd->base_.thread_nr_ = NNACL_MIN(gather_nd->base_.thread_nr_, gather_nd->count_); + if (gather_nd->base_.thread_nr_ != 0) { + gather_nd->thread_stride_ = UP_DIV(gather_nd->count_, gather_nd->base_.thread_nr_); + } + return NNACL_OK; +} + +KernelBase *CreateGatherNd(OpParameter *param, int data_type) { + GatherNdStruct *gather_nd = (GatherNdStruct *)malloc(sizeof(GatherNdStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(gather_nd); + memset(gather_nd, 0, sizeof(GatherNdStruct)); + + gather_nd->base_.Prepare = DefaultPrepare2In1Out; + gather_nd->base_.Resize = GatherNdResize; + gather_nd->base_.Compute = GatherNdCompute; + gather_nd->base_.Release = GatherNdRelease; + return (KernelBase *)gather_nd; +} + +REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeBool, CreateGatherNd); +REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeInt32, CreateGatherNd); +REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeFloat32, CreateGatherNd); +REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeFloat16, CreateGatherNd); diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_nd.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_nd.h new file mode 100644 index 0000000000000000000000000000000000000000..bb1f87f84fe3495ad8a8f8ca7d8f405afc88b4e2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_nd.h @@ -0,0 +1,35 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_GATHER_ND_H_ +#define NNACL_KERNEL_GATHER_ND_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct { + KernelBase base_; + int *in_offset_; + int count_; + int area_; + int thread_stride_; + void *in_ptr_; + void *out_ptr_; +} GatherNdStruct; + +KernelBase *CreateGatherNd(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_GATHER_ND_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_convolution.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_convolution.c new file mode 100644 index 0000000000000000000000000000000000000000..9a68f98cda5e6c230c92d6b6748c02a05fe85462 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_convolution.c @@ -0,0 +1,419 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/group_convolution.h" +#include "nnacl_c/kernel/convolution_delegate.h" +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/tensor_c_utils.h" + +int GroupConvBasePrepare(GroupConvolutionStruct *group_conv) { + for (int i = 0; i < group_conv->group_; ++i) { + KernelBase *sub_conv = group_conv->group_convs_[i]; + NNACL_CHECK_NULL_RETURN_ERR(sub_conv); + int ret = sub_conv->Prepare(sub_conv); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int GroupConvCreatorNewInputTensor(GroupConvolutionStruct *group_conv, KernelBase *new_conv) { + TensorC *in_tensor = (TensorC *)malloc(sizeof(TensorC)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(in_tensor); + in_tensor->format_ = Format_NHWC; + in_tensor->category_ = VarTensor; + in_tensor->data_type_ = group_conv->data_type_; + in_tensor->shape_size_ = DIMENSION_4D; + in_tensor->shape_[Index0] = INVALID_SHAPE; + new_conv->in_[FIRST_INPUT] = in_tensor; + return NNACL_OK; +} + +int GroupConvCreatorNewOutputTensor(GroupConvolutionStruct *group_conv, KernelBase *new_conv) { + TensorC *out_tensor = (TensorC *)malloc(sizeof(TensorC)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(out_tensor); + out_tensor->format_ = Format_NHWC; + out_tensor->category_ = VarTensor; + out_tensor->data_type_ = group_conv->data_type_; + out_tensor->shape_size_ = DIMENSION_4D; + out_tensor->shape_[Index0] = INVALID_SHAPE; + new_conv->out_[OUTPUT_INDEX] = out_tensor; + return NNACL_OK; +} + +TensorC *CreateConstTensor(const TensorC *tensor, const int *shape, const int shape_size, const int index) { + NNACL_CHECK_NULL_RETURN_NULL(tensor->data_); + + TensorC *new_tensor = (TensorC *)malloc(sizeof(TensorC)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(new_tensor); + new_tensor->data_type_ = tensor->data_type_; + new_tensor->format_ = Format_NHWC; + new_tensor->category_ = ConstTensor; + new_tensor->shape_size_ = shape_size; + memcpy(new_tensor->shape_, shape, shape_size * sizeof(int)); + + int size = NNACLGetSize(new_tensor); + if (size <= 0) { + free(new_tensor); + return NULL; + } + + void *data = malloc(size); + if (data == NULL) { + free(new_tensor); + return NULL; + } + new_tensor->data_ = data; + + uint8_t *new_tensor_data = (uint8_t *)tensor->data_ + index * size; + memcpy(new_tensor->data_, new_tensor_data, size); + return new_tensor; +} + +int GroupConvCreatorNewConstTensor(GroupConvolutionStruct *group_conv, KernelBase *new_conv, int group_id) { + TensorC *origin_weight = group_conv->conv_base_.base_.in_[SECOND_INPUT]; + int shape[] = {group_conv->sub_out_c_, NNACLGetHeight(origin_weight), NNACLGetWidth(origin_weight), + group_conv->sub_in_c_}; + TensorC *weight_tensor = CreateConstTensor(origin_weight, shape, DIMENSION_4D, group_id); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(weight_tensor); + new_conv->in_[SECOND_INPUT] = weight_tensor; + + if (group_conv->conv_base_.base_.in_size_ == THREE_TENSOR) { + TensorC *bias_weight = group_conv->conv_base_.base_.in_[THIRD_INPUT]; + TensorC *bias_tensor = CreateConstTensor(bias_weight, &group_conv->sub_out_c_, DIMENSION_1D, group_id); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(bias_tensor); + new_conv->in_[THIRD_INPUT] = bias_tensor; + } + return NNACL_OK; +} + +int GroupConvCreatorSetShapeOfTensors(GroupConvolutionStruct *group_conv) { + ConvParameter *origin_conv_param = (ConvParameter *)group_conv->conv_base_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(origin_conv_param); + ConvParameter *new_conv_param = &group_conv->new_conv_param_; + NNACL_CHECK_NULL_RETURN_ERR(new_conv_param); + memcpy(new_conv_param, origin_conv_param, sizeof(ConvParameter)); + + TensorC *weight_tensor = group_conv->conv_base_.base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(weight_tensor); + NNACL_CHECK_FALSE(origin_conv_param->group_ == 0, NNACL_GROUP_CONVOLUTION_GROUP_INVALID); + NNACL_CHECK_FALSE(weight_tensor->shape_size_ != DIMENSION_4D, NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + NNACL_CHECK_FALSE(origin_conv_param->kernel_h_ != NNACLGetHeight(weight_tensor), + NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + NNACL_CHECK_FALSE(origin_conv_param->kernel_w_ != NNACLGetWidth(weight_tensor), + NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + + ConvComputeParam *compute = &group_conv->conv_base_.compute_; + group_conv->ori_in_c_ = compute->in_c_; + group_conv->ori_out_c_ = compute->out_c_; + group_conv->sub_in_c_ = compute->in_c_ / group_conv->group_; + group_conv->sub_out_c_ = compute->out_c_ / group_conv->group_; + + new_conv_param->input_channel_ = group_conv->sub_in_c_; + new_conv_param->output_channel_ = group_conv->sub_out_c_; + new_conv_param->group_ = origin_conv_param->group_; + + return NNACL_OK; +} + +int GroupConvSetSubConvInfo(GroupConvolutionStruct *group_conv, KernelBase *new_conv, int group_id) { + NNACL_CHECK_NULL_RETURN_ERR(group_conv); + NNACL_CHECK_NULL_RETURN_ERR(new_conv); + + ConvolutionBaseStruct *sub_conv = (ConvolutionBaseStruct *)new_conv; + (void)ConvBaseUpdateParamInfo(&sub_conv->compute_, &group_conv->new_conv_param_); + + sub_conv->infershape_done_ = group_conv->conv_base_.infershape_done_; + sub_conv->shaing_manager_ = group_conv->conv_base_.shaing_manager_; + sub_conv->get_sharing_weight_ = group_conv->conv_base_.get_sharing_weight_; + sub_conv->free_sharing_weight_ = group_conv->conv_base_.free_sharing_weight_; + sub_conv->is_sharing_pack_ = group_conv->conv_base_.is_sharing_pack_; + + new_conv->env_ = group_conv->conv_base_.base_.env_; + new_conv->param_ = &group_conv->new_conv_param_.op_parameter_; + new_conv->thread_nr_ = group_conv->conv_base_.base_.thread_nr_; + new_conv->train_session_ = group_conv->conv_base_.base_.train_session_; + new_conv->UpdateThread = group_conv->conv_base_.base_.UpdateThread; + new_conv->in_size_ = group_conv->conv_base_.base_.in_size_; + new_conv->out_size_ = group_conv->conv_base_.base_.out_size_; + + new_conv->in_ = (TensorC **)malloc(new_conv->in_size_ * sizeof(TensorC *)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(new_conv->in_); + memset(new_conv->in_, 0, new_conv->in_size_ * sizeof(TensorC *)); + new_conv->out_ = (TensorC **)malloc(new_conv->out_size_ * sizeof(TensorC *)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(new_conv->out_); + memset(new_conv->out_, 0, new_conv->out_size_ * sizeof(TensorC *)); + + // create new input for each group + int ret = GroupConvCreatorNewInputTensor(group_conv, new_conv); + if (ret != NNACL_OK) { + group_conv->conv_base_.base_.Release((KernelBase *)group_conv); + return ret; + } + + // const tensor + ret = GroupConvCreatorNewConstTensor(group_conv, new_conv, group_id); + if (ret != NNACL_OK) { + group_conv->conv_base_.base_.Release((KernelBase *)group_conv); + return ret; + } + + // create new output tensor + ret = GroupConvCreatorNewOutputTensor(group_conv, new_conv); + if (ret != NNACL_OK) { + group_conv->conv_base_.base_.Release((KernelBase *)group_conv); + return ret; + } + return NNACL_OK; +} + +int GroupConvConcatOutputRun(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)cdata; + + int plane_step = UP_DIV(group_conv->conv_base_.compute_.out_hw_, group_conv->conv_base_.base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(plane_step, task_id, NNACL_ERR); + int begin_plane = plane_step * task_id; + int end_plane = NNACL_MIN(group_conv->conv_base_.compute_.out_hw_, plane_step * (task_id + 1)); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(begin_plane, group_conv->sub_out_c_, NNACL_ERR); + float *src_ptr = group_conv->sub_out_src_ + begin_plane * group_conv->sub_out_c_; + float *dst_ptr = group_conv->sub_out_dst_ + begin_plane * group_conv->ori_out_c_; + for (int i = begin_plane; i < end_plane; ++i) { + (void)memcpy(dst_ptr, src_ptr, group_conv->sub_out_c_ * sizeof(float)); + src_ptr += group_conv->sub_out_c_; + dst_ptr += group_conv->ori_out_c_; + } + return NNACL_OK; +} + +int GroupConvPostConcat(GroupConvolutionStruct *group_conv, int group_id) { + group_conv->sub_out_src_ = (float *)group_conv->group_convs_[group_id]->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(group_conv->sub_out_src_); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(group_id, group_conv->sub_out_c_, NNACL_ERR); + group_conv->sub_out_dst_ = (float *)(group_conv->origin_output_data_) + group_id * group_conv->sub_out_c_; + NNACL_CHECK_NULL_RETURN_ERR(group_conv->sub_out_dst_); + + return group_conv->conv_base_.base_.env_->ParallelLaunch(group_conv->conv_base_.base_.env_->thread_pool_, + GroupConvConcatOutputRun, group_conv, + group_conv->conv_base_.base_.thread_nr_); +} + +int GroupConvSeparateInputRun(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)cdata; + + int plane_step = UP_DIV(group_conv->conv_base_.compute_.in_hw_, group_conv->conv_base_.base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(plane_step, task_id, NNACL_ERR); + int begin_plane = plane_step * task_id; + int end_plane = NNACL_MIN(group_conv->conv_base_.compute_.in_hw_, plane_step * (task_id + 1)); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(begin_plane, group_conv->ori_in_c_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(begin_plane, group_conv->sub_in_c_, NNACL_ERR); + float *src_ptr = group_conv->sub_in_src_ + begin_plane * group_conv->ori_in_c_; + float *dst_ptr = group_conv->sub_in_dst_ + begin_plane * group_conv->sub_in_c_; + for (int i = begin_plane; i < end_plane; ++i) { + (void)memcpy(dst_ptr, src_ptr, group_conv->sub_in_c_ * sizeof(float)); + src_ptr += group_conv->ori_in_c_; + dst_ptr += group_conv->sub_in_c_; + } + + return NNACL_OK; +} + +int GroupConvSeparateInput(GroupConvolutionStruct *group_conv, int group_id) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(group_id, group_conv->sub_in_c_, NNACL_ERR); + + group_conv->sub_in_src_ = (float *)(group_conv->origin_input_data_) + group_id * group_conv->sub_in_c_; + NNACL_CHECK_NULL_RETURN_ERR(group_conv->sub_in_src_); + group_conv->sub_in_dst_ = (float *)(group_conv->group_convs_[group_id]->in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(group_conv->sub_in_dst_); + + return group_conv->conv_base_.base_.env_->ParallelLaunch(group_conv->conv_base_.base_.env_->thread_pool_, + GroupConvSeparateInputRun, group_conv, + group_conv->conv_base_.base_.thread_nr_); +} + +void GroupConvUpdateShape(GroupConvolutionStruct *group_conv) { + for (int i = 0; i < group_conv->group_; i++) { + TensorC *in_tensor = group_conv->conv_base_.base_.in_[FIRST_INPUT]; + int in_shape[] = {NNACLGetBatch(in_tensor), NNACLGetHeight(in_tensor), NNACLGetWidth(in_tensor), + group_conv->sub_in_c_}; + memcpy(group_conv->group_convs_[i]->in_[FIRST_INPUT]->shape_, in_shape, DIMENSION_4D * sizeof(float)); + + TensorC *out_tensor = group_conv->conv_base_.base_.out_[OUTPUT_INDEX]; + int out_shape[] = {NNACLGetBatch(out_tensor), NNACLGetHeight(out_tensor), NNACLGetWidth(out_tensor), + group_conv->sub_out_c_}; + memcpy(group_conv->group_convs_[i]->out_[OUTPUT_INDEX]->shape_, out_shape, DIMENSION_4D * sizeof(float)); + } + return; +} + +int GroupConvolutionResize(KernelBase *self) { + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(group_conv); + + (void)ConvBaseUpdateComputeInfo(&group_conv->conv_base_); + self->thread_nr_ = NNACL_MIN(NNACL_MAX(1, self->thread_nr_), group_conv->conv_base_.compute_.in_hw_); + self->thread_nr_ = NNACL_MIN(NNACL_MAX(1, self->thread_nr_), group_conv->conv_base_.compute_.in_hw_); + + GroupConvUpdateShape(group_conv); + + for (int i = 0; i < group_conv->group_; ++i) { + group_conv->group_convs_[i]->thread_nr_ = self->thread_nr_; + int ret = group_conv->group_convs_[i]->Resize(group_conv->group_convs_[i]); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int GroupConvolutionCompute(KernelBase *self) { + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(group_conv); + + group_conv->origin_input_data_ = self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(group_conv->origin_input_data_); + group_conv->origin_output_data_ = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(group_conv->origin_output_data_); + + for (int i = 0; i < group_conv->group_; ++i) { + // first, malloc data for sub_kernel's tensors. + TensorC *sub_kernel_in_tensor = group_conv->group_convs_[i]->in_[FIRST_INPUT]; + sub_kernel_in_tensor->data_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(sub_kernel_in_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(sub_kernel_in_tensor->data_); + + TensorC *sub_kernel_out_tensor = group_conv->group_convs_[i]->out_[OUTPUT_INDEX]; + sub_kernel_out_tensor->data_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(sub_kernel_out_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(sub_kernel_out_tensor->data_); + + // second, separate group conv input into several parts. This step must be in runtime stage. + int ret = GroupConvSeparateInput(group_conv, i); + if (ret != NNACL_OK) { + return ret; + } + + // sun kernels run + ret = group_conv->group_convs_[i]->Compute(group_conv->group_convs_[i]); + if (ret != NNACL_OK) { + return ret; + } + + // post process, concat all outputs of sub-kernels into one output + ret = GroupConvPostConcat(group_conv, i); + if (ret != NNACL_OK) { + return ret; + } + + // Free data + self->env_->Free(self->env_->allocator_, sub_kernel_in_tensor->data_); + sub_kernel_in_tensor->data_ = NULL; + self->env_->Free(self->env_->allocator_, sub_kernel_out_tensor->data_); + sub_kernel_out_tensor->data_ = NULL; + } + return NNACL_OK; +} + +int GroupConvolutionPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ != ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(group_conv); + NNACL_CHECK_FALSE(group_conv->group_ == 0, NNACL_GROUP_CONVOLUTION_GROUP_INVALID); + + GroupConvCreatorSetShapeOfTensors(group_conv); + + group_conv->group_convs_ = (KernelBase **)malloc(group_conv->group_ * sizeof(KernelBase *)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(group_conv->group_convs_); + memset(group_conv->group_convs_, 0, group_conv->group_ * sizeof(KernelBase *)); + + for (int i = 0; i < group_conv->group_; ++i) { + KernelBase *new_conv = CreateConvlutionDelegate(&group_conv->new_conv_param_); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(new_conv); + group_conv->group_convs_[i] = new_conv; + + int ret = GroupConvSetSubConvInfo(group_conv, new_conv, i); + if (ret != NNACL_OK) { + return ret; + } + } + return GroupConvBasePrepare(group_conv); +} + +void GroupConvReleaseSubConv(KernelBase *current_conv) { + (void)current_conv->Release(current_conv); + + if (current_conv->in_ != NULL) { + for (int j = 0; j < current_conv->in_size_; j++) { + if (NNACLIsConst(current_conv->in_[j])) { + free(current_conv->in_[j]->data_); + current_conv->in_[j]->data_ = NULL; + } + if (current_conv->in_[j] != NULL) { + free(current_conv->in_[j]); + current_conv->in_[j] = NULL; + } + } + free(current_conv->in_); + current_conv->in_ = NULL; + } + + if (current_conv->out_ != NULL) { + for (int j = 0; j < current_conv->out_size_; j++) { + if (current_conv->out_[j] != NULL) { + free(current_conv->out_[j]); + current_conv->out_[j] = NULL; + } + } + free(current_conv->out_); + current_conv->out_ = NULL; + } +} + +int GroupConvolutionRelease(KernelBase *self) { + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(group_conv); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + if (group_conv->group_convs_ != NULL) { + for (int i = 0; i < conv_param->group_; i++) { + if (group_conv->group_convs_[i] != NULL) { + GroupConvReleaseSubConv(group_conv->group_convs_[i]); + free(group_conv->group_convs_[i]); + group_conv->group_convs_[i] = NULL; + } + } + free(group_conv->group_convs_); + group_conv->group_convs_ = NULL; + } + return NNACL_OK; +} + +KernelBase *CreateGroupConvolution(ConvParameter *conv_param, TypeIdC data_type) { + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)malloc(sizeof(GroupConvolutionStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(group_conv); + memset(group_conv, 0, sizeof(GroupConvolutionStruct)); + + group_conv->data_type_ = data_type; + group_conv->group_ = conv_param->group_; + group_conv->conv_base_.base_.Compute = GroupConvolutionCompute; + group_conv->conv_base_.base_.Resize = GroupConvolutionResize; + group_conv->conv_base_.base_.Prepare = GroupConvolutionPrepare; + group_conv->conv_base_.base_.Release = GroupConvolutionRelease; + return (KernelBase *)group_conv; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_convolution.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..4d061b6ec53a092fcb1d6d73e97bb3e85f307fb1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_convolution.h @@ -0,0 +1,49 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_GROUP_CONVOLUTION_H_ +#define NNACL_KERNEL_GROUP_CONVOLUTION_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct GroupConvolutionStruct { + ConvolutionBaseStruct conv_base_; + KernelBase **group_convs_; + ConvParameter new_conv_param_; + TypeIdC data_type_; + int group_; + + void *origin_input_data_; + void *origin_output_data_; + + float *sub_in_src_; + float *sub_in_dst_; + float *sub_out_src_; + float *sub_out_dst_; + + int sub_in_c_; + int ori_in_c_; + int sub_out_c_; + int ori_out_c_; +} GroupConvolutionStruct; + +KernelBase *CreateGroupConvolution(ConvParameter *conv_param, TypeIdC data_type); + +#endif // NNACL_KERNEL_GROUP_CONVOLUTION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_norm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_norm.c new file mode 100644 index 0000000000000000000000000000000000000000..aabbfb94aaf42e6558e9c9f0e29c9f4ca4805d7b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_norm.c @@ -0,0 +1,122 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/group_norm.h" +#include "nnacl_c/fp32/group_norm_fp32.h" +#include "nnacl_c/group_norm_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/tensor_c_utils.h" + +int GroupNormResize(struct KernelBase *self) { + GroupNormStru *groupnorm = (GroupNormStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(groupnorm); + GroupNormParameter *param = (GroupNormParameter *)groupnorm->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_FALSE(self->in_size_ < kInputSize2, NNACL_TENSOR_SIZE_INVALID); + NNACL_CHECK_FALSE(self->out_size_ < 1, NNACL_TENSOR_SIZE_INVALID); + + self->Release(self); + + TensorC *in0 = self->in_[0]; + NNACL_CHECK_FALSE(in0->shape_size_ < C1NUM, NNACL_GROUP_NORM_SHAPE_SIZE_INVALID); + NNACL_CHECK_FALSE(in0->format_ != Format_NCHW, NNACL_GROUP_NORM_FORMAT_INVALID); + + param->unit_ = NNACLGetHeight(in0) * NNACLGetWidth(in0); + param->batch_ = NNACLGetBatch(in0); + param->channel_ = NNACLGetChannel(in0); + return self->Prepare(self); +} + +int GroupNormPrepare(struct KernelBase *self) { + GroupNormStru *groupnorm = (GroupNormStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(groupnorm); + GroupNormParameter *param = (GroupNormParameter *)groupnorm->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_FALSE(param->num_groups_ < 0, NNACL_GROUP_NORM_NUM_GROUPS_INVALID); + NNACL_CHECK_FALSE(param->channel_ % param->num_groups_, NNACL_GROUP_NORM_NUM_GROUPS_INVALID); + NNACL_CHECK_FALSE(param->num_groups_ == 0, NNACL_GROUP_NORM_NUM_GROUPS_INVALID); + + size_t mean_var_elem_num = param->num_groups_; + param->mean_ = malloc(mean_var_elem_num * sizeof(float)); + param->variance_ = malloc(mean_var_elem_num * sizeof(float)); + if (param->mean_ == NULL || param->variance_ == NULL) { + self->Release(self); + return NNACL_MALLOC_BUFFER_FAILED; + } + return NNACL_OK; +} + +int GroupNormRelease(struct KernelBase *self) { + GroupNormStru *groupnorm = (GroupNormStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(groupnorm); + GroupNormParameter *param = (GroupNormParameter *)groupnorm->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + if (param->mean_ != NULL) { + free(param->mean_); + param->mean_ = NULL; + } + if (param->variance_ != NULL) { + free(param->variance_); + param->variance_ = NULL; + } + + return NNACL_OK; +} + +int GroupNormImpl(void *param, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(param); + GroupNormStru *groupnorm_stru = (GroupNormStru *)param; + GroupNormParameter *groupnorm_param = (GroupNormParameter *)groupnorm_stru->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(groupnorm_param); + + const void *input_data = groupnorm_stru->base.in_[0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + const void *scale_data = groupnorm_stru->base.in_[C1NUM]->data_; + NNACL_CHECK_NULL_RETURN_ERR(scale_data); + const void *offset_data = groupnorm_stru->base.in_[C2NUM]->data_; + NNACL_CHECK_NULL_RETURN_ERR(offset_data); + void *output_data = groupnorm_stru->base.out_[0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + NNACL_CHECK_NULL_RETURN_ERR(groupnorm_param->mean_); + NNACL_CHECK_NULL_RETURN_ERR(groupnorm_param->variance_); + + int ret = GroupNormFp32(input_data, scale_data, offset_data, groupnorm_param->mean_, groupnorm_param->variance_, + groupnorm_param, task_id, output_data); + + return ret; +} + +int GroupNormCompute(struct KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, GroupNormImpl, self, self->param_->thread_num_); +} + +KernelBase *CreateGroupNorm(OpParameter *param, int data_type) { + GroupNormStru *groupnorm = (GroupNormStru *)malloc(sizeof(GroupNormStru)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(groupnorm); + + groupnorm->base.Prepare = GroupNormPrepare; + groupnorm->base.Resize = GroupNormResize; + groupnorm->base.Release = GroupNormRelease; + groupnorm->base.Compute = GroupNormCompute; + + return (void *)groupnorm; +} + +REG_KERNEL_CREATOR(PrimType_GroupNormFusion, kNumberTypeFloat32, CreateGroupNorm); diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_norm.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..79ff15ab9b32de123becd4e0b66f9b9b199a3c0a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_norm.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_GROUP_NORM_H_ +#define NNACL_KERNEL_GROUP_NORM_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/group_norm_parameter.h" +#include "nnacl_c/kernel.h" + +typedef struct GroupNormStru { + KernelBase base; +} GroupNormStru; + +KernelBase *CreateGroupNorm(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_GROUP_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_exec_env.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_exec_env.c new file mode 100644 index 0000000000000000000000000000000000000000..0913d714f641ad9b5a84cea9da292ddeb3ed03b9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_exec_env.c @@ -0,0 +1,51 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/init_exec_env.h" + +#define NNACLMaxAllocSize (2000 * 1024 * 1024) +ExecEnv nnacl_default_env; + +void *NNACLDefaultAlloc(void *allocator, size_t sz) { + if (sz == 0 || sz > NNACLMaxAllocSize) { + return NULL; + } + return malloc(sz); +} + +void NNACLDefaultFree(void *allocator, void *ptr) { return free(ptr); } + +int NNACLDefaultParallelLunch(void *threadPool, void *task, void *param, int taskNr) { + int (*function)(void *cdata, int task_id, float l, float r) = task; + int ret = 0; + for (int i = 0; i < taskNr; i++) { + ret += function(param, i, 0, 1); + } + return ret == NNACL_OK ? NNACL_OK : NNACL_ERR; +} + +void InitDefaultExecEnv(void) { + nnacl_default_env.Free = NNACLDefaultFree; + nnacl_default_env.Alloc = NNACLDefaultAlloc; + nnacl_default_env.ParallelLaunch = NNACLDefaultParallelLunch; +} + +void CheckExecEnv(KernelBase *base) { + if (base->env_ == NULL) { + base->env_ = &nnacl_default_env; + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_exec_env.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_exec_env.h new file mode 100644 index 0000000000000000000000000000000000000000..ee41705163af27af9de522aaf94f90edaf7b8945 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_exec_env.h @@ -0,0 +1,27 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_INIT_EXEC_ENV_H_ +#define NNACL_KERNEL_INIT_EXEC_ENV_H_ + +#include "nnacl_c/kernel.h" + +#ifndef _MSC_VER +__attribute__((constructor(103))) void InitDefaultExecEnv(void); +#endif + +void CheckExecEnv(KernelBase *base); + +#endif // NNACL_KERNEL_INIT_EXEC_ENV_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_vs_kernels.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_vs_kernels.c new file mode 100644 index 0000000000000000000000000000000000000000..43a29485a05e1d95713a6ebd7048f5df7afcbb5f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_vs_kernels.c @@ -0,0 +1,357 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/init_vs_kernels.h" +#include "nnacl_c/kernel/activation.h" +#include "nnacl_c/kernel/arithmetic.h" +#include "nnacl_c/kernel/arithmetic_compare.h" +#include "nnacl_c/kernel/arithmetic_self.h" +#include "nnacl_c/kernel/arg_min_max.h" +#include "nnacl_c/kernel/addn.h" +#include "nnacl_c/kernel/biasadd.h" +#include "nnacl_c/kernel/batch_norm.h" +#include "nnacl_c/kernel/clip.h" +#include "nnacl_c/kernel/concat.h" +#include "nnacl_c/kernel/crop.h" +#include "nnacl_c/kernel/crop_and_resize.h" +#include "nnacl_c/kernel/exp.h" +#include "nnacl_c/kernel/depth_to_space.h" +#include "nnacl_c/kernel/fill.h" +#include "nnacl_c/kernel/fused_batch_norm.h" +#include "nnacl_c/kernel/fullconnection.h" +#include "nnacl_c/kernel/gather.h" +#include "nnacl_c/kernel/gather_d.h" +#include "nnacl_c/kernel/gather_nd.h" +#include "nnacl_c/kernel/group_norm.h" +#include "nnacl_c/kernel/log_softmax.h" +#include "nnacl_c/kernel/local_response_norm.h" +#include "nnacl_c/kernel/layer_norm.h" +#include "nnacl_c/kernel/matmul.h" +#include "nnacl_c/kernel/non_max_suppression.h" +#include "nnacl_c/kernel/non_zero.h" +#include "nnacl_c/kernel/nllloss.h" +#include "nnacl_c/kernel/prior_box.h" +#include "nnacl_c/kernel/prelu.h" +#include "nnacl_c/kernel/pad.h" +#include "nnacl_c/kernel/pow.h" +#include "nnacl_c/kernel/reshape.h" +#include "nnacl_c/kernel/reverse.h" +#include "nnacl_c/kernel/range.h" +#include "nnacl_c/kernel/rank.h" +#include "nnacl_c/kernel/scale.h" +#include "nnacl_c/kernel/shape.h" +#include "nnacl_c/kernel/reduce.h" +#include "nnacl_c/kernel/ragged_range.h" +#include "nnacl_c/kernel/stack.h" +#include "nnacl_c/kernel/strided_slice.h" +#include "nnacl_c/kernel/softmax.h" +#include "nnacl_c/kernel/size.h" +#include "nnacl_c/kernel/splice.h" +#include "nnacl_c/kernel/tile.h" +#include "nnacl_c/kernel/tril.h" +#include "nnacl_c/kernel/triu.h" +#include "nnacl_c/kernel/transpose.h" +#include "nnacl_c/kernel/slice.h" +#include "nnacl_c/kernel/unique.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/kernel/f16/arithmetic_f16.h" +#include "nnacl_c/kernel/f16/arithmetic_compare_f16.h" +#include "nnacl_c/kernel/f16/concat_f16.h" +#include "nnacl_c/kernel/f16/reduce_f16.h" +#include "nnacl_c/kernel/f16/stack_f16.h" +#endif + +void InitVSKernelF16(KernelCreator **creators) { +#ifdef ENABLE_FP16 + creators[PrimType_Abs][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Activation][REGIST_DT(kNumberTypeFloat16)] = CreateActivation; + creators[PrimType_AddFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_AddN][REGIST_DT(kNumberTypeFloat16)] = CreateAddN; + creators[PrimType_ArgMinFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArgMinMax; + creators[PrimType_ArgMaxFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArgMinMax; + creators[PrimType_BatchNorm][REGIST_DT(kNumberTypeFloat16)] = CreateBatchNorm; + creators[PrimType_Ceil][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Concat][REGIST_DT(kNumberTypeFloat16)] = CreateConcatF16; + creators[PrimType_Cos][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Crop][REGIST_DT(kNumberTypeFloat16)] = CreateCrop; + creators[PrimType_DivFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_DepthToSpace][REGIST_DT(kNumberTypeFloat16)] = CreateDepthToSpace; + creators[PrimType_Eltwise][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_Erf][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Equal][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_ExpandDims][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_Fill][REGIST_DT(kNumberTypeFloat16)] = CreateFill; + creators[PrimType_Flatten][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_FlattenGrad][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_Floor][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_FloorMod][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_FloorDiv][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_FusedBatchNorm][REGIST_DT(kNumberTypeFloat16)] = CreateFusedBatchNorm; + creators[PrimType_Gather][REGIST_DT(kNumberTypeFloat16)] = CreateGather; + creators[PrimType_GatherD][REGIST_DT(kNumberTypeFloat16)] = CreateGatherD; + creators[PrimType_GatherNd][REGIST_DT(kNumberTypeFloat16)] = CreateGatherNd; + creators[PrimType_Greater][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_GreaterEqual][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_Less][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_LessEqual][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_LayerNormFusion][REGIST_DT(kNumberTypeFloat16)] = CreateLayerNorm; + creators[PrimType_LogicalAnd][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_LogicalOr][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_Log][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_LogSoftmax][REGIST_DT(kNumberTypeFloat16)] = CreateLogSoftmax; + creators[PrimType_LogicalNot][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Maximum][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_Minimum][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_MulFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_Neg][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_NotEqual][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_PadFusion][REGIST_DT(kNumberTypeFloat16)] = CreatePad; + creators[PrimType_PReLUFusion][REGIST_DT(kNumberTypeFloat16)] = CreatePRelu; + creators[PrimType_PowFusion][REGIST_DT(kNumberTypeFloat16)] = CreatePow; + creators[PrimType_Reshape][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_RealDiv][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_ReduceFusion][REGIST_DT(kNumberTypeFloat16)] = CreateReduceF16; + creators[PrimType_Rsqrt][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Round][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Reciprocal][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_ScaleFusion][REGIST_DT(kNumberTypeFloat16)] = CreateScale; + creators[PrimType_Shape][REGIST_DT(kNumberTypeFloat16)] = CreateShape; + creators[PrimType_Softmax][REGIST_DT(kNumberTypeFloat16)] = CreateSoftmax; + creators[PrimType_Stack][REGIST_DT(kNumberTypeFloat16)] = CreateStackF16; + creators[PrimType_StridedSlice][REGIST_DT(kNumberTypeFloat16)] = CreateStridedSlice; + creators[PrimType_Squeeze][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_SubFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_SquaredDifference][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_Splice][REGIST_DT(kNumberTypeFloat16)] = CreateSplice; + creators[PrimType_Sin][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Size][REGIST_DT(kNumberTypeFloat16)] = CreateSize; + creators[PrimType_SliceFusion][REGIST_DT(kNumberTypeFloat16)] = CreateSlice; + creators[PrimType_Square][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Sqrt][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_TileFusion][REGIST_DT(kNumberTypeFloat16)] = CreateTile; + creators[PrimType_Triu][REGIST_DT(kNumberTypeFloat16)] = CreateTriu; + creators[PrimType_Tril][REGIST_DT(kNumberTypeFloat16)] = CreateTril; + creators[PrimType_Transpose][REGIST_DT(kNumberTypeFloat16)] = CreateTranspose; + creators[PrimType_Unsqueeze][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_Unique][REGIST_DT(kNumberTypeFloat16)] = CreateUnique; +#endif +} + +void InitVSKernelA(KernelCreator **creators) { + creators[PrimType_Abs][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Abs][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticSelf; + creators[PrimType_Activation][REGIST_DT(kNumberTypeFloat32)] = CreateActivation; + creators[PrimType_Activation][REGIST_DT(kNumberTypeUInt32)] = CreateActivation; + creators[PrimType_AddFusion][REGIST_DT(kNumberTypeBool)] = CreateArithmetic; + creators[PrimType_AddFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_AddFusion][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_AddN][REGIST_DT(kNumberTypeFloat32)] = CreateAddN; + creators[PrimType_ArgMinFusion][REGIST_DT(kNumberTypeInt32)] = CreateArgMinMax; + creators[PrimType_ArgMinFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArgMinMax; + creators[PrimType_ArgMaxFusion][REGIST_DT(kNumberTypeInt32)] = CreateArgMinMax; + creators[PrimType_ArgMaxFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArgMinMax; + creators[PrimType_BiasAdd][REGIST_DT(kNumberTypeFloat32)] = CreateBiasAdd; + creators[PrimType_BatchNorm][REGIST_DT(kNumberTypeFloat32)] = CreateBatchNorm; + creators[PrimType_Ceil][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Cos][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Clip][REGIST_DT(kNumberTypeFloat)] = CreateClip; + creators[PrimType_Clip][REGIST_DT(kNumberTypeFloat32)] = CreateClip; + creators[PrimType_Clip][REGIST_DT(kNumberTypeInt)] = CreateClip; + creators[PrimType_Clip][REGIST_DT(kNumberTypeInt32)] = CreateClip; + creators[PrimType_Concat][REGIST_DT(kNumberTypeBool)] = CreateConcat; + creators[PrimType_Concat][REGIST_DT(kNumberTypeInt32)] = CreateConcat; + creators[PrimType_Concat][REGIST_DT(kNumberTypeFloat32)] = CreateConcat; + creators[PrimType_Crop][REGIST_DT(kNumberTypeInt32)] = CreateCrop; + creators[PrimType_Crop][REGIST_DT(kNumberTypeFloat32)] = CreateCrop; + creators[PrimType_CropAndResize][REGIST_DT(kNumberTypeFloat32)] = CreateCropAndResize; + creators[PrimType_DepthToSpace][REGIST_DT(kNumberTypeFloat32)] = CreateDepthToSpace; + creators[PrimType_DivFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_DivFusion][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_Eltwise][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_Equal][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_Equal][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_Erf][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_ExpFusion][REGIST_DT(kNumberTypeFloat32)] = CreateExp; + creators[PrimType_ExpandDims][REGIST_DT(kNumberTypeInt32)] = CreateReshape; + creators[PrimType_ExpandDims][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_ExpandDims][REGIST_DT(kNumberTypeBool)] = CreateReshape; + creators[PrimType_ExpandDims][REGIST_DT(kNumberTypeInt8)] = CreateReshape; + creators[PrimType_Fill][REGIST_DT(kNumberTypeBool)] = CreateFill; + creators[PrimType_Fill][REGIST_DT(kNumberTypeInt32)] = CreateFill; + creators[PrimType_Fill][REGIST_DT(kNumberTypeFloat32)] = CreateFill; + creators[PrimType_Flatten][REGIST_DT(kNumberTypeInt32)] = CreateReshape; + creators[PrimType_Flatten][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_FlattenGrad][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_Floor][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_FloorDiv][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_FloorDiv][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_FloorMod][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_FloorMod][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_FullConnection][REGIST_DT(kNumberTypeFloat32)] = CreateFullconnection; + creators[PrimType_FusedBatchNorm][REGIST_DT(kNumberTypeFloat32)] = CreateFusedBatchNorm; + creators[PrimType_Gather][REGIST_DT(kNumberTypeFloat32)] = CreateGather; + creators[PrimType_Gather][REGIST_DT(kNumberTypeInt32)] = CreateGather; + creators[PrimType_Gather][REGIST_DT(kNumberTypeBool)] = CreateGather; + creators[PrimType_GatherD][REGIST_DT(kNumberTypeFloat32)] = CreateGatherD; + creators[PrimType_GatherD][REGIST_DT(kNumberTypeInt32)] = CreateGatherD; + creators[PrimType_GatherNd][REGIST_DT(kNumberTypeBool)] = CreateGatherNd; + creators[PrimType_GatherNd][REGIST_DT(kNumberTypeInt32)] = CreateGatherNd; + creators[PrimType_GatherNd][REGIST_DT(kNumberTypeFloat32)] = CreateGatherNd; + creators[PrimType_Greater][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_Greater][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_GreaterEqual][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_GreaterEqual][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_GroupNormFusion][REGIST_DT(kNumberTypeFloat32)] = CreateGroupNorm; +} + +void InitVSKernelI(KernelCreator **creators) { + creators[PrimType_IsFinite][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_LayerNormFusion][REGIST_DT(kNumberTypeFloat32)] = CreateLayerNorm; + creators[PrimType_Less][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_Less][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_LessEqual][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_LessEqual][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_LogicalAnd][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_LogicalAnd][REGIST_DT(kNumberTypeBool)] = CreateArithmetic; + creators[PrimType_LogicalAnd][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_LogicalOr][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_LogicalOr][REGIST_DT(kNumberTypeBool)] = CreateArithmetic; + creators[PrimType_Log][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_LogSoftmax][REGIST_DT(kNumberTypeFloat32)] = CreateLogSoftmax; + creators[PrimType_Log1p][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_LogicalNot][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_LogicalNot][REGIST_DT(kNumberTypeBool)] = CreateArithmeticSelf; + creators[PrimType_LRN][REGIST_DT(kNumberTypeFloat32)] = CreateLocalResponseNorm; + creators[PrimType_Maximum][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_Maximum][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_MatMulFusion][REGIST_DT(kNumberTypeFloat32)] = CreateMatmul; + creators[PrimType_Mod][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_Mod][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_MulFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_MulFusion][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_Minimum][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_Minimum][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_NLLLoss][REGIST_DT(kNumberTypeFloat32)] = CreateNLLLoss; + creators[PrimType_Neg][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Neg][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticSelf; + creators[PrimType_NotEqual][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_NotEqual][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_NotEqual][REGIST_DT(kNumberTypeInt64)] = CreateArithmeticCompare; + creators[PrimType_NonZero][REGIST_DT(kNumberTypeBool)] = CreateNonZero; + creators[PrimType_NonMaxSuppression][REGIST_DT(kNumberTypeFloat32)] = CreateNonMaxSuppression; + creators[PrimType_PadFusion][REGIST_DT(kNumberTypeFloat32)] = CreatePad; + creators[PrimType_PriorBox][REGIST_DT(kNumberTypeFloat32)] = CreatePriorBox; + creators[PrimType_PriorBox][REGIST_DT(kNumberTypeInt8)] = CreatePriorBox; + creators[PrimType_PowFusion][REGIST_DT(kNumberTypeFloat32)] = CreatePow; + creators[PrimType_PReLUFusion][REGIST_DT(kNumberTypeFloat32)] = CreatePRelu; +} + +void InitVSKernelR(KernelCreator **creators) { + creators[PrimType_RaggedRange][REGIST_DT(kNumberTypeInt32)] = CreateRaggedRange; + creators[PrimType_RaggedRange][REGIST_DT(kNumberTypeFloat32)] = CreateRaggedRange; + creators[PrimType_Range][REGIST_DT(kNumberTypeFloat32)] = CreateRange; + creators[PrimType_Range][REGIST_DT(kNumberTypeInt32)] = CreateRange; + creators[PrimType_Range][REGIST_DT(kNumberTypeFloat16)] = CreateRange; + creators[PrimType_Rank][REGIST_DT(kNumberTypeFloat32)] = CreateRank; + creators[PrimType_Rank][REGIST_DT(kNumberTypeFloat32)] = CreateRank; + creators[PrimType_Reshape][REGIST_DT(kNumberTypeInt32)] = CreateReshape; + creators[PrimType_Reshape][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_Reshape][REGIST_DT(kNumberTypeBool)] = CreateReshape; + creators[PrimType_RealDiv][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_ReduceFusion][REGIST_DT(kNumberTypeBool)] = CreateReduce; + creators[PrimType_ReduceFusion][REGIST_DT(kNumberTypeInt32)] = CreateReduce; + creators[PrimType_ReduceFusion][REGIST_DT(kNumberTypeFloat32)] = CreateReduce; + creators[PrimType_Reciprocal][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_ReverseV2][REGIST_DT(kNumberTypeInt32)] = CreateReverse; + creators[PrimType_ReverseV2][REGIST_DT(kNumberTypeFloat32)] = CreateReverse; + creators[PrimType_Round][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Rsqrt][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_ScaleFusion][REGIST_DT(kNumberTypeFloat32)] = CreateScale; + creators[PrimType_Shape][REGIST_DT(kNumberTypeInt32)] = CreateShape; + creators[PrimType_Shape][REGIST_DT(kNumberTypeBool)] = CreateShape; + creators[PrimType_Shape][REGIST_DT(kNumberTypeFloat32)] = CreateShape; + creators[PrimType_Shape][REGIST_DT(kNumberTypeInt8)] = CreateShape; + creators[PrimType_Shape][REGIST_DT(kNumberTypeUInt8)] = CreateShape; + creators[PrimType_Shape][REGIST_DT(kNumberTypeInt64)] = CreateShape; + creators[PrimType_Softmax][REGIST_DT(kNumberTypeFloat32)] = CreateSoftmax; + creators[PrimType_SquaredDifference][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_Stack][REGIST_DT(kNumberTypeFloat32)] = CreateStack; + creators[PrimType_Stack][REGIST_DT(kNumberTypeInt32)] = CreateStack; + creators[PrimType_StridedSlice][REGIST_DT(kNumberTypeFloat32)] = CreateStridedSlice; + creators[PrimType_StridedSlice][REGIST_DT(kNumberTypeInt64)] = CreateStridedSlice; + creators[PrimType_StridedSlice][REGIST_DT(kNumberTypeInt32)] = CreateStridedSlice; + creators[PrimType_StridedSlice][REGIST_DT(kNumberTypeInt8)] = CreateStridedSlice; + creators[PrimType_Square][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Sqrt][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Sin][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Size][REGIST_DT(kNumberTypeInt32)] = CreateSize; + creators[PrimType_Size][REGIST_DT(kNumberTypeFloat32)] = CreateSize; + creators[PrimType_SliceFusion][REGIST_DT(kNumberTypeInt32)] = CreateSlice; + creators[PrimType_SliceFusion][REGIST_DT(kNumberTypeFloat32)] = CreateSlice; + creators[PrimType_SubFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_SubFusion][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_Squeeze][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_Squeeze][REGIST_DT(kNumberTypeInt32)] = CreateReshape; + creators[PrimType_Squeeze][REGIST_DT(kNumberTypeBool)] = CreateReshape; + creators[PrimType_Splice][REGIST_DT(kNumberTypeFloat32)] = CreateSplice; + creators[PrimType_TileFusion][REGIST_DT(kNumberTypeInt32)] = CreateTile; + creators[PrimType_TileFusion][REGIST_DT(kNumberTypeFloat32)] = CreateTile; + creators[PrimType_TileFusion][REGIST_DT(kNumberTypeBool)] = CreateTile; + creators[PrimType_TileFusion][REGIST_DT(kNumberTypeUInt8)] = CreateTile; + creators[PrimType_Triu][REGIST_DT(kNumberTypeDouble)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeFloat)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeFloat64)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeFloat32)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeInt)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeInt64)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeInt32)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeInt16)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeInt8)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeUInt64)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeUInt32)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeUInt16)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeUInt8)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeBool)] = CreateTriu; + creators[PrimType_Tril][REGIST_DT(kNumberTypeDouble)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeFloat)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeFloat64)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeFloat32)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeInt)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeInt64)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeInt32)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeInt16)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeInt8)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeUInt64)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeUInt32)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeUInt16)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeUInt8)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeBool)] = CreateTril; + creators[PrimType_Transpose][REGIST_DT(kNumberTypeFloat32)] = CreateTranspose; + creators[PrimType_Transpose][REGIST_DT(kNumberTypeInt32)] = CreateTranspose; + creators[PrimType_Unsqueeze][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_Unsqueeze][REGIST_DT(kNumberTypeInt32)] = CreateReshape; + creators[PrimType_Unsqueeze][REGIST_DT(kNumberTypeInt64)] = CreateReshape; + creators[PrimType_Unsqueeze][REGIST_DT(kNumberTypeBool)] = CreateReshape; + creators[PrimType_Unique][REGIST_DT(kNumberTypeInt32)] = CreateUnique; + creators[PrimType_Unique][REGIST_DT(kNumberTypeFloat32)] = CreateUnique; +} + +void init_vs_kernels(KernelCreator **creators) { + InitVSKernelA(creators); + InitVSKernelI(creators); + InitVSKernelR(creators); + InitVSKernelF16(creators); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_vs_kernels.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_vs_kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..1a8c9d53850df5f1be1e9ded29dfedd0545dde0f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_vs_kernels.h @@ -0,0 +1,20 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_INIT_VS_KERNELS_H_ +#define NNACL_KERNEL_INIT_VS_KERNELS_H_ +#include "nnacl_c/kernel.h" +void init_vs_kernels(KernelCreator **creators); +#endif // NNACL_KERNEL_INIT_VS_KERNELS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/layer_norm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/layer_norm.c new file mode 100644 index 0000000000000000000000000000000000000000..de70e04a2b5a5a719d21f90836d333d0d0e9259a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/layer_norm.c @@ -0,0 +1,130 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either log_softmaxress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/layer_norm.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/fp32/layer_norm_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/layer_norm_fp16.h" +#endif + +int LayerNormRun(void *cdata, int task_id, float l, float r) { + LayerNormStruct *ln = (LayerNormStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(ln); + if (ln->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + return LayerNormFp16(ln->src_data_, ln->gamma_data_, ln->beta_data_, ln->dst_data_, ln->mean_data_, ln->var_data_, + &ln->compute_, task_id, ln->base_.thread_nr_); +#endif + } + return LayerNorm(ln->src_data_, ln->gamma_data_, ln->beta_data_, ln->dst_data_, ln->mean_data_, ln->var_data_, + &ln->compute_, task_id, ln->base_.thread_nr_); +} + +int LayerNormResize(KernelBase *self) { + LayerNormStruct *layer_norm = (LayerNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm); + LayerNormComputeParam *compute = &layer_norm->compute_; + + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + + if (compute->begin_norm_axis_ < 0) { + compute->begin_norm_axis_ = compute->begin_norm_axis_ + (int)input->shape_size_; + } + + if (compute->begin_params_axis_ < 0) { + compute->begin_params_axis_ = compute->begin_params_axis_ + (int)input->shape_size_; + } + + compute->norm_outer_size_ = 1; + for (int i = 0; i < compute->begin_norm_axis_; ++i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->norm_outer_size_, input->shape_[i], NNACL_ERR); + compute->norm_outer_size_ *= input->shape_[i]; + } + + compute->norm_inner_size_ = 1; + for (size_t i = compute->begin_norm_axis_; i < input->shape_size_; ++i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->norm_inner_size_, input->shape_[i], NNACL_ERR); + compute->norm_inner_size_ *= input->shape_[i]; + } + + compute->params_outer_size_ = 1; + for (int i = 0; i < compute->begin_params_axis_; ++i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->params_outer_size_, input->shape_[i], NNACL_ERR); + compute->params_outer_size_ *= input->shape_[i]; + } + + compute->params_inner_size_ = 1; + for (size_t i = compute->begin_params_axis_; i < input->shape_size_; ++i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->params_inner_size_, input->shape_[i], NNACL_ERR); + compute->params_inner_size_ *= input->shape_[i]; + } + + int out_num = NNACLGetElementNum(self->out_[OUTPUT_INDEX]); + self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_LayerNormFusion), compute->norm_inner_size_, + compute->norm_inner_size_, out_num, self->thread_nr_); + self->thread_nr_ = NNACL_MIN(compute->norm_outer_size_, self->thread_nr_); + return NNACL_OK; +} + +int LayerNormCompute(KernelBase *self) { + LayerNormStruct *layer_norm = (LayerNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm); + + layer_norm->src_data_ = self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->src_data_); + layer_norm->gamma_data_ = self->in_[SECOND_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->gamma_data_); + layer_norm->beta_data_ = self->in_[THIRD_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->beta_data_); + layer_norm->dst_data_ = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->dst_data_); + + if (layer_norm->base_.out_size_ == THREE_TENSOR) { + layer_norm->mean_data_ = self->out_[Index1]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->mean_data_); + layer_norm->var_data_ = self->out_[Index2]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->var_data_); + } else if (layer_norm->base_.out_size_ != ONE_TENSOR) { + return NNACL_LAYER_NORM_OUTPUT_NUM_INVALID; + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, LayerNormRun, self, self->thread_nr_); +} + +KernelBase *CreateLayerNorm(OpParameter *param, int data_type) { + LayerNormStruct *layer_norm = (LayerNormStruct *)malloc(sizeof(LayerNormStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(layer_norm); + memset(layer_norm, 0, sizeof(LayerNormStruct)); + layer_norm->data_type_ = data_type; + + LayerNormParameter *layer_norm_param = (LayerNormParameter *)param; + layer_norm->compute_.epsilon_ = layer_norm_param->epsilon_; + layer_norm->compute_.elementwise_affine_ = layer_norm_param->elementwise_affine_; + layer_norm->compute_.begin_norm_axis_ = layer_norm_param->begin_norm_axis_; + layer_norm->compute_.begin_params_axis_ = layer_norm_param->begin_params_axis_; + + layer_norm->base_.Prepare = DefaultPrepare3In1Out; + layer_norm->base_.Release = DefaultRelease; + layer_norm->base_.Resize = LayerNormResize; + layer_norm->base_.Compute = LayerNormCompute; + return (KernelBase *)layer_norm; +} + +REG_KERNEL_CREATOR(PrimType_LayerNormFusion, kNumberTypeFloat16, CreateLayerNorm) +REG_KERNEL_CREATOR(PrimType_LayerNormFusion, kNumberTypeFloat32, CreateLayerNorm) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/layer_norm.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/layer_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..5b561a65ed6ecd7a3d2957a9eed4b7f733309119 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/layer_norm.h @@ -0,0 +1,49 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_LAYER_NORM_H_ +#define NNACL_KERNEL_LAYER_NORM_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct LayerNormComputeParam { + float epsilon_; + bool elementwise_affine_; + int begin_norm_axis_; + int begin_params_axis_; + int norm_inner_size_; + int norm_outer_size_; + int params_inner_size_; + int params_outer_size_; +} LayerNormComputeParam; + +typedef struct LayerNormStruct { + KernelBase base_; + LayerNormComputeParam compute_; + int data_type_; + void *src_data_; + void *dst_data_; + void *gamma_data_; + void *beta_data_; + void *mean_data_; + void *var_data_; +} LayerNormStruct; + +KernelBase *CreateLayerNorm(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_LAYER_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/local_response_norm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/local_response_norm.c new file mode 100644 index 0000000000000000000000000000000000000000..a16b31c5740ca2a30dd1cb73e2c714a5a64bf161 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/local_response_norm.c @@ -0,0 +1,77 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either log_softmaxress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/local_response_norm.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/local_response_norm_fp32.h" +#include "nnacl_c/tensor_c_utils.h" + +int LocalResponseNormRun(void *cdata, int task_id, float l, float r) { + LocalResponseNormStruct *lrn = (LocalResponseNormStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(lrn); + LocalResponseNormParameter *param = (LocalResponseNormParameter *)lrn->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + TensorC *input = lrn->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = lrn->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + NNACL_CHECK_FALSE(input->shape_size_ != DIMENSION_4D, NNACL_LOCAL_RESPONSE_NORM_SHAPE_INVALID); + NNACL_CHECK_FALSE(param->depth_radius_ <= 0, NNACL_LOCAL_RESPONSE_NORM_DEPTH_RADIUS_INVALID); + + float *input_ptr = (float *)input->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + float *output_ptr = (float *)output->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + + int batch = NNACLGetBatch(input); + int height = NNACLGetHeight(input); + int width = NNACLGetWidth(input); + int channel = NNACLGetChannel(input); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(batch, width, NNACL_ERR); + int size_bw = batch * width; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(size_bw, height, NNACL_ERR); + int outer_size = size_bw * height; + int stride = UP_DIV(outer_size, lrn->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(stride, task_id, NNACL_ERR); + int start = stride * task_id; + int count = MSMIN(stride, outer_size - start); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(start, channel, NNACL_ERR); + input_ptr += start * channel; + output_ptr += start * channel; + + return LocalResponseNorm(input_ptr, count, channel, output_ptr, param); +} + +int LrnCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, LocalResponseNormRun, self, self->thread_nr_); +} + +KernelBase *CreateLocalResponseNorm(OpParameter *param, int data_type) { + LocalResponseNormStruct *lrn = (LocalResponseNormStruct *)malloc(sizeof(LocalResponseNormStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(lrn); + memset(lrn, 0, sizeof(LocalResponseNormStruct)); + + lrn->base_.Prepare = DefaultPrepare1In1Out; + lrn->base_.Release = DefaultRelease; + lrn->base_.Resize = DefaultResize; + lrn->base_.Compute = LrnCompute; + return (KernelBase *)lrn; +} + +REG_KERNEL_CREATOR(PrimType_LRN, kNumberTypeFloat32, CreateLocalResponseNorm) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/local_response_norm.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/local_response_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..0b3ebf7306c1c1db98a20e78b44f5d80294c0699 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/local_response_norm.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_LOCAL_RESPONSE_NORM_H_ +#define NNACL_KERNEL_LOCAL_RESPONSE_NORM_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct LocalResponseNormStruct { + KernelBase base_; +} LocalResponseNormStruct; + +KernelBase *CreateLocalResponseNorm(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_LOG_SOFTMAX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/log_softmax.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/log_softmax.c new file mode 100644 index 0000000000000000000000000000000000000000..3c311766c3382671a947edfd61c39c4035c2605c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/log_softmax.c @@ -0,0 +1,120 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either log_softmaxress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/log_softmax.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/log_softmax_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/log_softmax_fp16.h" +#endif + +int LogSoftmaxLastAxisRun(void *cdata, int task_id, float l, float r) { + LogSoftmaxStruct *log_softmax = (LogSoftmaxStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(log_softmax); + + TensorC *in = log_softmax->softmax_.base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in); + void *input_ptr = in->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + void *output_ptr = log_softmax->softmax_.base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + void *tmp_ptr = log_softmax->softmax_.sum_data_; + NNACL_CHECK_NULL_RETURN_ERR(tmp_ptr); + + int unit = UP_DIV(log_softmax->softmax_.out_plane_size_, log_softmax->softmax_.base_.thread_nr_); + int begin = task_id * unit; + int end = MSMIN(begin + unit, log_softmax->softmax_.out_plane_size_); + int channel = in->shape_[log_softmax->softmax_.axis_]; + int offset = begin * channel; + +#ifdef ENABLE_FP16 + if (log_softmax->softmax_.data_type_ == kNumberTypeFloat16) { + LogSoftmaxLastAxisFp16((const float16_t *)input_ptr + offset, (float16_t *)output_ptr + offset, + (float16_t *)tmp_ptr + offset, end - begin, channel); + return NNACL_OK; + } +#endif + LogSoftmaxLastAxis((const float *)input_ptr + offset, (float *)output_ptr + offset, (float *)tmp_ptr + offset, + end - begin, channel); + return NNACL_OK; +} + +int LogSoftmaxResize(struct KernelBase *self) { + LogSoftmaxStruct *log_softmax = (LogSoftmaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(log_softmax); + + int ret = InitSoftmaxParam(&log_softmax->softmax_); + if (ret != NNACL_OK) { + return ret; + } + + if (log_softmax->softmax_.in_plane_size_ == 1 && log_softmax->softmax_.sum_data_ == NULL) { + TensorC *in = log_softmax->softmax_.base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in); + SoftmaxStruct *softmax = &log_softmax->softmax_; + + int sum_data_size = softmax->in_plane_size_ * softmax->out_plane_size_ * in->shape_[softmax->axis_]; + softmax->sum_data_ = self->env_->Alloc(self->env_->allocator_, sum_data_size * DataTypeCSize(softmax->data_type_)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(softmax->sum_data_); + } + return NNACL_OK; +} + +int LogSoftmaxCompute(struct KernelBase *self) { + LogSoftmaxStruct *log_softmax = (LogSoftmaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(log_softmax); + + if (log_softmax->softmax_.in_plane_size_ == 1) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, LogSoftmaxLastAxisRun, self, self->thread_nr_); + } + + TensorC *in = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in); + void *input_ptr = in->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + void *output_ptr = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + NNACL_CHECK_NULL_RETURN_ERR(log_softmax->softmax_.sum_data_); + +#ifdef ENABLE_FP16 + if (log_softmax->softmax_.data_type_ == kNumberTypeFloat16) { + LogSoftmaxFp16((const float16_t *)input_ptr, (float16_t *)output_ptr, (float16_t *)log_softmax->softmax_.sum_data_, + in->shape_, in->shape_size_, log_softmax->softmax_.axis_); + return NNACL_OK; + } +#endif + LogSoftmax((const float *)input_ptr, (float *)output_ptr, (float *)log_softmax->softmax_.sum_data_, in->shape_, + in->shape_size_, log_softmax->softmax_.axis_); + return NNACL_OK; +} + +KernelBase *CreateLogSoftmax(OpParameter *param, int data_type) { + LogSoftmaxStruct *log_softmax = (LogSoftmaxStruct *)malloc(sizeof(LogSoftmaxStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(log_softmax); + memset(log_softmax, 0, sizeof(LogSoftmaxStruct)); + + log_softmax->softmax_.sum_data_ = NULL; + log_softmax->softmax_.data_type_ = data_type; + log_softmax->softmax_.base_.Prepare = DefaultPrepare1In1Out; + log_softmax->softmax_.base_.Release = SoftmaxRelease; + log_softmax->softmax_.base_.Resize = LogSoftmaxResize; + log_softmax->softmax_.base_.Compute = LogSoftmaxCompute; + return (KernelBase *)log_softmax; +} + +REG_KERNEL_CREATOR(PrimType_LogSoftmax, kNumberTypeFloat32, CreateLogSoftmax) +REG_KERNEL_CREATOR(PrimType_LogSoftmax, kNumberTypeFloat16, CreateLogSoftmax) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/log_softmax.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/log_softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..65abf3c3f4590449a6c1795caf716d5d89c84506 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/log_softmax.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_LOG_SOFTMAX_H_ +#define NNACL_KERNEL_LOG_SOFTMAX_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/softmax.h" + +typedef struct LogSoftmaxStruct { + SoftmaxStruct softmax_; +} LogSoftmaxStruct; + +KernelBase *CreateLogSoftmax(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_LOG_SOFTMAX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul.c new file mode 100644 index 0000000000000000000000000000000000000000..8644c41a6d25f094967ccf9b80bfb1e062dd5633 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul.c @@ -0,0 +1,176 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/matmul.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/kernel/matmul_create.h" + +void MatmulInitShapeA(MatmulStruct *matmul) { + int *a_shape = matmul->base_.in_[kInputIndex]->shape_; + size_t a_shape_size = matmul->base_.in_[kInputIndex]->shape_size_; + int batch = 1; + NNACL_CHECK_TRUE_RET_VOID(a_shape_size >= C2NUM); + for (size_t i = 0; i < a_shape_size - C2NUM; ++i) { + batch *= a_shape[i]; + } + matmul->a_batch_ = batch; + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->compute_.row_ = param->a_transpose_ ? a_shape[a_shape_size - 1] : a_shape[a_shape_size - C2NUM]; + matmul->compute_.deep_ = param->a_transpose_ ? a_shape[a_shape_size - C2NUM] : a_shape[a_shape_size - 1]; +} + +void MatmulInitShapeB(MatmulStruct *matmul) { + int *b_shape = matmul->base_.in_[kWeightIndex]->shape_; + size_t b_shape_size = matmul->base_.in_[kWeightIndex]->shape_size_; + int batch = 1; + NNACL_CHECK_TRUE_RET_VOID(b_shape_size >= C2NUM); + for (size_t i = 0; i < b_shape_size - C2NUM; ++i) { + batch *= b_shape[i]; + } + matmul->b_batch_ = batch; + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->compute_.col_ = param->b_transpose_ ? b_shape[b_shape_size - C2NUM] : b_shape[b_shape_size - 1]; + matmul->compute_.deep_ = param->b_transpose_ ? b_shape[b_shape_size - 1] : b_shape[b_shape_size - C2NUM]; +} + +int MatmulInitBroadcastParams(MatmulStruct *matmul) { + TensorC *a = matmul->base_.in_[FIRST_INPUT]; + TensorC *b = matmul->base_.in_[SECOND_INPUT]; + + int max_dim_size = (int)NNACL_MAX(a->shape_size_, b->shape_size_); + max_dim_size = NNACL_MAX(max_dim_size, COMM_SHAPE_SIZE); + + int a_shape[MAX_SHAPE_SIZE] = {0}; + int index = max_dim_size - 1; + for (int i = (int)a->shape_size_ - 1; i >= 0; i--) { + a_shape[index--] = a->shape_[i]; + } + for (; index >= 0;) { + a_shape[index--] = 1; + } + + int b_shape[MAX_SHAPE_SIZE] = {0}; + index = max_dim_size - 1; + for (int i = (int)b->shape_size_ - 1; i >= 0; i--) { + b_shape[index--] = b->shape_[i]; + } + for (; index >= 0;) { + b_shape[index--] = 1; + } + + int batch_sizes[MAX_SHAPE_SIZE] = {0}; + int a_batch_sizes[MAX_SHAPE_SIZE] = {0}; + int b_batch_sizes[MAX_SHAPE_SIZE] = {0}; + for (int i = max_dim_size - Num3; i >= 0; --i) { + if (max_dim_size - Num3 == i) { + batch_sizes[i] = NNACL_MAX(a_shape[i], b_shape[i]); + a_batch_sizes[i] = a_shape[i]; + b_batch_sizes[i] = b_shape[i]; + } else { + batch_sizes[i] = batch_sizes[i + 1] * NNACL_MAX(a_shape[i], b_shape[i]); + a_batch_sizes[i] = a_batch_sizes[i + 1] * a_shape[i]; + b_batch_sizes[i] = b_batch_sizes[i + 1] * b_shape[i]; + } + } + + int out_batch = 1; + for (int i = 0; i < max_dim_size - Num2; ++i) { + int max_v = NNACL_MAX(a_shape[i], b_shape[i]); + int min_v = NNACL_MIN(a_shape[i], b_shape[i]) > 0 ? NNACL_MIN(a_shape[i], b_shape[i]) : 1; + out_batch *= max_v; + if ((max_v != min_v) && ((max_v % min_v) != 0)) { + return NNACL_ERR; + } + } + matmul->batch_ = out_batch; + + MatmulBaseFreeBatchOffset(matmul); + int ret = MatmulBaseMallocBatchOffset(matmul); + if (ret != NNACL_OK) { + return ret; + } + + for (int i = 0; i < matmul->batch_; ++i) { + int delta = i; + int a_offset = 0; + int b_offset = 0; + for (int j = 0; j < max_dim_size - Num2; ++j) { + if (j > 0) { + delta = delta % batch_sizes[j]; + } + if (j >= (MAX_SHAPE_SIZE - 1)) { + return NNACL_ERR; + } + if (j < (max_dim_size - Num3)) { + a_offset += + (delta / batch_sizes[j + 1] * a_shape[j] / NNACL_MAX(a_shape[j], b_shape[j])) * a_batch_sizes[j + 1]; + b_offset += + (delta / batch_sizes[j + 1] * b_shape[j] / NNACL_MAX(a_shape[j], b_shape[j])) * b_batch_sizes[j + 1]; + } else { + a_offset += (delta * a_shape[j] / NNACL_MAX(a_shape[j], b_shape[j])); + b_offset += (delta * b_shape[j] / NNACL_MAX(a_shape[j], b_shape[j])); + } + } + matmul->a_offset_[i] = a_offset; + matmul->b_offset_[i] = b_offset; + } + return NNACL_OK; +} + +int MatmulPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < C2NUM, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < C1NUM, NNACL_ERR); + + MatmulStruct *matmul = (MatmulStruct *)self; + if (matmul->a_const_ || matmul->infer_shape_) { + MatmulInitShapeA(matmul); + } + + if (matmul->b_const_ || matmul->infer_shape_) { + MatmulInitShapeB(matmul); + } + + return MatmulBasePrepare(self); +} + +int MatmulResize(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + MatmulInitShapeA(matmul); + MatmulInitShapeB(matmul); + + int ret = MatmulInitBroadcastParams(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + return MatmulBaseResize(self); +} + +int MatmulRelease(KernelBase *self) { + MatmulBaseFreeBatchOffset((MatmulStruct *)self); + return MatmulBaseRelease(self); +} + +KernelBase *CreateMatmul(OpParameter *param, int data_type) { + KernelBase *kernel = NULL; + if (data_type == kNumberTypeFloat32) { + kernel = CreateMatmulKernel(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(kernel); + kernel->Prepare = MatmulPrepare; + kernel->Resize = MatmulResize; + kernel->Release = MatmulRelease; + } + return kernel; +} + +REG_KERNEL_CREATOR(PrimType_MatMulFusion, kNumberTypeFloat32, CreateMatmul); diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul.h new file mode 100644 index 0000000000000000000000000000000000000000..db1d384023c1c4deea4812c876d65554eb3e6508 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_MATMUL_H_ +#define NNACL_KERNEL_MATMUL_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateMatmul(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_MATMUL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm32.c new file mode 100644 index 0000000000000000000000000000000000000000..43bb65d8d1bd8c3708e7b4cff9260fe9b4ec4258 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm32.c @@ -0,0 +1,110 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM32 +#include "nnacl_c/kernel/matmul_arm32.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" + +void MatmulARM32InitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->matrix_a_.need_pack_ = true; + matmul->matrix_b_.need_pack_ = true; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col4MajorParallel : RowMajor2Row4MajorParallel; + matmul->compute_.row_tile_ = C12NUM; + matmul->compute_.col_tile_ = C4NUM; + matmul->compute_.col_min_unit_ = C4NUM; +} + +int MatmulARM32ParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute->col_step_, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block4(a, b, c, bias, act, compute->deep_, compute->col_step_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulARM32ParallelRunByOC(MatmulStruct *matmul, int task_id) { + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block4(a, b, c, bias, act, compute->deep_, compute_oc); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + return NNACL_OK; +} + +KernelBase *CreateMatmulARM32() { + MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + matmul->matmul_type_ = kNotImplemented; + matmul->init_global_varibale_ = MatmulARM32InitGlobalVariable; + matmul->parallel_run_by_batch_ = MatmulARM32ParallelRunByBatch; + matmul->parallel_run_by_oc_ = MatmulARM32ParallelRunByOC; + return (KernelBase *)matmul; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm32.h new file mode 100644 index 0000000000000000000000000000000000000000..a3bc34ac15fc86f9cd2ca2d4c33c2b69b2d9548f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm32.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_ARM32_H_ +#define NNACL_KERNEL_MATMUL_ARM32_H_ + +#ifdef ENABLE_ARM32 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateMatmulARM32(); + +#endif +#endif // NNACL_KERNEL_MATMUL_ARM32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm64.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm64.c new file mode 100644 index 0000000000000000000000000000000000000000..c61068a756decccc6d52659a61cd0cd18ce68480 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm64.c @@ -0,0 +1,214 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM64 +#include "nnacl_c/kernel/matmul_arm64.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32_opt.h" + +typedef struct MatrixAPack { + int64_t points_[MAX_THREAD_NUM]; + int64_t unit_num_; + int thread_; + int deep_; + int row_; + int col_; + MatrixInfo *matrix_a_; + float *src_ptr_; + bool a_transpose_; +} MatrixAPack; + +int MatmulARM64PackMatrixAImplOptPack(void *cdata, int task_id, float l, float r) { + MatrixAPack *pack = (MatrixAPack *)cdata; + int64_t start = pack->points_[task_id]; + int64_t end = pack->unit_num_; + if (task_id < pack->thread_ - 1) { + end = pack->points_[task_id + 1]; + } + + if (pack->a_transpose_) { + RowMajor2Row12MajorOpt(pack->src_ptr_, pack->matrix_a_->pack_ptr_, pack->deep_, pack->row_, start, end); + } else { + RowMajor2Col12MajorOpt(pack->src_ptr_, pack->matrix_a_->pack_ptr_, pack->row_, pack->deep_, start, end); + } + return NNACL_OK; +} + +int MatmulARM64PackMatrixAImplOpt(MatmulStruct *matmul) { + int64_t kPackAMinUnitNum = 1 << 13; + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + float *src_ptr = matmul->matrix_a_.origin_ptr_ != NULL ? matmul->matrix_a_.origin_ptr_ + : (float *)(matmul->base_.in_[FIRST_INPUT]->data_); + NNACL_CHECK_TRUE_RET(src_ptr != NULL, NNACL_ERR); + NNACL_CHECK_TRUE_RET(matmul->matrix_a_.pack_ptr_ != NULL, NNACL_ERR); + + MatrixAPack pack; + pack.src_ptr_ = src_ptr; + pack.matrix_a_ = &matmul->matrix_a_; + pack.deep_ = matmul->compute_.deep_; + pack.col_ = matmul->compute_.col_; + pack.row_ = matmul->compute_.row_; + pack.a_transpose_ = param->a_transpose_; + pack.unit_num_ = 0; + pack.unit_num_ = matmul->a_batch_ * UP_DIV(matmul->compute_.row_, C12NUM) * matmul->compute_.deep_; + pack.thread_ = MSMIN(matmul->base_.thread_nr_, UP_DIV(pack.unit_num_, kPackAMinUnitNum)); + if (pack.thread_ < 1) { + pack.thread_ = 1; + } + int64_t block_size = pack.unit_num_ / pack.thread_; + int64_t remain_size = pack.unit_num_ - block_size * pack.thread_; + int64_t start = 0; + size_t count = 0; + while (start < pack.unit_num_) { + pack.points_[count++] = start; + start += block_size; + if (remain_size > 0) { + ++start; + --remain_size; + } + } + pack.thread_ = count; + + if (pack.thread_ == 1) { + return MatmulARM64PackMatrixAImplOptPack(&pack, 0, 0, 1); + } + return matmul->base_.env_->ParallelLaunch(matmul->base_.env_->thread_pool_, MatmulARM64PackMatrixAImplOptPack, &pack, + pack.thread_); +} + +bool MatmulARM64CheckThreadCuttingByRow(MatmulStruct *matmul) { + if (matmul->b_batch_ != C1NUM) { + return false; + } + if (matmul->batch_ >= matmul->base_.thread_nr_ || matmul->compute_.col_ == 1) { + matmul->compute_.row_min_unit_ = C4NUM; + return true; + } + return false; +} +void MatmulARM64InitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->pack_opt_ = true; + matmul->compute_.row_tile_ = C12NUM; + matmul->compute_.col_tile_ = C8NUM; + matmul->compute_.col_min_unit_ = C8NUM; + matmul->matrix_a_.need_pack_ = true; + matmul->matrix_b_.need_pack_ = !matmul->weight_is_packed_; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col8MajorParallel : RowMajor2Row8MajorParallel; +} + +int MatmulARM64ParallelRunByBatch(MatmulStruct *matmul, int task_id) { + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute->col_step_, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulARM64ParallelRunByRow(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int start_row = matmul->split_points_[task_id]; + int end_row = matmul->compute_.row_num_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_row = matmul->split_points_[task_id + 1]; + } + int row_num = end_row - start_row; + if (row_num <= 0) { + return NNACL_OK; + } + GemmIsNotPackByRow(matmul->matrix_a_.pack_ptr_, matmul->matrix_b_.pack_ptr_, matmul->output_data_, + matmul->matrix_c_.pack_ptr_, start_row, end_row, matmul->compute_.deep_, param->act_type_); + return NNACL_OK; +} + +int MatmulARM64ParallelRunByOC(MatmulStruct *matmul, int task_id) { + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulPackFp32(a, b, c, bias, act, compute->deep_, compute_oc); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + return NNACL_OK; +} + +KernelBase *CreateMatmulARM64() { + MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + matmul->matmul_type_ = kMatmulFp32Arm64Cpu; + matmul->check_thread_cutting_by_row_ = MatmulARM64CheckThreadCuttingByRow; + matmul->init_global_varibale_ = MatmulARM64InitGlobalVariable; + matmul->parallel_run_by_oc_ = MatmulARM64ParallelRunByOC; + matmul->parallel_run_by_row_ = MatmulARM64ParallelRunByRow; + matmul->parallel_run_by_batch_ = MatmulARM64ParallelRunByBatch; + matmul->pack_matrix_a_impl_opt_ = MatmulARM64PackMatrixAImplOpt; + return (KernelBase *)matmul; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm64.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm64.h new file mode 100644 index 0000000000000000000000000000000000000000..35c938a745b3ebf92357f5b37c3ace7f689395ed --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm64.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_ARM64_H_ +#define NNACL_KERNEL_MATMUL_ARM64_H_ + +#ifdef ENABLE_ARM64 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateMatmulARM64(); + +#endif +#endif // NNACL_KERNEL_MATMUL_ARM64_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx.c new file mode 100644 index 0000000000000000000000000000000000000000..9ccae7ec10a8dad7f86c8cc9af25f0395a9077e7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx.c @@ -0,0 +1,169 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/matmul_avx.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +void MatmulAVXInitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->compute_.row_tile_ = C1NUM; + matmul->compute_.col_tile_ = C8NUM; + matmul->compute_.col_min_unit_ = C32NUM; + matmul->out_need_aligned_ = true; + matmul->matrix_b_.need_pack_ = true; + matmul->matrix_a_.need_pack_ = param->a_transpose_; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col32MajorParallel : RowMajor2Row32MajorParallel; +} + +int MatmulAVXParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)matmul->base_.param_; + MatmulComputeParam *compute = (MatmulComputeParam *)&matmul->compute_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (matmul->compute_.row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + ActType act = param->act_type_; + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + if (func_flag == 0) { + MatMulAvxFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_, compute->row_); + } else if (func_flag == C1NUM) { + MatVecMulAvxFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulAVXParallelRunByRow(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatmulComputeParam *compute = (MatmulComputeParam *)&matmul->compute_; + + int start_row = matmul->split_points_[task_id]; + int end_row = compute->row_num_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_row = matmul->split_points_[task_id + 1]; + } + int row_num = end_row - start_row; + if (row_num <= 0) { + return NNACL_OK; + } + const float *input = matmul->matrix_a_.pack_ptr_ + start_row * compute->deep_; + float *output = matmul->output_data_ + start_row * compute->col_align_; + if (compute->col_ == 1) { + float bias = 0; + if (matmul->matrix_c_.pack_ptr_ != NULL) { + bias = matmul->matrix_c_.pack_ptr_[0]; + } + matmul->gemm_not_pack_fun_(input, matmul->matrix_b_.pack_ptr_, output, &bias, row_num, compute->deep_, + param->act_type_); + } else { + MatMulAvxFp32(input, matmul->matrix_b_.pack_ptr_, output, matmul->matrix_c_.pack_ptr_, param->act_type_, + compute->deep_, compute->col_align_, compute->col_align_, row_num); + } + return NNACL_OK; +} + +int MatmulAVXParallelRunByOC(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatmulComputeParam *compute = (MatmulComputeParam *)&matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + MatMulAvxFp32(a, b, c, bias, param->act_type_, compute->deep_, compute_oc, compute->col_align_, compute->row_); + } else if (func_flag == C1NUM) { + MatVecMulAvxFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_align_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + return NNACL_OK; +} + +bool MatmulAVXCheckThreadCuttingByRow(MatmulStruct *matmul) { + if (matmul->b_batch_ != C1NUM) { + return false; + } + if (matmul->compute_.row_num_ < matmul->base_.thread_nr_) { + return false; + } + if (matmul->compute_.col_ == 1) { + matmul->compute_.row_min_unit_ = C4NUM; + return true; + } + if (matmul->compute_.row_ == 1 && !matmul->b_const_ && matmul->compute_.col_ <= C128NUM) { + return false; + } + matmul->compute_.row_min_unit_ = C3NUM; + if (matmul->compute_.col_step_ < C16NUM) { + matmul->compute_.row_min_unit_ = C8NUM; + } else if (matmul->compute_.col_step_ < C24NUM) { + matmul->compute_.row_min_unit_ = C6NUM; + } else if (matmul->compute_.col_step_ < C32NUM) { + matmul->compute_.row_min_unit_ = C4NUM; + } + return MSMIN(matmul->compute_.row_num_ / matmul->compute_.row_min_unit_, matmul->base_.thread_nr_) > + MSMIN(matmul->compute_.col_step_ / matmul->compute_.col_min_unit_, matmul->base_.thread_nr_); +} + +KernelBase *CreateMatmulAVX() { + MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + matmul->matmul_type_ = kNotImplemented; + matmul->init_global_varibale_ = MatmulAVXInitGlobalVariable; + matmul->parallel_run_by_batch_ = MatmulAVXParallelRunByBatch; + matmul->parallel_run_by_row_ = MatmulAVXParallelRunByRow; + matmul->parallel_run_by_oc_ = MatmulAVXParallelRunByOC; + matmul->check_thread_cutting_by_row_ = MatmulAVXCheckThreadCuttingByRow; + return (KernelBase *)matmul; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx.h new file mode 100644 index 0000000000000000000000000000000000000000..bb722473a1987b7378a32bd7c26356d18049238c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_AVX_H_ +#define NNACL_KERNEL_MATMUL_AVX_H_ + +#ifdef ENABLE_AVX +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateMatmulAVX(); + +#endif +#endif // NNACL_KERNEL_MATMUL_AVX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx512.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx512.c new file mode 100644 index 0000000000000000000000000000000000000000..75e80eaf3aa7c924cfb8d6559e86e3aa5dd79aa5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx512.c @@ -0,0 +1,708 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX512 +#include "nnacl_c/kernel/matmul_avx512.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_avx512_fp32.h" +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" + +#define MIN_CALC_COST 24576 /* 1 x 6 x 64x 64 */ + +void MatmulAVX512BatchRowThreadCut(MatmulStruct *matmul) { + // BatchCut + matmul->compute_.batch_stride_ = DOWN_DIV(matmul->batch_, matmul->base_.thread_nr_); + + // RowCut + int row_step = MSMAX(matmul->compute_.row_ / matmul->base_.thread_nr_, matmul->compute_.row_min_unit_); + int row_remaining = matmul->compute_.row_ - row_step * matmul->base_.thread_nr_; + + matmul->row_split_points_size_ = 0; + int row_split_point = 0; + while (row_split_point < matmul->compute_.row_) { + matmul->row_split_points_[matmul->row_split_points_size_++] = row_split_point; + row_split_point += row_step; + if (row_remaining > 0) { + ++row_split_point; + --row_remaining; + } + } + matmul->row_split_points_[matmul->row_split_points_size_] = matmul->compute_.row_; + if (matmul->compute_.batch_stride_ == 0) { + matmul->base_.thread_nr_ = matmul->row_split_points_size_; + } +} + +void MatmulAVX512BatchColThreadCut(MatmulStruct *matmul) { + // BatchCut + matmul->compute_.batch_stride_ = DOWN_DIV(matmul->batch_, matmul->base_.thread_nr_); + + // ColCut + int total_col_unit = UP_DIV(matmul->compute_.col_align_, matmul->compute_.col_min_unit_); + int thread_num_tmp = NNACL_MIN(matmul->base_.thread_nr_, total_col_unit); + int block_col_unit = UP_DIV(total_col_unit, thread_num_tmp); + int split_point = 0; + matmul->col_split_points_size_ = 0; + while (split_point < total_col_unit) { + matmul->col_split_points_[matmul->col_split_points_size_++] = split_point * matmul->compute_.col_min_unit_; + split_point += block_col_unit; + } + if (matmul->compute_.batch_stride_ == 0) { + matmul->base_.thread_nr_ = matmul->col_split_points_size_; + } +} + +void MatmulAVX512BatchColRowSliceThreadCut(MatmulStruct *matmul) { + // BatchCut + matmul->compute_.batch_stride_ = DOWN_DIV(matmul->batch_, matmul->base_.thread_nr_); + + int row_s = 0; + int row_e = matmul->compute_.row_; + int col_s = 0; + int col_e = matmul->compute_.col_; + + // ColCut + int total_col_unit = UP_DIV(matmul->compute_.col_align_, matmul->compute_.col_min_unit_); + matmul->compute_.block_col_unit_ = DOWN_DIV(total_col_unit, matmul->base_.thread_nr_); + matmul->col_split_points_size_ = 0; + matmul->col_split_points_[matmul->col_split_points_size_++] = 0; + if (matmul->compute_.block_col_unit_ > 0) { + int col_split_point = 0; + for (int i = 0; i < matmul->base_.thread_nr_; i++) { + MatmulSlice matmul_slice; + matmul_slice.row_s_ = row_s; + matmul_slice.row_e_ = row_e; + matmul_slice.col_s_ = col_split_point * matmul->compute_.col_min_unit_; + col_split_point += matmul->compute_.block_col_unit_; + col_s = NNACL_MIN(col_split_point * matmul->compute_.col_min_unit_, matmul->compute_.col_step_); + matmul_slice.col_e_ = col_s; + matmul->matmul_slice_set_[i][matmul->matmul_slice_count_[i]++] = matmul_slice; + } + } + if (col_e - col_s <= 0) { + return; + } + + // RowColCut + int row_thread = 0; + int less_col_align = UP_ROUND(col_e - col_s, C16NUM); + bool use_colrowcut_flag = ((less_col_align / C64NUM) * C64NUM) == less_col_align; + bool use_rowcut_flag = matmul->compute_.row_ >= C6NUM * matmul->base_.thread_nr_ || col_e - col_s <= C64NUM; + if (use_rowcut_flag && !use_colrowcut_flag) { + int row_step = MSMAX(matmul->compute_.row_ / matmul->base_.thread_nr_, matmul->compute_.row_min_unit_); + int row_remaining = matmul->compute_.row_ - row_step * matmul->base_.thread_nr_; + int row_split_point = 0; + + for (row_thread = 0; row_thread < matmul->base_.thread_nr_ && row_split_point < matmul->compute_.row_; + row_thread++) { + MatmulSlice matmul_slice; + matmul_slice.row_s_ = row_split_point; + + row_split_point += row_step; + if (row_remaining > 0) { + ++row_split_point; + --row_remaining; + } + + matmul_slice.row_e_ = row_split_point; + matmul_slice.col_s_ = col_s; + matmul_slice.col_e_ = col_e; + matmul->matmul_slice_set_[row_thread][matmul->matmul_slice_count_[row_thread]++] = matmul_slice; + } + } else { + int col_num = UP_DIV(col_e - col_s, C64NUM); + int row_num = NNACL_MIN(UP_DIV(matmul->base_.thread_nr_, col_num), (row_e - row_s)); + int tile_remaining = MSMAX(col_num * row_num - matmul->base_.thread_nr_, 0); + + NNACL_CHECK_ZERO_RETURN(row_num); + int row_step = (row_e - row_s) / row_num; + int row_remaining_tmp = (row_e - row_s) - row_step * row_num; + + int row_step_cut2 = (row_num == 1) ? row_step : (row_e - row_s) / (row_num - 1); + int row_remaining_cut2_tmp = (row_e - row_s) - row_step_cut2 * (row_num - 1); + + MatmulSlice matmul_slice; + for (int c = 0; c < col_num; c++) { + matmul_slice.col_s_ = col_s + c * C64NUM; + matmul_slice.col_e_ = NNACL_MIN(col_s + (c + 1) * C64NUM, matmul->compute_.col_); + int row_split_point = 0; + int row_remaining = row_remaining_tmp; + int row_remaining_cut2 = row_remaining_cut2_tmp; + if (c < col_num - tile_remaining) { + for (int r = 0; r < row_num; r++) { + matmul_slice.row_s_ = row_split_point; + row_split_point += row_step; + if (row_remaining > 0) { + ++row_split_point; + --row_remaining; + } + matmul_slice.row_e_ = NNACL_MIN(row_split_point, matmul->compute_.row_); + matmul->matmul_slice_set_[row_thread][matmul->matmul_slice_count_[row_thread]++] = matmul_slice; + row_thread++; + } + } else { + for (int r = 0; r < row_num - 1; r++) { + matmul_slice.row_s_ = row_split_point; + row_split_point += row_step_cut2; + if (row_remaining_cut2 > 0) { + ++row_split_point; + --row_remaining_cut2; + } + matmul_slice.row_e_ = NNACL_MIN(row_split_point, matmul->compute_.row_); + matmul->matmul_slice_set_[row_thread][matmul->matmul_slice_count_[row_thread]++] = matmul_slice; + row_thread++; + } + } + } + } + if ((matmul->compute_.batch_stride_ == 0) && (matmul->compute_.block_col_unit_ == 0)) { + matmul->base_.thread_nr_ = row_thread; + } +} + +void MatmulAVX512GetThreadCuttingPolicy(MatmulStruct *matmul) { + size_t total_cost = (size_t)(matmul->batch_) * (size_t)(matmul->compute_.row_) * (size_t)(matmul->compute_.col_) * + (size_t)(matmul->compute_.deep_); + + // Thread Update + matmul->base_.thread_nr_ = MSMAX(NNACL_MIN((int)(total_cost / MIN_CALC_COST), matmul->base_.thread_nr_), C1NUM); + + if (matmul->compute_.deep_ < C128NUM) { + return MatmulBaseGetThreadCuttingPolicy(matmul); + } + + for (int i = 0; i < SPLIT_COUNT; i++) { + matmul->matmul_slice_count_[i] = 0; + } + if (matmul->compute_.col_ == 1 && !matmul->a_const_) { + MatmulAVX512BatchRowThreadCut(matmul); + if (matmul->compute_.deep_ == 1) { + matmul->gemm_not_pack_fun_ = GemmIsNotPack; + } else { + matmul->gemm_not_pack_fun_ = GemmIsNotPackOptimize; + } + matmul->parallel_run_ = matmul->parallel_run_by_gepdot_; + } else if (matmul->compute_.row_ == 1 && !matmul->b_const_ && matmul->compute_.col_ <= C128NUM) { + MatmulAVX512BatchColThreadCut(matmul); + if (matmul->compute_.deep_ == 1) { + matmul->parallel_run_ = matmul->parallel_run_by_row1_deep1_gepdot_; + if (matmul->matrix_c_.pack_ptr_ != NULL) { + matmul->gemm_not_pack_fun_ = Row1Deep1GemmIsNotPack; + } else { + matmul->gemm_not_pack_fun_ = Row1Deep1NoBiasGemmIsNotPack; + } + return; + } + matmul->parallel_run_ = matmul->parallel_run_by_gepm_; + } else { + MatmulAVX512BatchColRowSliceThreadCut(matmul); + matmul->parallel_run_ = matmul->parallel_run_by_batch_col_row_gemm_; + } + return; +} + +bool MatmulAVX512CheckThreadCuttingByRow(MatmulStruct *matmul) { + if (matmul->b_batch_ != C1NUM) { + return false; + } + if (matmul->compute_.row_num_ < matmul->base_.thread_nr_) { + return false; + } + if (matmul->compute_.col_ == 1) { + matmul->compute_.row_min_unit_ = C8NUM; + return true; + } + if (matmul->compute_.row_ == 1 && !matmul->b_const_ && matmul->compute_.col_ <= C128NUM) { + return false; + } + matmul->compute_.row_min_unit_ = C6NUM; + if (matmul->compute_.col_step_ < C48NUM) { + matmul->compute_.row_min_unit_ = C12NUM; + } else if (matmul->compute_.col_step_ < C64NUM) { + matmul->compute_.row_min_unit_ = C8NUM; + } + return NNACL_MIN(matmul->compute_.row_num_ / matmul->compute_.row_min_unit_, matmul->base_.thread_nr_) > + NNACL_MIN(matmul->compute_.col_step_ / matmul->compute_.col_min_unit_, matmul->base_.thread_nr_); +} +void MatmulAVX512InitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col64MajorParallel : RowMajor2Row64MajorParallel; + matmul->matrix_a_.need_pack_ = param->a_transpose_; + matmul->matrix_b_.need_pack_ = true; + matmul->compute_.row_tile_ = C1NUM; + matmul->compute_.col_tile_ = C16NUM; + matmul->compute_.col_min_unit_ = C64NUM; + + if (matmul->compute_.row_ == 1) { + if (!matmul->b_const_ && matmul->compute_.col_ <= C128NUM) { + matmul->out_need_aligned_ = true; + } + } else if (matmul->compute_.col_ == 1) { + matmul->out_need_aligned_ = true; + } else { + matmul->out_need_aligned_ = false; + } + + if (matmul->compute_.deep_ >= C128NUM) { + matmul->out_need_aligned_ = false; + } +} +int MatmulAVX512InitParameter(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + + if (compute->deep_ < C128NUM) { + return MatmulBaseInitParameter(matmul); + } + + matmul->init_global_varibale_(matmul); + if (compute->col_ == 1 && !matmul->a_const_) { + matmul->out_need_aligned_ = false; + compute->row_tile_ = 1; + compute->col_tile_ = 1; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_a_.need_pack_ = param->a_transpose_ && compute->row_ != 1; + matmul->matrix_b_.need_pack_ = false; + matmul->pack_opt_ = false; + } else if (compute->row_ == 1 && !matmul->b_const_ && compute->col_ <= C128NUM) { + matmul->out_need_aligned_ = false; + compute->row_tile_ = 1; + compute->col_tile_ = 1; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_a_.need_pack_ = false; + matmul->matrix_b_.need_pack_ = param->b_transpose_; + matmul->pack_opt_ = false; + } + compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_); + compute->col_align_ = UP_ROUND(compute->col_, compute->col_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->row_align_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->row_align_, compute->deep_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->col_align_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->col_align_, compute->deep_, NNACL_ERR); + int a_pack_size = matmul->a_batch_ * compute->row_align_ * compute->deep_; + int b_pack_size = matmul->b_batch_ * compute->col_align_ * compute->deep_; + if ((matmul->matrix_a_.has_packed_ && matmul->matrix_a_.pack_size_ != a_pack_size) || + (matmul->matrix_b_.has_packed_ && matmul->matrix_b_.pack_size_ != b_pack_size)) { + return NNACL_ERR; + } + matmul->matrix_a_.pack_size_ = a_pack_size; + matmul->matrix_b_.pack_size_ = b_pack_size; + compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_); + matmul->out_need_aligned_ = (matmul->out_need_aligned_ && ((compute->col_ % compute->col_tile_) != 0)); + compute->col_step_ = matmul->out_need_aligned_ ? compute->col_align_ : compute->col_; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(matmul->a_batch_, compute->row_), NNACL_ERR); + compute->row_num_ = matmul->a_batch_ * compute->row_; + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByRow(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int start_row = matmul->split_points_[task_id]; + int end_row = matmul->compute_.row_num_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_row = matmul->split_points_[task_id + 1]; + } + int row_num = end_row - start_row; + if (row_num <= 0) { + return NNACL_OK; + } + const float *input = matmul->matrix_a_.pack_ptr_ + start_row * matmul->compute_.deep_; + float *output = matmul->output_data_ + start_row * matmul->compute_.col_step_; + if (matmul->compute_.col_ == 1) { + float bias = 0; + if (matmul->matrix_c_.pack_ptr_ != NULL) { + bias = matmul->matrix_c_.pack_ptr_[0]; + } + matmul->gemm_not_pack_fun_(input, matmul->matrix_b_.pack_ptr_, output, &bias, row_num, matmul->compute_.deep_, + param->act_type_); + } else { + if (matmul->out_need_aligned_) { + MatMulAvx512Fp32(input, matmul->matrix_b_.pack_ptr_, output, matmul->matrix_c_.pack_ptr_, param->act_type_, + matmul->compute_.deep_, matmul->compute_.col_align_, matmul->compute_.col_align_, row_num); + } else { + MatMulMaskAvx512Fp32(input, matmul->matrix_b_.pack_ptr_, output, matmul->matrix_c_.pack_ptr_, param->act_type_, + matmul->compute_.deep_, matmul->compute_.col_, matmul->compute_.col_, row_num); + } + } + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByOC(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + int func_flag = 0; + if (matmul->compute_.row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + if (matmul->out_need_aligned_) { + MatMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_align_, compute->row_); + } else { + MatMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_, compute->row_); + } + } else if (func_flag == C1NUM) { + if (matmul->out_need_aligned_) { + MatVecMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_align_); + } else { + MatVecMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_); + } + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = NNACL_MIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + + if (func_flag == 0) { + if (matmul->out_need_aligned_) { + MatMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_, compute->row_); + } else { + MatMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_, compute->row_); + } + } else if (func_flag == C1NUM) { + if (matmul->out_need_aligned_) { + MatVecMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_); + } else { + MatVecMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_); + } + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByGEPM(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_; + int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_; + int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_; + int matrix_col = matmul->compute_.col_step_; + int matrix_deep = matmul->compute_.deep_; + + // by BatchCut + int start_batch = task_id * matmul->compute_.batch_stride_; + int end_batch = NNACL_MIN(matmul->batch_, start_batch + matmul->compute_.batch_stride_); + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size; + float *c = matmul->output_data_ + index * c_plane_size; + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + MatVecMulNoPackFp32(a, b, c, bias, param->act_type_, matrix_deep, matrix_col, matrix_col); + } + + // by ColCut + int col_split_points_size = matmul->col_split_points_size_; + if (task_id < col_split_points_size) { + int start_oc = matmul->col_split_points_[task_id]; + int end_oc = matrix_col; + if (task_id < (col_split_points_size - 1)) { + end_oc = matmul->col_split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc; + float *c = matmul->output_data_ + i * c_plane_size + start_oc; + MatVecMulNoPackFp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col); + } + } + return NNACL_OK; +} +int MatmulAVX512ParallelRunByGEMM(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_; + int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_; + int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_; + int matrix_row = matmul->compute_.row_; + int matrix_col = matmul->compute_.col_step_; + int matrix_deep = matmul->compute_.deep_; + + // by BatchCut + int start_batch = task_id * matmul->compute_.batch_stride_; + int end_batch = start_batch + matmul->compute_.batch_stride_; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size; + float *c = matmul->output_data_ + index * c_plane_size; + MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, matrix_col, matrix_col, matrix_row); + } + + // by ColCut + int col_split_points_size = matmul->col_split_points_size_; + if (task_id < col_split_points_size) { + int start_oc = matmul->col_split_points_[task_id]; + int end_oc = matmul->col_split_points_[task_id + 1]; + int compute_oc = end_oc - start_oc; + + bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + if (compute_oc > 0) { + for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc * matrix_deep; + float *c = matmul->output_data_ + i * c_plane_size + start_oc; + MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col, matrix_row); + } + } + } + + // by RowCut + int start_oc = matmul->col_split_points_[col_split_points_size]; + int end_oc = matrix_col; + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + + int row_split_points_size = matmul->row_split_points_size_; + if (task_id >= row_split_points_size) { + return NNACL_OK; + } + int start_row = matmul->row_split_points_[task_id]; + int end_row = matmul->row_split_points_[task_id + 1]; + int row_num = end_row - start_row; + if (row_num <= 0) { + return NNACL_OK; + } + + bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size + start_row * matrix_deep; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc * matrix_deep; + float *c = matmul->output_data_ + i * c_plane_size + start_row * matrix_col + start_oc; + MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col, row_num); + } + + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByGEPDOT(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatmulComputeParam *compute = &matmul->compute_; + + // by BatchCut + int start_batch = task_id * compute->batch_stride_; + int end_batch = start_batch + compute->batch_stride_; + float bias = 0; + if (matmul->matrix_c_.pack_ptr_ != NULL) { + bias = matmul->matrix_c_.pack_ptr_[0]; + } + int a_stride = compute->row_ * compute->deep_; + int b_stride = compute->deep_ * compute->col_; + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_stride; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_stride; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_; + matmul->gemm_not_pack_fun_(a, b, c, &bias, compute->row_, compute->deep_, param->act_type_); + } + + // by RowCut + int split_points_size = matmul->row_split_points_size_; + if (task_id >= split_points_size) { + return NNACL_OK; + } + for (int index = matmul->base_.thread_nr_ * compute->batch_stride_; index < matmul->batch_; ++index) { + int start_row = matmul->row_split_points_[task_id]; + int end_row = matmul->row_split_points_[task_id + 1]; + int row_num = end_row - start_row; + if (row_num <= 0) { + continue; + } + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_stride + start_row * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_stride; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_ + start_row * compute->col_step_; + matmul->gemm_not_pack_fun_(a, b, c, &bias, row_num, compute->deep_, param->act_type_); + } + + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByRow1Deep1GEPDOT(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_; + int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_; + int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_; + int matrix_col = matmul->compute_.col_step_; + int matrix_deep = matmul->compute_.deep_; + + // by BatchCut + int start_batch = task_id * matmul->compute_.batch_stride_; + int end_batch = NNACL_MIN(matmul->batch_, start_batch + matmul->compute_.batch_stride_); + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size; + float *c = matmul->output_data_ + index * c_plane_size; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + matmul->gemm_not_pack_fun_(a, b, c, bias, matrix_col, matrix_deep, param->act_type_); + } + + // by ColCut + int col_split_points_size = matmul->col_split_points_size_; + if (task_id < col_split_points_size) { + int start_oc = matmul->col_split_points_[task_id]; + int end_oc = matrix_col; + if (task_id < (col_split_points_size - 1)) { + end_oc = matmul->col_split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc; + float *c = matmul->output_data_ + i * c_plane_size + start_oc; + matmul->gemm_not_pack_fun_(a, b, c, bias, compute_oc, matrix_deep, param->act_type_); + } + } + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByBatchColRowGEMM(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_; + int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_; + int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_; + int matrix_row = matmul->compute_.row_; + int matrix_col = matmul->compute_.col_step_; + int matrix_deep = matmul->compute_.deep_; + + // by BatchCut + int start_batch = task_id * matmul->compute_.batch_stride_; + int end_batch = start_batch + matmul->compute_.batch_stride_; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size; + float *c = matmul->output_data_ + index * c_plane_size; + MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, matrix_col, matrix_col, matrix_row); + } + + MatmulSlice *matmul_slices = matmul->matmul_slice_set_[task_id]; + int slice_count = matmul->matmul_slice_count_[task_id]; + for (int s = 0; s < slice_count; s++) { + MatmulSlice matmul_slice = matmul_slices[s]; + + int start_oc = matmul_slice.col_s_; + int end_oc = matmul_slice.col_e_; + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + + int start_row = matmul_slice.row_s_; + int end_row = matmul_slice.row_e_; + int row_num = end_row - start_row; + if (row_num <= 0) { + return NNACL_OK; + } + + bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size + start_row * matrix_deep; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc * matrix_deep; + float *c = matmul->output_data_ + i * c_plane_size + start_row * matrix_col + start_oc; + MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col, row_num); + } + } + return NNACL_OK; +} + +KernelBase *CreateMatmulAVX512() { + MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + matmul->matmul_type_ = kNotImplemented; + matmul->check_thread_cutting_by_row_ = MatmulAVX512CheckThreadCuttingByRow; + matmul->get_thread_cutting_policy_ = MatmulAVX512GetThreadCuttingPolicy; + matmul->init_parameter_ = MatmulAVX512InitParameter; + matmul->init_global_varibale_ = MatmulAVX512InitGlobalVariable; + matmul->parallel_run_by_oc_ = MatmulAVX512ParallelRunByOC; + matmul->parallel_run_by_row_ = MatmulAVX512ParallelRunByRow; + matmul->parallel_run_by_batch_ = MatmulAVX512ParallelRunByBatch; + matmul->parallel_run_by_gemm_ = MatmulAVX512ParallelRunByGEMM; + matmul->parallel_run_by_gepm_ = MatmulAVX512ParallelRunByGEPM; + matmul->parallel_run_by_gepdot_ = MatmulAVX512ParallelRunByGEPDOT; + matmul->parallel_run_by_batch_col_row_gemm_ = MatmulAVX512ParallelRunByBatchColRowGEMM; + matmul->parallel_run_by_row1_deep1_gepdot_ = MatmulAVX512ParallelRunByRow1Deep1GEPDOT; + return (KernelBase *)matmul; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx512.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx512.h new file mode 100644 index 0000000000000000000000000000000000000000..4233286cedfc30436c84df996975b5119919332b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx512.h @@ -0,0 +1,27 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_AVX512_H_ +#define NNACL_KERNEL_MATMUL_AVX512_H_ +#ifdef ENABLE_AVX512 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateMatmulAVX512(); + +#endif +#endif // NNACL_KERNEL_MATMUL_AVX512_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_base.c new file mode 100644 index 0000000000000000000000000000000000000000..35917710ab657b1b57a7e99d1d75cbdc32dc13ec --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_base.c @@ -0,0 +1,676 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/op_base.h" + +#define kNumDeepThreshold 512 + +int MatmulFp32Run(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + MatmulStruct *matmul = (MatmulStruct *)cdata; + return matmul->parallel_run_(matmul, task_id); +} + +void MatmulBaseFreeBatchOffset(MatmulStruct *matmul) { + if (matmul->a_offset_ != NULL) { + free(matmul->a_offset_); + matmul->a_offset_ = NULL; + } + if (matmul->b_offset_ != NULL) { + free(matmul->b_offset_); + matmul->b_offset_ = NULL; + } +} + +int MatmulBaseMallocBatchOffset(MatmulStruct *matmul) { + matmul->a_offset_ = malloc(matmul->batch_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->a_offset_); + memset(matmul->a_offset_, 0, matmul->batch_ * sizeof(int)); + + matmul->b_offset_ = malloc(matmul->batch_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->b_offset_); + memset(matmul->b_offset_, 0, matmul->batch_ * sizeof(int)); + return NNACL_OK; +} + +int MatmulBasePackMatrixBParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + + int start = task_id * compute->pack_b_stride_; + if (param->b_transpose_) { + int end = NNACL_MIN(matmul->compute_.col_, start + compute->pack_b_stride_); + matmul->matrix_b_pack_fun_(matmul->pack_b_src_, matmul->pack_b_dst_, compute->col_, compute->deep_, start, end); + } else { + int end = NNACL_MIN(matmul->compute_.deep_, start + compute->pack_b_stride_); + matmul->matrix_b_pack_fun_(matmul->pack_b_src_, matmul->pack_b_dst_, compute->deep_, compute->col_, start, end); + } + return NNACL_OK; +} + +int MatmulFp32PackMatrixBRun(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + MatmulStruct *matmul = (MatmulStruct *)cdata; + return MatmulBasePackMatrixBParallelRunByBatch(matmul, task_id); +} + +bool MatmulBaseCheckRowOptimalConditions(MatmulStruct *matmul) { + return matmul->compute_.row_ == 1 && + !(matmul->support_mul_batch_cut_by_row_ && (matmul->a_batch_ > 1 && matmul->b_batch_ == 1)); +} + +int MatmulBaseInitParameter(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + + matmul->init_global_varibale_(matmul); + if (MatmulBaseCheckRowOptimalConditions(matmul)) { + compute->row_tile_ = 1; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_a_.need_pack_ = false; + matmul->pack_opt_ = false; + if (!matmul->b_const_ && compute->col_ <= C128NUM) { + compute->col_tile_ = 1; + matmul->out_need_aligned_ = false; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_.need_pack_ = param->b_transpose_; + } + } + if (compute->col_ == 1 && !matmul->a_const_) { + matmul->out_need_aligned_ = false; + compute->row_tile_ = 1; + compute->col_tile_ = 1; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_a_.need_pack_ = param->a_transpose_ && compute->row_ != 1; + matmul->matrix_b_.need_pack_ = false; + matmul->pack_opt_ = false; + } + compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_); + compute->col_align_ = UP_ROUND(compute->col_, compute->col_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->row_align_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->row_align_, compute->deep_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->col_align_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->col_align_, compute->deep_, NNACL_ERR); + int a_pack_size = matmul->a_batch_ * compute->row_align_ * compute->deep_; + int b_pack_size = matmul->b_batch_ * compute->col_align_ * compute->deep_; + if ((matmul->matrix_a_.has_packed_ && matmul->matrix_a_.pack_size_ != a_pack_size) || + (matmul->matrix_b_.has_packed_ && matmul->matrix_b_.pack_size_ != b_pack_size)) { + return NNACL_ERR; + } + matmul->matrix_a_.pack_size_ = a_pack_size; + matmul->matrix_b_.pack_size_ = b_pack_size; + compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_); + matmul->out_need_aligned_ = (matmul->out_need_aligned_ && ((compute->col_ % compute->col_tile_) != 0)); + compute->col_step_ = matmul->out_need_aligned_ ? compute->col_align_ : compute->col_; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(matmul->a_batch_, compute->row_), NNACL_ERR); + compute->row_num_ = matmul->a_batch_ * compute->row_; + return NNACL_OK; +} + +int MatmulBasePackMatrixAImplOpt(MatmulStruct *matmul) { return NNACL_ERR; } + +int MatmulBasePackMatrixAImpl(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + float *src_ptr = (matmul->matrix_a_.origin_ptr_ != NULL) ? (matmul->matrix_a_.origin_ptr_) + : (float *)(matmul->base_.in_[FIRST_INPUT]->data_); + NNACL_CHECK_TRUE_RET(src_ptr != NULL, NNACL_ERR); + NNACL_CHECK_TRUE_RET(matmul->matrix_a_.pack_ptr_ != NULL, NNACL_ERR); + NNACL_CHECK_TRUE_RET(matmul->matrix_a_pack_fun_ != NULL, NNACL_ERR); + for (int i = 0; i < matmul->a_batch_; i++) { + const float *src = src_ptr + i * matmul->compute_.deep_ * matmul->compute_.row_; + float *dst = matmul->matrix_a_.pack_ptr_ + i * matmul->compute_.deep_ * matmul->compute_.row_align_; + if (param->a_transpose_) { + matmul->matrix_a_pack_fun_(src, dst, matmul->compute_.deep_, matmul->compute_.row_, 0, matmul->compute_.deep_); + } else { + matmul->matrix_a_pack_fun_(src, dst, matmul->compute_.row_, matmul->compute_.deep_, 0, matmul->compute_.row_); + } + } + return NNACL_OK; +} + +int MatmulBasePackMatrixBImpl(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + + float *src_ptr = matmul->matrix_b_.origin_ptr_ != NULL ? matmul->matrix_b_.origin_ptr_ + : (float *)matmul->base_.in_[SECOND_INPUT]->data_; + NNACL_CHECK_TRUE_RET(src_ptr != NULL, NNACL_ERR); + NNACL_CHECK_TRUE_RET(matmul->matrix_b_.pack_ptr_ != NULL, NNACL_ERR); + NNACL_CHECK_TRUE_RET(matmul->matrix_b_pack_fun_ != NULL, NNACL_ERR); + + for (int i = 0; i < matmul->b_batch_; i++) { + if (param->b_transpose_) { + matmul->compute_.pack_b_stride_ = UP_DIV(matmul->compute_.col_, matmul->base_.thread_nr_); + } else { + matmul->compute_.pack_b_stride_ = UP_DIV(matmul->compute_.deep_, matmul->base_.thread_nr_); + } + matmul->pack_b_src_ = src_ptr + i * matmul->compute_.deep_ * matmul->compute_.col_; + matmul->pack_b_dst_ = matmul->matrix_b_.pack_ptr_ + i * matmul->compute_.deep_ * matmul->compute_.col_align_; + int ret = matmul->base_.env_->ParallelLaunch(matmul->base_.env_->thread_pool_, MatmulFp32PackMatrixBRun, matmul, + matmul->base_.thread_nr_); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + } + return NNACL_OK; +} + +int MatmulBasePackMatrixA(MatmulStruct *matmul) { + if (!matmul->a_const_) { + if (!matmul->matrix_a_.need_pack_) { + matmul->matrix_a_.pack_ptr_ = (float *)matmul->base_.in_[0]->data_; + return NNACL_OK; + } + if (matmul->base_.train_session_) { + matmul->matrix_a_.pack_ptr_ = (float *)(matmul->base_.workspace_); + } else { + matmul->matrix_a_.pack_ptr_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, + matmul->matrix_a_.pack_size_ * sizeof(float))); + } + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->matrix_a_.pack_ptr_); + } else { + bool is_packed = false; + void *data = NULL; + size_t data_size = (size_t)(matmul->matrix_a_.pack_size_) * sizeof(float); + if (matmul->is_sharing_pack_) { + TensorC *a_matrix = matmul->base_.in_[FIRST_INPUT]; + data = matmul->get_sharing_weight_(matmul->shaing_manager_, a_matrix->data_, data_size, &is_packed); + } else { + data = malloc(data_size); + } + matmul->matrix_a_.pack_ptr_ = (float *)data; + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->matrix_a_.pack_ptr_); + if (is_packed) { + return NNACL_OK; + } + } + if (matmul->pack_opt_) { + /* valid in arm64 */ + return matmul->pack_matrix_a_impl_opt_(matmul); + } + return matmul->pack_matrix_a_impl_(matmul); +} + +int MatmulBasePackMatrixB(MatmulStruct *matmul) { + if (!matmul->b_const_) { + if (!matmul->matrix_b_.need_pack_) { + matmul->matrix_b_.pack_ptr_ = (float *)matmul->base_.in_[SECOND_INPUT]->data_; + return NNACL_OK; + } + if (matmul->base_.train_session_) { + matmul->matrix_b_.pack_ptr_ = (float *)(matmul->base_.workspace_) + matmul->matrix_a_.pack_size_; + } else { + matmul->matrix_b_.pack_ptr_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, + matmul->matrix_b_.pack_size_ * sizeof(float))); + } + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->matrix_b_.pack_ptr_); + } else { + if (!matmul->matrix_b_.need_pack_ && matmul->weight_is_packed_) { + matmul->matrix_b_.pack_ptr_ = (float *)matmul->base_.in_[SECOND_INPUT]->data_; + return NNACL_OK; + } + bool is_packed = false; + void *data = NULL; + size_t data_size = (size_t)(matmul->matrix_b_.pack_size_) * sizeof(float); + if (matmul->is_sharing_pack_) { + TensorC *b_matrix = matmul->base_.in_[SECOND_INPUT]; + data = matmul->get_sharing_weight_(matmul->shaing_manager_, b_matrix->data_, data_size, &is_packed); + } else { + data = malloc(data_size); + } + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(data); + matmul->matrix_b_.pack_ptr_ = (float *)data; + if (is_packed) { + return NNACL_OK; + } + } + return matmul->pack_matrix_b_impl_(matmul); +} + +int MatmulBaseBackupConstMatrix(MatmulStruct *matmul, MatrixInfo *matrix_info, int index) { + NNACL_CHECK_TRUE_RET(index < (int)matmul->base_.in_size_, NNACL_ERR); + size_t backup_size = (size_t)NNACLGetElementNum(matmul->base_.in_[index]) * sizeof(float); + NNACL_CHECK_TRUE_RET(backup_size > 0, NNACL_ERR); + matrix_info->origin_ptr_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, backup_size)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matrix_info->origin_ptr_); + void *src_ptr = matmul->base_.in_[index]->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_ptr); + (void)memcpy(matrix_info->origin_ptr_, src_ptr, backup_size); + matrix_info->origin_need_free_ = true; + return NNACL_OK; +} + +int MatmulBaseParallelRunByRow(MatmulStruct *matmul, int task_id) { return NNACL_ERR; } + +int MatmulBaseParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, param->act_type_, compute->deep_, compute->row_, compute->col_step_, compute->col_, + OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block8(a, b, c, bias, param->act_type_, compute->deep_, compute->col_step_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, param->act_type_, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulBaseParallelRunIsNotPackByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + int start_batch = task_id * matmul->compute_.batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + matmul->compute_.batch_stride_); + float bias = 0; + if (matmul->matrix_c_.pack_ptr_ != NULL) { + bias = matmul->matrix_c_.pack_ptr_[0]; + } + for (int index = start_batch; index < end_batch; ++index) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * matmul->compute_.row_ * matmul->compute_.deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * matmul->compute_.deep_ * matmul->compute_.col_; + float *c = matmul->output_data_ + index * matmul->compute_.row_ * matmul->compute_.col_; + matmul->gemm_not_pack_fun_(a, b, c, &bias, matmul->compute_.row_, matmul->compute_.deep_, param->act_type_); + } + return NNACL_OK; +} + +void MatmulBaseGetThreadCuttingInfoByRow(MatmulStruct *matmul) { + int row_step = NNACL_MAX(matmul->compute_.row_num_ / matmul->base_.thread_nr_, matmul->compute_.row_min_unit_); + int row_remaining = matmul->compute_.row_num_ - row_step * matmul->base_.thread_nr_; + + int split_point = 0; + int count = 0; + while (split_point < matmul->compute_.row_num_) { + matmul->split_points_[count++] = split_point; + split_point += row_step; + if (row_remaining > 0) { + ++split_point; + --row_remaining; + } + } + matmul->base_.thread_nr_ = count; +} + +int MatmulBaseParallelRunByOC(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block8(a, b, c, bias, act, compute->deep_, compute_oc); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + return NNACL_OK; +} + +void MatmulBaseGetThreadCuttingPolicy(MatmulStruct *matmul) { + if (matmul->compute_.deep_ < kNumDeepThreshold) { + if (matmul->model_thread_nr_ != -1) { + matmul->base_.thread_nr_ = matmul->model_thread_nr_; + } + } + + if ((matmul->a_batch_ >= matmul->base_.thread_nr_ && + (matmul->b_batch_ == matmul->a_batch_ || !matmul->support_mul_batch_cut_by_row_)) || + matmul->compute_.col_ == 1) { + matmul->compute_.batch_stride_ = UP_DIV(matmul->batch_, matmul->base_.thread_nr_); + matmul->parallel_run_ = matmul->parallel_run_by_batch_; + if (matmul->compute_.col_ != 1 || matmul->a_const_) { + return; + } + + matmul->parallel_run_ = matmul->parallel_run_not_pack_by_batch_; + if (matmul->compute_.deep_ == 1) { + matmul->gemm_not_pack_fun_ = GemmIsNotPack; + } else { + matmul->gemm_not_pack_fun_ = GemmIsNotPackOptimize; + if (matmul->check_thread_cutting_by_row_(matmul)) { + matmul->parallel_run_ = matmul->parallel_run_by_row_; + matmul->get_thread_cutting_info_by_row_(matmul); + } + } + return; + } else if ((matmul->a_batch_ >= matmul->base_.thread_nr_ && matmul->b_batch_ == 1) || + matmul->check_thread_cutting_by_row_(matmul)) { + matmul->parallel_run_ = matmul->parallel_run_by_row_; + matmul->get_thread_cutting_info_by_row_(matmul); + } else { + int total_col_unit = UP_DIV(matmul->compute_.col_align_, matmul->compute_.col_min_unit_); + matmul->base_.thread_nr_ = MSMIN(matmul->base_.thread_nr_, total_col_unit); + int block_col_unit = UP_DIV(total_col_unit, matmul->base_.thread_nr_); + + int count = 0; + int split_point = 0; + while (split_point < total_col_unit) { + matmul->split_points_[count++] = (split_point * matmul->compute_.col_min_unit_); + split_point += block_col_unit; + } + matmul->base_.thread_nr_ = count; + matmul->parallel_run_ = matmul->parallel_run_by_oc_; + } + return; +} + +int MatmulBasePackBiasMatrix(MatmulStruct *matmul) { + if (matmul->base_.in_size_ != FOURTH_INPUT) { + return NNACL_OK; + } + if (matmul->matrix_c_.has_packed_) { + NNACL_CHECK_FALSE(matmul->matrix_c_.pack_size_ < matmul->compute_.col_align_, NNACL_ERR); + return NNACL_OK; + } + TensorC *bias_tensor = matmul->base_.in_[THIRD_INPUT]; + float *bias_src = matmul->matrix_c_.origin_ptr_ != NULL ? matmul->matrix_c_.origin_ptr_ : (float *)bias_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(bias_src); + + int bias_num = NNACLGetElementNum(bias_tensor); + NNACL_CHECK_TRUE_RET(bias_num > 0 && matmul->compute_.col_align_ >= bias_num, NNACL_ERR); + + matmul->matrix_c_.pack_size_ = matmul->compute_.col_align_; + if (matmul->matrix_c_.pack_ptr_ == NULL) { + matmul->matrix_c_.pack_ptr_ = (float *)(malloc(matmul->matrix_c_.pack_size_ * sizeof(float))); + } + NNACL_CHECK_NULL_RETURN_ERR(matmul->matrix_c_.pack_ptr_); + + if (bias_num == 1) { + for (int i = 0; i < matmul->matrix_c_.pack_size_; ++i) { + matmul->matrix_c_.pack_ptr_[i] = bias_src[0]; + } + } else { + (void)memcpy(matmul->matrix_c_.pack_ptr_, bias_src, bias_num * sizeof(float)); + (void)memset(matmul->matrix_c_.pack_ptr_ + bias_num, 0, (matmul->matrix_c_.pack_size_ - bias_num) * sizeof(float)); + } + if (matmul->matrix_c_.origin_need_free_) { + matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->matrix_c_.origin_ptr_); + matmul->matrix_c_.origin_ptr_ = NULL; + matmul->matrix_c_.origin_need_free_ = false; + } + return NNACL_OK; +} + +int MatmulBaseInitTmpOutBuffer(MatmulStruct *matmul) { + if (matmul->out_need_aligned_) { + if (matmul->output_data_ != NULL) { + matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->output_data_); + } + // avx need to malloc dst aligned to C8NUM + // avx512 need to malloc dst aligned to C16NUM + int out_channel = matmul->compute_.col_; + NNACL_CHECK_ZERO_RETURN_ERR(matmul->compute_.col_tile_); + int oc_block_num = UP_DIV(out_channel, matmul->compute_.col_tile_); + int ele_num = matmul->batch_ * matmul->compute_.row_ * oc_block_num * matmul->compute_.col_tile_; + int data_size = ele_num * (int)sizeof(float); + matmul->output_data_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, data_size)); + NNACL_CHECK_NULL_RETURN_ERR(matmul->output_data_); + } + return NNACL_OK; +} + +void MatmulBaseInitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->matrix_a_.need_pack_ = true; + matmul->matrix_b_.need_pack_ = !matmul->weight_is_packed_; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col8MajorParallel : RowMajor2Row8MajorParallel; + matmul->compute_.row_tile_ = C12NUM; + matmul->compute_.col_tile_ = C8NUM; + matmul->compute_.col_min_unit_ = C8NUM; + return; +} + +bool MatmulBaseCheckThreadCuttingByRow() { return false; } + +void MatmulBaseFreePackedMatrixA(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + if (matmul->matrix_a_.need_pack_ && !matmul->base_.train_session_ && matmul->matrix_a_.pack_ptr_ != NULL) { + self->env_->Free(self->env_->allocator_, matmul->matrix_a_.pack_ptr_); + } + matmul->matrix_a_.pack_ptr_ = NULL; +} + +void MatmulBaseFreePackedMatrixB(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + if (matmul->matrix_b_.need_pack_ && !matmul->base_.train_session_ && matmul->matrix_b_.pack_ptr_ != NULL) { + matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->matrix_b_.pack_ptr_); + } + matmul->matrix_b_.pack_ptr_ = NULL; +} + +int MatmulBaseResize(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + + int ret = matmul->init_parameter_(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + if (self->train_session_) { + self->work_size_ = (matmul->matrix_a_.pack_size_ + matmul->matrix_b_.pack_size_) * (int)sizeof(float); + } + + matmul->get_thread_cutting_policy_(matmul); + if (!matmul->matrix_c_.has_packed_) { + ret = MatmulBasePackBiasMatrix(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + if (!matmul->bias_need_repack_) { + matmul->matrix_c_.has_packed_ = true; + } + } + ret = MatmulBaseInitTmpOutBuffer(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + + return NNACL_OK; +} + +int MatmulBaseRelease(struct KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + MatmulBaseFreeBatchOffset(matmul); + + if (matmul->out_need_aligned_ && matmul->output_data_ != NULL) { + matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->output_data_); + matmul->output_data_ = NULL; + } + if (matmul->matrix_c_.pack_ptr_ != NULL) { + free(matmul->matrix_c_.pack_ptr_); + matmul->matrix_c_.pack_ptr_ = NULL; + } + if (matmul->a_const_) { + if (matmul->is_sharing_pack_) { + matmul->free_sharing_weight_(matmul->shaing_manager_, matmul->matrix_a_.pack_ptr_); + } else { + free(matmul->matrix_a_.pack_ptr_); + } + } + if (matmul->b_const_) { + if (!matmul->matrix_b_.need_pack_ && matmul->weight_is_packed_) { + return NNACL_OK; + } + if (matmul->is_sharing_pack_) { + matmul->free_sharing_weight_(matmul->shaing_manager_, matmul->matrix_b_.pack_ptr_); + } else { + free(matmul->matrix_b_.pack_ptr_); + } + } + return NNACL_OK; +} + +int MatmulBasePrepare(struct KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + + NNACL_CHECK_FALSE(matmul->base_.in_size_ < C2NUM, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(matmul->base_.out_size_ < 1, NNACL_OUTPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(matmul->base_.in_[FIRST_INPUT]->data_type_ != kNumberTypeFloat32, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(matmul->base_.in_[SECOND_INPUT]->data_type_ != kNumberTypeFloat32, NNACL_INPUT_TENSOR_ERROR); + + if (matmul->base_.in_size_ == THREE_TENSOR) { + NNACL_CHECK_TRUE_RET(matmul->base_.in_[THIRD_INPUT]->data_type_ == kNumberTypeFloat32, NNACL_MATMUL_BIAS_INVALID); + } + + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE( + param->act_type_ != ActType_No && param->act_type_ != ActType_Relu && param->act_type_ != ActType_Relu6, + NNACL_MATMUL_ACT_TYPE_INVALID); + + int ret = matmul->init_parameter_(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + + if (matmul->a_const_) { + ret = MatmulBasePackMatrixA(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + matmul->matrix_a_.has_packed_ = true; + } + if (matmul->b_const_) { + ret = MatmulBasePackMatrixB(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + matmul->matrix_b_.has_packed_ = true; + } + + if (matmul->base_.in_size_ == THREE_TENSOR) { + /* deal with const bias */ + bool bias_const = NNACLIsConst(self->in_[THIRD_INPUT]); + if (!matmul->infer_shape_ && bias_const && !matmul->base_.train_session_ && matmul->matrix_c_.origin_ptr_ == NULL) { + ret = MatmulBaseBackupConstMatrix(matmul, &matmul->matrix_c_, THIRD_INPUT); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + } + } + return NNACL_OK; +} + +int MatmulBaseCompute(struct KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + + float *out_data = (float *)(matmul->base_.out_[FIRST_INPUT]->data_); + NNACL_CHECK_FALSE(out_data == NULL, NNACL_ERR); + if (!matmul->out_need_aligned_) { + matmul->output_data_ = out_data; + } + + if (!matmul->a_const_) { + int ret = MatmulBasePackMatrixA(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + } + if (!matmul->b_const_) { + int ret = MatmulBasePackMatrixB(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + } + + NNACL_CHECK_NULL_RETURN_ERR(matmul->matrix_a_.pack_ptr_); + NNACL_CHECK_NULL_RETURN_ERR(matmul->matrix_b_.pack_ptr_); + + int ret = self->env_->ParallelLaunch(self->env_->thread_pool_, MatmulFp32Run, self, self->thread_nr_); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + + if (matmul->out_need_aligned_) { + PackNHWCXToNHWCFp32(matmul->output_data_, out_data, matmul->batch_, matmul->compute_.row_, matmul->compute_.col_, + matmul->compute_.col_tile_); + } else { + matmul->output_data_ = NULL; + } + if (!matmul->a_const_) { + MatmulBaseFreePackedMatrixA(self); + } + + if (!matmul->b_const_) { + MatmulBaseFreePackedMatrixB(self); + } + return NNACL_OK; +} + +void InitMatrixInfo(MatrixInfo *info) { + info->need_pack_ = false; + info->has_packed_ = false; + info->origin_need_free_ = false; + info->pack_size_ = -1; + info->origin_ptr_ = NULL; + info->pack_ptr_ = NULL; +} + +KernelBase *CreateMatmulBase() { + MatmulStruct *matmul = (MatmulStruct *)malloc(sizeof(MatmulStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + memset(matmul, 0, sizeof(MatmulStruct)); + matmul->base_.Prepare = MatmulBasePrepare; + matmul->base_.Resize = MatmulBaseResize; + matmul->base_.Release = MatmulBaseRelease; + matmul->base_.Compute = MatmulBaseCompute; + InitMatrixInfo(&(matmul->matrix_a_)); + InitMatrixInfo(&(matmul->matrix_b_)); + InitMatrixInfo(&(matmul->matrix_c_)); + matmul->is_sharing_pack_ = false; + matmul->pack_opt_ = false; + matmul->a_const_ = false; + matmul->b_const_ = false; + matmul->bias_need_repack_ = false; + matmul->out_need_aligned_ = false; + matmul->a_offset_ = NULL; + matmul->b_offset_ = NULL; + matmul->model_thread_nr_ = -1; + matmul->support_mul_batch_cut_by_row_ = false; + matmul->matmul_type_ = kMatmulFp32BaseCpu; + matmul->get_thread_cutting_policy_ = MatmulBaseGetThreadCuttingPolicy; + matmul->check_thread_cutting_by_row_ = MatmulBaseCheckThreadCuttingByRow; + matmul->get_thread_cutting_info_by_row_ = MatmulBaseGetThreadCuttingInfoByRow; + matmul->init_parameter_ = MatmulBaseInitParameter; + matmul->init_global_varibale_ = MatmulBaseInitGlobalVariable; + matmul->pack_matrix_a_impl_opt_ = MatmulBasePackMatrixAImplOpt; + matmul->pack_matrix_a_impl_ = MatmulBasePackMatrixAImpl; + matmul->pack_matrix_b_impl_ = MatmulBasePackMatrixBImpl; + matmul->parallel_run_by_batch_ = MatmulBaseParallelRunByBatch; + matmul->parallel_run_not_pack_by_batch_ = MatmulBaseParallelRunIsNotPackByBatch; + matmul->parallel_run_by_oc_ = MatmulBaseParallelRunByOC; + matmul->parallel_run_by_row_ = MatmulBaseParallelRunByRow; + return (KernelBase *)matmul; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_base.h new file mode 100644 index 0000000000000000000000000000000000000000..d840697fb624b92af3ce566f8d1807d119dc741c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_base.h @@ -0,0 +1,35 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_BASE_H_ +#define NNACL_KERNEL_MATMUL_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/kernel/matmul_struct.h" + +void MatmulBaseGetThreadCuttingPolicy(MatmulStruct *matmul); +void MatmulBaseFreeBatchOffset(MatmulStruct *matmul); +int MatmulBaseMallocBatchOffset(MatmulStruct *matmul); +int MatmulBaseInitParameter(MatmulStruct *matmul); +int MatmulBasePrepare(KernelBase *self); +int MatmulBaseResize(KernelBase *self); +int MatmulBaseRelease(KernelBase *self); + +KernelBase *CreateMatmulBase(); + +#endif // NNACL_KERNEL_MATMUL_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_create.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_create.c new file mode 100644 index 0000000000000000000000000000000000000000..dac46193dbd351c2cd5cac9477f4766b6d2b2903 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_create.c @@ -0,0 +1,82 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/matmul_create.h" +#include "nnacl_c/kernel/matmul_base.h" +#if defined(ENABLE_AVX512) +#include "nnacl_c/kernel/matmul_avx512.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" +#endif + +#if defined(ENABLE_AVX) +#include "nnacl_c/kernel/matmul_avx.h" +#endif + +#if defined(ENABLE_SSE) +#include "nnacl_c/kernel/matmul_sse.h" +#endif + +#if defined(ENABLE_ARM32) +#include "nnacl_c/kernel/matmul_arm32.h" +#endif + +#if defined(ENABLE_ARM64) +#include "nnacl_c/kernel/matmul_arm64.h" +#endif + +KernelBase *CreateMatmulKernel() { + KernelBase *matmul = NULL; + +#if defined(ENABLE_AVX512) + AVX512_HARDWARE_SELF_AWARENESS_BEGIN + matmul = CreateMatmulAVX512(); + if (matmul != NULL) { + return matmul; + } + AVX512_HARDWARE_SELF_AWARENESS_END +#endif + +#if defined(ENABLE_AVX) + matmul = CreateMatmulAVX(); + if (matmul != NULL) { + return matmul; + } +#endif + +#if defined(ENABLE_SSE) + matmul = CreateMatmulSSE(); + if (matmul != NULL) { + return matmul; + } +#endif + +#if defined(ENABLE_ARM64) + matmul = CreateMatmulARM64(); + if (matmul != NULL) { + return matmul; + } +#endif + +#if defined(ENABLE_ARM32) + matmul = CreateMatmulARM32(); + if (matmul != NULL) { + return matmul; + } +#endif + + matmul = CreateMatmulBase(); + return matmul; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_create.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_create.h new file mode 100644 index 0000000000000000000000000000000000000000..a5cf3b44064797a89e0d3dabb69e733a3d498994 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_create.h @@ -0,0 +1,24 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_MATMUL_CREATE_H_ +#define NNACL_KERNEL_MATMUL_CREATE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateMatmulKernel(); + +#endif // NNACL_KERNEL_MATMUL_CREATE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_sse.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_sse.c new file mode 100644 index 0000000000000000000000000000000000000000..9ee236b8d40d8c645e5243fd0e75dc0452edfd71 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_sse.c @@ -0,0 +1,110 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_SSE +#include "nnacl_c/kernel/matmul_sse.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +void MatmulSSEInitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)matmul->base_.param_; + matmul->matrix_a_.need_pack_ = true; + matmul->matrix_b_.need_pack_ = true; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row4MajorParallel : RowMajor2Col4MajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col8MajorParallel : RowMajor2Row8MajorParallel; + matmul->compute_.row_tile_ = C4NUM; + matmul->compute_.col_tile_ = C8NUM; + matmul->compute_.col_min_unit_ = C8NUM; +} + +int MatmulSSEParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)matmul->base_.param_; + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute->col_step_, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block8(a, b, c, bias, act, compute->deep_, compute->col_step_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulSSEParallelRunByOC(MatmulStruct *matmul, int task_id) { + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatMulParameter *param = (MatMulParameter *)matmul->base_.param_; + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block8(a, b, c, bias, act, compute->deep_, compute_oc); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + return NNACL_OK; +} + +KernelBase *CreateMatmulSSE() { + MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + matmul->matmul_type_ = kNotImplemented; + matmul->init_global_varibale_ = MatmulSSEInitGlobalVariable; + matmul->parallel_run_by_oc_ = MatmulSSEParallelRunByOC; + matmul->parallel_run_by_batch_ = MatmulSSEParallelRunByBatch; + return (KernelBase *)matmul; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_sse.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_sse.h new file mode 100644 index 0000000000000000000000000000000000000000..78c7c0e379bbea165b4ddcba98f768d8dcbf832e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_sse.h @@ -0,0 +1,27 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_SSE_H_ +#define NNACL_KERNEL_MATMUL_SSE_H_ +#ifdef ENABLE_SSE +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateMatmulSSE(); + +#endif +#endif // NNACL_KERNEL_MATMUL_SSE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_struct.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_struct.h new file mode 100644 index 0000000000000000000000000000000000000000..501249cf50a8e1b1507675ba99e6b39400c70d26 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_struct.h @@ -0,0 +1,133 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_MATMUL_STRUCT_H_ +#define NNACL_KERNEL_MATMUL_STRUCT_H_ + +#include "nnacl_c/kernel.h" +#include "nnacl_c/matmul_parameter.h" + +#define SPLIT_COUNT MAX_THREAD_NUM + +typedef struct MatrixInfo { + bool need_pack_; + bool has_packed_; // only valid for constant, only do once throughout the process. + bool origin_need_free_; // true when failing to infer shape, false in conv1x1 free in convolution delegate + int pack_size_; + float *origin_ptr_; // only valid for constant, which is synchronized with the 'has_origin'. + float *pack_ptr_; +} MatrixInfo; + +typedef struct MatmulSlice { + int row_s_; + int row_e_; + int col_s_; + int col_e_; +} MatmulSlice; + +typedef struct MatmulComputeParam { + int row_; + int col_; + int deep_; + int row_align_; + int col_align_; + int deep_align_; + int row_num_; + int col_tile_; + int row_tile_; + int col_step_; + int row_min_unit_; + int col_min_unit_; + int batch_stride_; + int pack_b_stride_; + int block_col_unit_; +} MatmulComputeParam; + +typedef struct MatmulStruct { + KernelBase base_; + MatmulComputeParam compute_; + MatmulType matmul_type_; + + /* model pool optimize */ + int model_thread_nr_; + + /* batch-matmul broadcast */ + int batch_; + int a_batch_; + int b_batch_; + int *a_offset_; /* batch_ size */ + int *b_offset_; /* batch_ size */ + + int split_points_[SPLIT_COUNT]; + + float *output_data_; + float *pack_b_src_; + float *pack_b_dst_; + + bool a_const_; + bool b_const_; + bool bias_need_repack_; + bool infer_shape_; + bool pack_opt_; + bool is_sharing_pack_; + bool out_need_aligned_; + bool weight_is_packed_; + bool support_mul_batch_cut_by_row_; + + MatrixInfo matrix_a_; + MatrixInfo matrix_b_; + MatrixInfo matrix_c_; + + void (*matrix_a_pack_fun_)(const float *src_ptr, float *dst_ptr, int row, int col, int start_row, int end_row); + void (*matrix_b_pack_fun_)(const float *src_ptr, float *dst_ptr, int row, int col, int start_row, int end_row); + + int (*pack_matrix_a_impl_opt_)(struct MatmulStruct *matmul); + int (*pack_matrix_a_impl_)(struct MatmulStruct *matmul); + int (*pack_matrix_b_impl_)(struct MatmulStruct *matmul); + + int (*init_parameter_)(struct MatmulStruct *matmul); + void (*init_global_varibale_)(struct MatmulStruct *matmul); + + bool (*check_thread_cutting_by_row_)(struct MatmulStruct *matmul); + void (*get_thread_cutting_policy_)(struct MatmulStruct *matmul); + void (*get_thread_cutting_info_by_row_)(struct MatmulStruct *matmul); + + void *shaing_manager_; + void *(*get_sharing_weight_)(void *manager, const void *tensor_data, const size_t size, bool *is_packed); + void (*free_sharing_weight_)(void *manager, void *tensor_data); + + void (*gemm_not_pack_fun_)(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type); + + int (*parallel_run_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_row_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_oc_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_batch_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_not_pack_by_batch_)(struct MatmulStruct *matmul, int task_id); + + /* optimize for avx512 */ + int col_split_points_size_; + int row_split_points_size_; + int col_split_points_[SPLIT_COUNT]; + int row_split_points_[SPLIT_COUNT]; + int matmul_slice_count_[SPLIT_COUNT]; + MatmulSlice matmul_slice_set_[SPLIT_COUNT][SPLIT_COUNT]; + int (*parallel_run_by_gemm_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_gepm_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_gepdot_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_batch_col_row_gemm_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_row1_deep1_gepdot_)(struct MatmulStruct *matmul, int task_id); +} MatmulStruct; + +#endif // NNACL_KERNEL_MATMUL_STRUCT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/nllloss.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/nllloss.c new file mode 100644 index 0000000000000000000000000000000000000000..9b49f325ac68329ad044379e409f67609e60195a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/nllloss.c @@ -0,0 +1,63 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/nllloss.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/nllloss_fp32.h" +#include "nnacl_c/nllloss_parameter.h" + +int NlllossCompute(KernelBase *self) { + NLLLossStruct *nllloss = (NLLLossStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(nllloss); + float *logits = self->in_[Index0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(logits); + int *labels = self->in_[Index1]->data_; + NNACL_CHECK_NULL_RETURN_ERR(labels); + float *weight = self->in_[Index2]->data_; + NNACL_CHECK_NULL_RETURN_ERR(weight); + + float *loss = self->out_[Index0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(loss); + float *total_weight = self->out_[Index1]->data_; + NNACL_CHECK_NULL_RETURN_ERR(total_weight); + + ReductionType reduction_type = ((NLLLossParameter *)self->param_)->reduction_type_; + return NLLLoss(logits, labels, weight, loss, total_weight, nllloss, reduction_type); +} + +int NlllossPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < TWO_TENSOR, NNACL_ERR); + NLLLossStruct *nllloss = (NLLLossStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(nllloss); + TensorC *logits_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(logits_tensor); + nllloss->batch_ = logits_tensor->shape_[Index0]; + nllloss->class_num_ = logits_tensor->shape_[Index1]; + return NNACL_OK; +} + +KernelBase *CreateNLLLoss(OpParameter *param, int data_type) { + NLLLossStruct *nllloss = (NLLLossStruct *)malloc(sizeof(NLLLossStruct)); + NNACL_CHECK_NULL_RETURN_NULL(nllloss); + nllloss->base_.Release = DefaultRelease; + nllloss->base_.Prepare = NlllossPrepare; + nllloss->base_.Resize = DefaultResize; + nllloss->base_.Compute = NlllossCompute; + return (KernelBase *)nllloss; +} + +REG_KERNEL_CREATOR(PrimType_NLLLoss, kNumberTypeFloat32, CreateNLLLoss) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/nllloss.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/nllloss.h new file mode 100644 index 0000000000000000000000000000000000000000..0527b0492008f3d8684dbfbe9a49567022f545ea --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/nllloss.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_NLLLOSS_H_ +#define NNACL_KERNEL_NLLLOSS_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct { + KernelBase base_; + int batch_; + int class_num_; +} NLLLossStruct; + +KernelBase *CreateNLLLoss(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_NLLLOSS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_max_suppression.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_max_suppression.c new file mode 100644 index 0000000000000000000000000000000000000000..9787dd6ec963af30e189accbaf9bc9d7d70408d1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_max_suppression.c @@ -0,0 +1,126 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/non_max_suppression.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/non_max_suppression_parameter.h" +#include "nnacl_c/fp32/non_max_suppression_fp32.h" + +void NonMaxSuppressioExpandDims(int *dst_shape, int *origin_shape, size_t size) { + int i = 0; + for (; i < size; i++) { + dst_shape[i] = 1; + } + for (; i < Num3; i++) { + dst_shape[i] = origin_shape[i - size]; + } +} + +void NonMaxSuppressionGetParams(NonMaxSuppressionStruct *nm_suppression) { + // optional input order: max_output_per_class, iou_threshold, score_threshold + nm_suppression->max_output_per_class_ = 0; + if (nm_suppression->base_.in_size_ >= Num3) { + TensorC *max_output_tensor = nm_suppression->base_.in_[Index3]; + if (max_output_tensor != NULL && max_output_tensor->data_ != NULL) { + nm_suppression->max_output_per_class_ = *(int *)(max_output_tensor->data_); + } + } + + nm_suppression->iou_threshold_ = 0.0f; + if (nm_suppression->base_.in_size_ >= Num4) { + TensorC *iou_threshold_tensor = nm_suppression->base_.in_[Index4]; + if (iou_threshold_tensor != NULL && iou_threshold_tensor->data_ != NULL) { + nm_suppression->iou_threshold_ = *(float *)(iou_threshold_tensor->data_); + } + } + + nm_suppression->score_threshold_ = 0.0f; + if (nm_suppression->base_.in_size_ >= Num5) { + TensorC *score_threshold_tensor = nm_suppression->base_.in_[Index5]; + if (score_threshold_tensor != NULL && score_threshold_tensor->data_ != NULL) { + nm_suppression->score_threshold_ = *(float *)(score_threshold_tensor->data_); + } + } +} + +int NonMaxSuppressionCompute(KernelBase *self) { + NonMaxSuppressionStruct *nm_suppression = (NonMaxSuppressionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(nm_suppression); + + NonMaxSuppressionGetParams(nm_suppression); + + TensorC *box_tensor = self->in_[Index0]; + NNACL_CHECK_NULL_RETURN_ERR(box_tensor); + int box_dims[Num3] = {0}; // batch, box_num, 4 + bool simple_out = false; + if (box_tensor->shape_size_ != Num3) { + NonMaxSuppressioExpandDims(box_dims, box_tensor->shape_, Num3 - box_tensor->shape_size_); + simple_out = true; + } + if (box_dims[Index2] != Num4) { + return NNACL_NON_MAX_SUPPRESSION_BOX_DIMS_INVALID; + } + + TensorC *score_tensor = self->in_[Index1]; + NNACL_CHECK_NULL_RETURN_ERR(score_tensor); + int score_dims[Num3] = {0}; // batch, class, box_num + if (score_tensor->shape_size_ != Num3) { + NonMaxSuppressioExpandDims(score_dims, score_tensor->shape_, Num3 - score_tensor->shape_size_); + } + if (score_dims[Index0] != box_dims[Index0]) { + return NNACL_NON_MAX_SUPPRESSION_BOX_DIMS_SCORE_UNMATCH; + } + if (score_dims[Index2] != box_dims[Index1]) { + return NNACL_NON_MAX_SUPPRESSION_DIMENSION_SPATIAL_UNMATCH; + } + if (nm_suppression->base_.out_[OUTPUT_INDEX]->data_ != NULL) { + /* output shape and data set in compute */ + return NNACL_NON_MAX_SUPPRESSION_UNSUPPORT_DEFINE_DATA; + } + return NonMaxSuppressionSelecte(nm_suppression, simple_out, score_dims); +} + +int NonMaxSuppressionPrepare(KernelBase *self) { + NonMaxSuppressionStruct *nm_suppression = (NonMaxSuppressionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(nm_suppression); + + // boxes, scores, max_output_boxes, iou_threshold, score_threshold + if (self->in_size_ < Num2 || self->in_size_ > Num5 || self->out_size_ != Num1) { + return NNACL_NON_MAX_SUPPRESSION_TENSOR_SIZE_INVALID; + } + + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + NMSParameter *nmparam = (NMSParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(nmparam); + if (nmparam->center_point_box_ != 0 && nmparam->center_point_box_ != 1) { + return NNACL_NON_MAX_SUPPRESSION_PARAM_INVALID; + } + nm_suppression->center_point_box_ = nmparam->center_point_box_; + return NNACL_OK; +} + +KernelBase *CreateNonMaxSuppression(OpParameter *param, int data_type) { + NonMaxSuppressionStruct *non_max_suppression = (NonMaxSuppressionStruct *)malloc(sizeof(NonMaxSuppressionStruct)); + NNACL_CHECK_NULL_RETURN_NULL(non_max_suppression); + non_max_suppression->base_.Release = DefaultRelease; + non_max_suppression->base_.Resize = DefaultResize; + non_max_suppression->base_.Prepare = NonMaxSuppressionPrepare; + non_max_suppression->base_.Compute = NonMaxSuppressionCompute; + return (KernelBase *)non_max_suppression; +} + +REG_KERNEL_CREATOR(PrimType_NonMaxSuppression, kNumberTypeFloat32, CreateNonMaxSuppression) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_max_suppression.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_max_suppression.h new file mode 100644 index 0000000000000000000000000000000000000000..39d5485be9a3de12ab57a255e90cc36473a60915 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_max_suppression.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_NON_MAX_SUPPRESSION_H_ +#define NNACL_KERNEL_NON_MAX_SUPPRESSION_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct { + KernelBase base_; + int center_point_box_; + int max_output_per_class_; + float iou_threshold_; + float score_threshold_; +} NonMaxSuppressionStruct; + +KernelBase *CreateNonMaxSuppression(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_NON_MAX_SUPPRESSION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_zero.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_zero.c new file mode 100644 index 0000000000000000000000000000000000000000..6d9c46ec37da62a198f123c8d11bf421d93bfd3a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_zero.c @@ -0,0 +1,69 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/non_zero.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" + +int NonZeroCompute(KernelBase *self) { + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + NNACL_CHECK_FALSE(input->shape_size_ != DIMENSION_2D, NNACL_NON_ZERO_SHAPE_INVALID); + + bool *input_data = (bool *)input->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + int *output_data = (int *)output->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + int non_zero_nums = output->shape_[Index1]; + int non_zero_count = 0; + + int *coordiate_values = (int *)self->env_->Alloc(self->env_->allocator_, input->shape_size_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(coordiate_values); + + for (int i = 0; i < NNACLGetElementNum(input); i += 1) { + if (input_data[i]) { + for (size_t j = 0; j < input->shape_size_; j++) { + output_data[non_zero_count + (int)j * non_zero_nums] = coordiate_values[j]; + } + non_zero_count++; + } + for (size_t idx = input->shape_size_; idx >= 1; --idx) { + if (coordiate_values[idx - 1] != input->shape_[idx - 1] - 1) { + coordiate_values[idx - 1] = coordiate_values[idx - 1] + 1; + break; + } + coordiate_values[idx - 1] = 0; + } + } + + return NNACL_OK; +} + +KernelBase *CreateNonZero(OpParameter *param, int data_type) { + NonZeroStruct *non_zero = (NonZeroStruct *)malloc(sizeof(NonZeroStruct)); + NNACL_CHECK_NULL_RETURN_NULL(non_zero); + non_zero->base_.Release = DefaultRelease; + non_zero->base_.Prepare = DefaultPrepare2In1Out; + non_zero->base_.Resize = DefaultResize; + non_zero->base_.Compute = NonZeroCompute; + return (KernelBase *)non_zero; +} + +REG_KERNEL_CREATOR(PrimType_NonZero, kNumberTypeBool, CreateNonZero) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_zero.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_zero.h new file mode 100644 index 0000000000000000000000000000000000000000..383e67e00bcbe0550233c701af1d167a26ad1c92 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_zero.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_NON_ZERO_H_ +#define NNACL_KERNEL_NON_ZERO_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct { + KernelBase base_; +} NonZeroStruct; + +KernelBase *CreateNonZero(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_NON_ZERO_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/one_hot.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/one_hot.c new file mode 100644 index 0000000000000000000000000000000000000000..f303e03d1637bb9103493545863118fa903e9408 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/one_hot.c @@ -0,0 +1,193 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/one_hot.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/one_hot_parameter.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/fp32/one_hot_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/one_hot_fp16.h" +#endif + +int OneHotRun(void *cdata, int task_id, float l, float r) { + OneHotStruct *one_hot = (OneHotStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(one_hot); + + int *indices_data = (int *)one_hot->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(indices_data); + + TensorC *output_tensor = one_hot->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + void *output_data = one_hot->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + if (output_tensor->data_type_ == kNumberTypeFloat32) { + return OneHotToFp32(indices_data, one_hot->on_value_, one_hot->off_value_, (float *)output_data, one_hot, task_id, + one_hot->base_.thread_nr_); +#ifdef ENABLE_FP16 + } else if (output_tensor->data_type_ == kNumberTypeFloat16) { + return OneHotToFp16(indices_data, (float16_t)one_hot->on_value_, (float16_t)one_hot->off_value_, + (float16_t *)output_data, one_hot, task_id, one_hot->base_.thread_nr_); +#endif + } + + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +int OneHotInitOnOffValueForFourInputs(OneHotStruct *one_hot) { + TensorC *on_value_tensor = one_hot->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(on_value_tensor); + void *on_value_data = on_value_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(on_value_data); + if (on_value_tensor->data_type_ == kNumberTypeFloat32) { + one_hot->on_value_ = *((float *)on_value_data); +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) + } else if (on_value_tensor->data_type_ == kNumberTypeFloat16) { + one_hot->on_value_ = *((float16_t *)on_value_data); +#endif + } else { + return NNACL_ONE_HOR_ON_VALUE_TENSOR_DATA_TYPE_INVALID; + } + + TensorC *off_value_tensor = one_hot->base_.in_[FOURTH_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(off_value_tensor); + void *off_value_data = off_value_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(off_value_data); + if (on_value_tensor->data_type_ == kNumberTypeFloat32) { + one_hot->off_value_ = *((float *)off_value_data); +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) + } else if (on_value_tensor->data_type_ == kNumberTypeFloat16) { + one_hot->off_value_ = *((float16_t *)off_value_data); +#endif + } else { + return NNACL_ONE_HOR_OFF_VALUE_TENSOR_DATA_TYPE_INVALID; + } + + return NNACL_OK; +} + +int OneHotInitOnOffValueForThreeInputs(OneHotStruct *one_hot) { + TensorC *value_tensor = one_hot->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(value_tensor); + void *value_data = value_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(value_data); + + if (value_tensor->data_type_ == kNumberTypeFloat32) { + one_hot->off_value_ = ((float *)value_data)[Index0]; + one_hot->on_value_ = ((float *)value_data)[Index1]; +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) + } else if (value_tensor->data_type_ == kNumberTypeFloat16) { + one_hot->off_value_ = ((float16_t *)value_data)[Index0]; + one_hot->on_value_ = ((float16_t *)value_data)[Index1]; +#endif + } else { + return NNACL_ONE_HOR_ON_OFF_VALUE_TENSOR_DATA_TYPE_INVALID; + } + return NNACL_OK; +} + +int OneHotInitParamsAndOnOffValue(OneHotStruct *one_hot) { + TensorC *depth_tensor = one_hot->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(depth_tensor); + + if (depth_tensor->data_type_ == kNumberTypeInt32) { + const int *depth = (int *)depth_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(depth); + one_hot->depth_ = *depth; + } else { + return NNACL_ONE_HOR_DEPTH_TENSOR_DATA_TYPE_INVALID; + } + + if (one_hot->base_.in_size_ == FOUR_TENSOR) { + // 4 inputs: indices, depth, on_value, off_value + one_hot->support_neg_index_ = false; + int ret = OneHotInitOnOffValueForFourInputs(one_hot); + if (ret != NNACL_OK) { + return ret; + } + } else { + // 3 inputs: indices, depth, off_on_value + one_hot->support_neg_index_ = true; + int ret = OneHotInitOnOffValueForThreeInputs(one_hot); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int OneHotCompute(KernelBase *self) { + OneHotStruct *one_hot = (OneHotStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(one_hot); + int ret = OneHotInitParamsAndOnOffValue(one_hot); + if (ret != NNACL_OK) { + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, OneHotRun, self, self->thread_nr_); + if (ret != NNACL_OK) { + return ret; + } + + return NNACL_OK; +} + +int OneHotPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ != FOUR_TENSOR && self->in_size_ != THREE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ != ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + TypeIdC data_type = self->in_[FIRST_INPUT]->data_type_; + NNACL_CHECK_FALSE(data_type != kNumberTypeInt32 && data_type != kNumberTypeInt64, NNACL_OUTPUT_TENSOR_ERROR); + return NNACL_OK; +} + +int OneHotResize(KernelBase *self) { + OneHotStruct *one_hot = (OneHotStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(one_hot); + + TensorC *indices = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(indices); + + int origin_axis = ((OneHotParameter *)self->param_)->axis_; + one_hot->axis_ = origin_axis < 0 ? origin_axis + (int)indices->shape_size_ + 1 : origin_axis; + NNACL_CHECK_FALSE(one_hot->axis_ < 0 && one_hot->axis_ > (int)indices->shape_size_, NNACL_ONE_HOT_AXIS_INVALID); + + one_hot->outer_size_ = 1; + for (int i = 0; i < one_hot->axis_; i++) { + one_hot->outer_size_ *= indices->shape_[i]; + } + if (one_hot->outer_size_ == 0) { + return NNACL_ONE_HOT_OUTER_SIZE_INVALID; + } + one_hot->inner_size_ = NNACLGetElementNum(indices) / one_hot->outer_size_; + NNACL_CHECK_FALSE(one_hot->inner_size_ <= 0, NNACL_ONE_HOT_INNER_SIZE_INVALID); + + self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_OneHot), one_hot->inner_size_, one_hot->outer_size_, + NNACLGetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_); + return NNACL_OK; +} + +KernelBase *CreateOneHot(OpParameter *param, int data_type) { + OneHotStruct *one_hot = (OneHotStruct *)malloc(sizeof(OneHotStruct)); + NNACL_CHECK_NULL_RETURN_NULL(one_hot); + one_hot->base_.Release = DefaultRelease; + one_hot->base_.Prepare = OneHotPrepare; + one_hot->base_.Resize = OneHotResize; + one_hot->base_.Compute = OneHotCompute; + return (KernelBase *)one_hot; +} + +REG_KERNEL_CREATOR(PrimType_OneHot, kNumberTypeInt32, CreateOneHot) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/one_hot.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/one_hot.h new file mode 100644 index 0000000000000000000000000000000000000000..d945485b1f092b119385240abe94b586d74ec396 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/one_hot.h @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ONE_HOT_H_ +#define NNACL_KERNEL_ONE_HOT_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct { + KernelBase base_; + int axis_; + int depth_; + int outer_size_; + int inner_size_; + bool support_neg_index_; + float on_value_; + float off_value_; +} OneHotStruct; + +KernelBase *CreateOneHot(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ONE_HOT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ones_like.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ones_like.c new file mode 100644 index 0000000000000000000000000000000000000000..fff44c1579d7c6524fa4e95f50258e878547dc97 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ones_like.c @@ -0,0 +1,67 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/ones_like.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" + +#define ApproximateOnesLike(output, data_size) \ + for (size_t i = 0; i < data_size; ++i) { \ + output[i] = 1; \ + } + +int OnesLikeCompute(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + void *output_ptr = output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + size_t num = (size_t)NNACLGetElementNum(output_tensor); + + if (output_tensor->data_type_ == kNumberTypeFloat32) { + float *output = (float *)output_ptr; + ApproximateOnesLike(output, num); + return NNACL_OK; + } +#ifdef ENABLE_FP16 + if (output_tensor->data_type_ == kNumberTypeFloat16) { + float16_t *output = (float16_t *)output_ptr; + ApproximateOnesLike(output, num); + return NNACL_OK; + } +#endif + if (output_tensor->data_type_ == kNumberTypeInt32) { + int *output = (int *)output_ptr; + ApproximateOnesLike(output, num); + return NNACL_OK; + } + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +KernelBase *CreateOnesLike(OpParameter *param, int data_type) { + OnesLikeStruct *ones_like = (OnesLikeStruct *)malloc(sizeof(OnesLikeStruct)); + NNACL_CHECK_NULL_RETURN_NULL(ones_like); + ones_like->data_type_ = data_type; + ones_like->base_.Release = DefaultRelease; + ones_like->base_.Prepare = DefaultPrepare1In1Out; + ones_like->base_.Resize = DefaultResize; + ones_like->base_.Compute = OnesLikeCompute; + return (KernelBase *)ones_like; +} + +REG_KERNEL_CREATOR(PrimType_OnesLike, kNumberTypeInt32, CreateOnesLike) +REG_KERNEL_CREATOR(PrimType_OnesLike, kNumberTypeFloat32, CreateOnesLike) +REG_KERNEL_CREATOR(PrimType_OnesLike, kNumberTypeFloat16, CreateOnesLike) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ones_like.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ones_like.h new file mode 100644 index 0000000000000000000000000000000000000000..1027b72094b6f903b58453a9078dff15f490985a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ones_like.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ONES_LIKE_H_ +#define NNACL_KERNEL_ONES_LIKE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct OnesLikeStruct { + KernelBase base_; + int data_type_; +} OnesLikeStruct; + +KernelBase *CreateOnesLike(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ONES_LIKE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pad.c new file mode 100644 index 0000000000000000000000000000000000000000..0f742ddb4923c590226acf666da1b5265d5c0903 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pad.c @@ -0,0 +1,406 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/pad.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/common_func.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/pad_fp16.h" +#endif +#include "nnacl_c/fp32/pad_fp32.h" + +int PadInitMirrorPadBlock(PadStruct *pad) { + int left_pads[DEFAULT_PAD_NDIMS] = {0}; + for (size_t i = 0; i < DEFAULT_PAD_NDIMS; ++i) { + left_pads[i] = pad->paddings_[Num2 * i]; + } + + int input_separate_dims[DEFAULT_PAD_NDIMS] = {0}; + int output_separate_dims[DEFAULT_PAD_NDIMS] = {0}; + int separate_offset[DEFAULT_PAD_NDIMS] = {0}; + int separate_size = 0; + + /* init separate dims */ + for (size_t i = 0; i < DEFAULT_PAD_NDIMS; ++i) { + input_separate_dims[separate_size] = pad->in_[i]; + output_separate_dims[separate_size] = pad->out_[i]; + separate_offset[separate_size] = left_pads[i]; + separate_size++; + } + + /* init separate stride */ + int output_separate_stride[DEFAULT_PAD_NDIMS] = {0}; + (void)GetStride(output_separate_stride, output_separate_dims, separate_size); + int remain_stride_size = 0; + int remain_size = 1; + int right_pads[DEFAULT_PAD_NDIMS] = {0}; + for (size_t i = 0; i < DEFAULT_PAD_NDIMS; i++) { + right_pads[i] = output_separate_dims[i] - input_separate_dims[i] - separate_offset[i]; + } + + /* init pad region */ + int pad_region[DEFAULT_PAD_NDIMS] = {0}; + int pad_region_size = 0; + for (int i = remain_stride_size; i < separate_size; ++i) { + int r = 1; + r = (separate_offset[i] > 0) ? (r + 1) : r; + r = (right_pads[i] > 0) ? (r + 1) : r; + pad_region[pad_region_size++] = r; + } + int pad_region_stride[DEFAULT_PAD_NDIMS] = {0}; + int region_size = GetStride(pad_region_stride, pad_region, pad_region_size); + + /* init mirror block info */ + int max_block_size = remain_size * region_size * sizeof(MirrorPadBlock); + pad->mirror_pad_block_ = (MirrorPadBlock *)pad->base_.env_->Alloc(pad->base_.env_->allocator_, max_block_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(pad->mirror_pad_block_); + + // 0: center, 1: left, 2: right + int pad_cord[DEFAULT_PAD_NDIMS] = {0}; + + for (int pos = 0; pos < remain_size; ++pos) { + const int dst_basic_offset = 0; + for (int index = 1; index < region_size; ++index) { + int dst_offset = dst_basic_offset; + int value = index; + for (size_t i = 0; i < pad_region_size && pad_region_stride[i] != 0; ++i) { + NNACL_CHECK_ZERO_RETURN_ERR(pad_region_stride[i]); + pad_cord[i] = value / pad_region_stride[i]; + value = value % pad_region_stride[i]; + } + MirrorPadBlock block; + const int size_offset = DEFAULT_PAD_NDIMS - pad_region_size; + for (size_t i = 0; i < pad_region_size; ++i) { + int di = size_offset + i; + int si = remain_stride_size + i; + if (di >= DEFAULT_PAD_NDIMS) { + continue; + } + switch (pad_cord[i]) { + case Num0: + dst_offset += separate_offset[si] * output_separate_stride[si]; + block.size_[di] = input_separate_dims[si]; + block.out_stride_[di] = output_separate_stride[si]; + break; + case Num2: + dst_offset += (separate_offset[si] + input_separate_dims[si]) * output_separate_stride[si]; + block.size_[di] = right_pads[si]; + block.out_stride_[di] = output_separate_stride[si]; + break; + case Num1: + if (separate_offset[si] > 0) { + block.size_[di] = separate_offset[si]; + block.out_stride_[di] = output_separate_stride[si]; + } else { + dst_offset += (separate_offset[si] + input_separate_dims[si]) * output_separate_stride[si]; + block.size_[di] = right_pads[si]; + block.out_stride_[di] = output_separate_stride[si]; + } + break; + default: + break; + } + } + block.out_offset_ = dst_offset; + pad->mirror_pad_block_[pad->mirror_pad_block_size_++] = block; + } + } + return NNACL_OK; +} + +int PadExtendDims(int *dims, const int *origin_dims, int max_dim, int origin_dim, int init_value) { + NNACL_CHECK_NULL_RETURN_ERR(dims); + NNACL_CHECK_NULL_RETURN_ERR(origin_dims); + for (int i = 0; i < max_dim - origin_dim; ++i) { + dims[i] = init_value; + } + for (int i = max_dim - origin_dim; i < max_dim; ++i) { + dims[i] = origin_dims[i - (max_dim - origin_dim)]; + } + return NNACL_OK; +} + +int PadImpl(void *cdata, int task_id, float l, float r) { + PadStruct *pad = (PadStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(pad); + void *input = pad->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input); + void *output = pad->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output); + + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + PadFp16(input, output, pad->in_, pad->out_, pad->paddings_, task_id, pad->base_.thread_nr_); +#endif + } else { + Pad((float *)input, (float *)output, pad->in_, pad->out_, pad->paddings_, task_id, pad->base_.thread_nr_); + } + return NNACL_OK; +} + +int PadFastMirrorRunImpl(PadStruct *pad, int task_id) { + void *in = pad->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(in); + void *out = pad->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out); + + /* copy center part */ + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + PadFp16((float16_t *)in, (float16_t *)out, pad->in_, pad->out_, pad->paddings_, task_id, pad->base_.thread_nr_); +#endif + } else { + Pad((float *)in, (float *)out, pad->in_, pad->out_, pad->paddings_, task_id, pad->base_.thread_nr_); + } + + /* calculate region part */ + for (int i = task_id; i < pad->mirror_pad_block_size_; i += pad->base_.thread_nr_) { + MirrorPadBlock *block = &pad->mirror_pad_block_[i]; + for (int a = 0; a < block->size_[FIRST_INPUT]; a++) { + int out_a_index = block->out_offset_ + a * block->out_stride_[FIRST_INPUT]; + for (int b = 0; b < block->size_[SECOND_INPUT]; b++) { + int out_b_index = out_a_index + b * block->out_stride_[SECOND_INPUT]; + for (int c = 0; c < block->size_[THIRD_INPUT]; ++c) { + int out_c_index = out_b_index + c * block->out_stride_[THIRD_INPUT]; + for (int d = 0; d < block->size_[FOURTH_INPUT]; ++d) { + int out_d_index = out_c_index + d * block->out_stride_[FOURTH_INPUT]; + for (int e = 0; e < block->size_[FIFTH_INPUT]; ++e) { + int start_index = out_d_index + e * block->out_stride_[FIFTH_INPUT]; + int end_index = start_index + block->size_[SIXTH_INPUT]; + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + MirrorPadFp16(in, out, pad->in_, pad->in_strides_, pad->out_strides_, pad->paddings_, + pad->mirror_offset_, start_index, end_index); +#endif + } else { + MirrorPad(in, out, pad->in_, pad->in_strides_, pad->out_strides_, pad->paddings_, pad->mirror_offset_, + start_index, end_index); + } + } + } + } + } + } + } + return NNACL_OK; +} + +int MirrorPadImpl(void *cdata, int task_id, float l, float r) { + PadStruct *pad = (PadStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(pad); + + /* Fast Mirror pad */ + if (pad->mirror_pad_block_size_ != 0) { + return PadFastMirrorRunImpl(pad, task_id); + } + + TensorC *input = pad->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + void *input_data = input->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + TensorC *output = pad->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + void *output_data = output->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + /* Common Mirror pad */ + int unit = UP_DIV(NNACLGetElementNum(output), pad->base_.thread_nr_); + int begin = unit * task_id; + int end = NNACL_MIN(begin + unit, NNACLGetElementNum(output)); + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + MirrorPadFp16((float16_t *)input_data, (float16_t *)output_data, pad->in_, pad->in_strides_, pad->out_strides_, + pad->paddings_, pad->mirror_offset_, begin, end); +#endif + } else { + MirrorPad((float *)input_data, (float *)output_data, pad->in_, pad->in_strides_, pad->out_strides_, pad->paddings_, + pad->mirror_offset_, begin, end); + } + return NNACL_OK; +} + +int PadCheckPaddings(const int *paddings, int length, const int *input_shape, int mode) { + NNACL_CHECK_NULL_RETURN_ERR(paddings); + NNACL_CHECK_NULL_RETURN_ERR(input_shape); + int offset = mode == PaddingMode_Symmetric ? 0 : 1; + for (int i = 0; i < length; ++i) { + int max_valid = input_shape[i] - offset; + if (paddings[i * Num2] > max_valid) { + return NNACL_PAD_MIRROR_PAD_SIZE_INVALID; + } + if (paddings[i * Num2 + 1] > max_valid) { + return NNACL_PAD_MIRROR_PAD_SIZE_INVALID; + } + } + return NNACL_OK; +} + +int PadCopyPaddingFromInput(PadStruct *pad) { + TensorC *input_tensor = pad->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *padding_tensor = pad->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(padding_tensor); + int *padding_data = padding_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(padding_data); + + (void)PadExtendDims(pad->in_, input_tensor->shape_, DEFAULT_PAD_NDIMS, input_tensor->shape_size_, 1); + (void)PadExtendDims(pad->paddings_, padding_data, MAX_PAD_SIZE, NNACLGetElementNum(padding_tensor), 0); + pad->paddings_size_ = MAX_PAD_SIZE; + + return NNACL_OK; +} + +void PadCalculateStrides(PadStruct *pad) { + pad->in_strides_[DEFAULT_PAD_NDIMS - 1] = 1; + for (int i = DEFAULT_PAD_NDIMS - Num2; i >= 0; --i) { + pad->in_strides_[i] = pad->in_[i + 1] * pad->in_strides_[i + 1]; + } + for (int i = 0; i < DEFAULT_PAD_NDIMS; ++i) { + pad->out_[i] = pad->in_[i] + pad->paddings_[i * Num2] + pad->paddings_[i * Num2 + 1]; + } + pad->out_strides_[DEFAULT_PAD_NDIMS - 1] = 1; + for (int i = DEFAULT_PAD_NDIMS - Num2; i >= 0; --i) { + pad->out_strides_[i] = pad->out_[i + 1] * pad->out_strides_[i + 1]; + } +} + +int PadHandleMirrorPad(PadStruct *pad) { + pad->mirror_offset_ = pad->pad_mode_ == PaddingMode_Reflect ? 1 : 0; + (void)PadCheckPaddings(pad->paddings_, DEFAULT_PAD_NDIMS, pad->in_, pad->pad_mode_); + PadCalculateStrides(pad); + return PadInitMirrorPadBlock(pad); +} + +int PadCompute(KernelBase *self) { + PadStruct *pad = (PadStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(pad); + + if (self->in_size_ == THREE_TENSOR) { + TensorC *pad_value_tensor = self->in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(pad_value_tensor); + NNACL_CHECK_FALSE(NNACLGetElementNum(pad_value_tensor) != 1, NNACL_PAD_PADDING_VALID_INVALID); + void *pad_valud = pad_value_tensor->data_; + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + pad->constant_value_ = ((float16_t *)pad_valud)[Index0]; +#endif + } else { + pad->constant_value_ = ((float *)pad_valud)[Index0]; + } + } + + int ret = PadCopyPaddingFromInput(pad); + if (ret != NNACL_OK) { + return ret; + } + + if (pad->pad_mode_ == PaddingMode_Constant) { + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + size_t output_size = NNACLGetElementNum(output); + void *output_data = output->data_; + if (fabsf(pad->constant_value_ - 0.0f) < 1e-5) { + memset(output_data, 0, output_size * (int)DataTypeCSize(pad->data_type_)); + } else { + for (size_t i = 0; i < output_size; ++i) { + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + ((float16_t *)output_data)[i] = pad->constant_value_; +#endif + } else { + ((float *)output_data)[i] = pad->constant_value_; + } + } + } + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, PadImpl, self, self->thread_nr_); + return ret; + } + + /* not constant pad mod using mirror pad algorithm */ + ret = PadHandleMirrorPad(pad); + if (ret != NNACL_OK) { + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, MirrorPadImpl, self, self->thread_nr_); + + self->env_->Free(self->env_->allocator_, pad->mirror_pad_block_); + pad->mirror_pad_block_ = NULL; + pad->mirror_pad_block_size_ = 0; + return ret; +} + +int PadResize(KernelBase *self) { + PadStruct *pad = (PadStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(pad); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *padding = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(padding); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + int rank = input->shape_size_; + NNACL_CHECK_FALSE(input->shape_size_ > DEFAULT_PAD_NDIMS, NNACL_PAD_SHAPE_INVALID); + NNACL_CHECK_FALSE(NNACLGetElementNum(padding) != rank + rank, NNACL_PAD_SHAPE_INVALID); + + if (pad->pad_mode_ == PaddingMode_Constant) { + (void)PadExtendDims(pad->in_, input->shape_, DEFAULT_PAD_NDIMS, rank, 1); + (void)PadExtendDims(pad->out_, output->shape_, DEFAULT_PAD_NDIMS, rank, 1); + + if (pad->paddings_size_ < MAX_PAD_SIZE) { + int ori_paddings[MAX_PAD_SIZE]; + memcpy(ori_paddings, pad->paddings_, MAX_PAD_SIZE * sizeof(int)); + (void)PadExtendDims(pad->paddings_, ori_paddings, MAX_PAD_SIZE, pad->paddings_size_, 0); + pad->paddings_size_ = MAX_PAD_SIZE; + } + } + return NNACL_OK; +} + +int PadPrepare(KernelBase *self) { + NNACL_CHECK_TRUE_RET(self->in_size_ == TWO_TENSOR || self->in_size_ == THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_TRUE_RET(self->out_size_ == ONE_TENSOR, NNACL_ERR); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_FALSE(input->data_type_ != kNumberTypeFloat32 && input->data_type_ != kNumberTypeFloat16, NNACL_ERR); + return NNACL_OK; +} + +KernelBase *CreatePad(OpParameter *param, int data_type) { + PadStruct *pad = (PadStruct *)malloc(sizeof(PadStruct)); + NNACL_CHECK_NULL_RETURN_NULL(pad); + memset(pad, 0, sizeof(PadStruct)); + + pad->data_type_ = data_type; + + PadParameter *pad_param = (PadParameter *)param; + pad->pad_mode_ = pad_param->pad_mode_; + pad->constant_value_ = pad_param->constant_value_; + pad->paddings_size_ = pad_param->padding_length; + memcpy(pad->paddings_, pad_param->paddings_, MAX_PAD_SIZE * sizeof(int)); + + pad->base_.Release = DefaultRelease; + pad->base_.Prepare = PadPrepare; + pad->base_.Resize = PadResize; + pad->base_.Compute = PadCompute; + return (KernelBase *)pad; +} + +REG_KERNEL_CREATOR(PrimType_PadFusion, kNumberTypeFloat32, CreatePad) +REG_KERNEL_CREATOR(PrimType_PadFusion, kNumberTypeFloat16, CreatePad) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pad.h new file mode 100644 index 0000000000000000000000000000000000000000..157e3c87c5ee6cc5a788282f5599e39dd34e1ac9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pad.h @@ -0,0 +1,51 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_PAD_H_ +#define NNACL_KERNEL_PAD_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/pad_parameter.h" + +typedef struct MirrorPadBlock { + int out_offset_; + int out_stride_[DEFAULT_PAD_NDIMS]; + int size_[DEFAULT_PAD_NDIMS]; +} MirrorPadBlock; + +typedef struct PadStruct { + KernelBase base_; + int data_type_; + int mirror_offset_; + float constant_value_; + int pad_mode_; + int paddings_[MAX_PAD_SIZE]; + int paddings_size_; + int in_[DEFAULT_PAD_NDIMS]; + int out_[DEFAULT_PAD_NDIMS]; + int in_strides_[DEFAULT_PAD_NDIMS]; + int out_strides_[DEFAULT_PAD_NDIMS]; + MirrorPadBlock *mirror_pad_block_; + int mirror_pad_block_size_; +} PadStruct; + +KernelBase *CreatePad(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_PAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pooling.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pooling.c new file mode 100644 index 0000000000000000000000000000000000000000..752021ce336a0c1dcaca87f3ba039996dadb0701 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pooling.c @@ -0,0 +1,159 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/pooling.h" +#include +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/fp32/pooling_fp32.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/pooling_fp16.h" +#endif + +int PoolingF16RunImpl(PoolingStruct *pooling, int task_id) { +#ifdef ENABLE_FP16 + PoolingParameter *param = (PoolingParameter *)pooling->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + float16_t *input_ptr = (float16_t *)pooling->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + float16_t *output_ptr = (float16_t *)pooling->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + + if (param->pool_mode_ == PoolMode_MaxPool) { + MaxPoolingFp16(input_ptr, output_ptr, param, &pooling->compute_, task_id, pooling->base_.thread_nr_); + return NNACL_OK; + } else { + return AvgPoolingFp16(input_ptr, output_ptr, param, &pooling->compute_, task_id, pooling->base_.thread_nr_); + } +#endif + return NNACL_DISABLE_FP16; +} + +int PoolingRunImpl(PoolingStruct *pooling, int task_id) { + PoolingParameter *param = (PoolingParameter *)pooling->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + TensorC *input_tensor = pooling->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + float *input_ptr = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + float *output_ptr = (float *)pooling->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + + if (input_tensor->format_ == Format_NC4HW4) { + if (param->pool_mode_ == PoolMode_MaxPool) { + return MaxPoolingFromNC4HW4ToNHWC(input_ptr, output_ptr, param, &pooling->compute_, task_id, + pooling->base_.thread_nr_); + } else { + return AvgPoolingFromNC4HW4ToNHWC(input_ptr, output_ptr, param, &pooling->compute_, task_id, + pooling->base_.thread_nr_); + } + } else if (input_tensor->format_ == Format_NHWC) { + if (param->pool_mode_ == PoolMode_MaxPool) { + return MaxPooling(input_ptr, output_ptr, param, &pooling->compute_, task_id, pooling->base_.thread_nr_); + } else { + return AvgPooling(input_ptr, output_ptr, param, &pooling->compute_, task_id, pooling->base_.thread_nr_); + } + } + + return NNACL_UNSUPPORTED_FORMAT; +} + +int PoolingImpl(void *cdata, int task_id, float l, float r) { + PoolingStruct *pooling = (PoolingStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(cdata); + if (pooling->data_type_ == kNumberTypeFloat16) { + return PoolingF16RunImpl(pooling, task_id); + } else if (pooling->data_type_ == kNumberTypeFloat32) { + return PoolingRunImpl(pooling, task_id); + } + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +int PoolingCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, PoolingImpl, self, self->thread_nr_); +} + +int PoolingResize(KernelBase *self) { + PoolingStruct *pooling = (PoolingStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(pooling); + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + + PoolingComputeParam *compute = &pooling->compute_; + PoolingParameter *param = (PoolingParameter *)self->param_; + + compute->input_batch_ = NNACLGetBatch(in_tensor); + compute->input_channel_ = NNACLGetChannel(in_tensor); + compute->input_h_ = NNACLGetHeight(in_tensor); + compute->input_w_ = NNACLGetWidth(in_tensor); + compute->output_batch_ = NNACLGetBatch(out_tensor); + compute->output_channel_ = NNACLGetChannel(out_tensor); + compute->output_h_ = NNACLGetHeight(out_tensor); + compute->output_w_ = NNACLGetWidth(out_tensor); + compute->window_h_ = param->window_h_; + compute->window_w_ = param->window_w_; + if (param->global_) { + compute->window_h_ = compute->input_h_; + compute->window_w_ = compute->input_w_; + } + return NNACL_OK; +} + +int PoolingPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < 1, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < 1, NNACL_ERR); + + PoolingStruct *pooling = (PoolingStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(pooling); + PoolingParameter *param = (PoolingParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + float minf = pooling->data_type_ == kNumberTypeFloat32 ? -FLT_MAX : -FLT16_MAX; + float maxf = pooling->data_type_ == kNumberTypeFloat32 ? FLT_MAX : FLT16_MAX; + + if (param->act_type_ == ActType_Relu) { + minf = 0.f; + } else if (param->act_type_ == ActType_Relu6) { + minf = 0.f; + maxf = 6.f; + } + pooling->compute_.minf = minf; + pooling->compute_.maxf = maxf; + + return NNACL_OK; +} + +KernelBase *CreatePooling(OpParameter *param, int data_type) { + PoolingStruct *pooling = (PoolingStruct *)malloc(sizeof(PoolingStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(pooling); + memset(pooling, 0, sizeof(PoolingStruct)); + pooling->data_type_ = data_type; + pooling->base_.Release = DefaultRelease; + pooling->base_.Prepare = PoolingPrepare; + pooling->base_.Resize = PoolingResize; + pooling->base_.Compute = PoolingCompute; + return (KernelBase *)pooling; +} + +REG_KERNEL_CREATOR(PrimType_AvgPoolFusion, kNumberTypeFloat16, CreatePooling) +REG_KERNEL_CREATOR(PrimType_MaxPoolFusion, kNumberTypeFloat16, CreatePooling) +REG_KERNEL_CREATOR(PrimType_AvgPoolFusion, kNumberTypeFloat32, CreatePooling) +REG_KERNEL_CREATOR(PrimType_MaxPoolFusion, kNumberTypeFloat32, CreatePooling) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pooling.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..7a95f0fc1e3126b8636a6fe4037cddf25d024c56 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pooling.h @@ -0,0 +1,54 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_POOLING_H_ +#define NNACL_KERNEL_POOLING_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct PoolingComputeParam { + int input_w_; + int input_h_; + int input_batch_; + int input_channel_; + int output_w_; + int output_h_; + int output_batch_; + int output_channel_; + int window_w_; + int window_h_; + float minf; + float maxf; +} PoolingComputeParam; + +typedef struct Pooling3DComputeParam { + PoolingComputeParam pooling_compute_param_; + int input_d_; + int output_d_; + int window_d_; +} Pooling3DComputeParam; + +typedef struct PoolingStruct { + KernelBase base_; + PoolingComputeParam compute_; + int data_type_; +} PoolingStruct; + +KernelBase *CreatePooling(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_POOLING_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pow.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pow.c new file mode 100644 index 0000000000000000000000000000000000000000..39198cd3ad5034f7811f5e0d44cca28e5ed4a3db --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pow.c @@ -0,0 +1,79 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/pow.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/fp32/power_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/power_fp16.h" +#endif + +int PowImpl(void *cdata, int task_id, float l, float r) { + PowStruct *pow = (PowStruct *)cdata; + TensorC *input0 = pow->base_.in_[FIRST_INPUT]; + TensorC *input1 = pow->base_.in_[SECOND_INPUT]; + TensorC *output = pow->base_.out_[OUTPUT_INDEX]; + + int size = NNACLGetElementNum(input0); + int stride = UP_DIV(size, pow->base_.thread_nr_); + int len = MSMIN(stride, size - stride * task_id); + if (len <= 0) { + return NNACL_OK; + } + bool broadcast = !ShapeEqual(input0->shape_, input0->shape_size_, input1->shape_, input1->shape_size_); + float scale = ((PowParameter *)pow->base_.param_)->scale_; + float shift = ((PowParameter *)pow->base_.param_)->shift_; + int task_stride = stride * task_id; + + uint8_t *exp_addr = (uint8_t *)input1->data_; + void *cur_exp = NULL; + if (broadcast) { + cur_exp = exp_addr; + } else { + cur_exp = exp_addr + task_stride * DataTypeCSize(pow->data_type_); + } + + if (pow->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + return PowerFp16((float16_t *)input0->data_ + task_stride, (float16_t *)cur_exp, + (float16_t *)output->data_ + task_stride, len, scale, shift, broadcast); +#endif + } else if (pow->data_type_ == kNumberTypeFloat32) { + return Power((float *)input0->data_ + task_stride, (float *)cur_exp, (float *)output->data_ + task_stride, len, + scale, shift, broadcast); + } + return NNACL_POW_INVALID_DATA_TYPE; +} + +int PowCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, PowImpl, self, self->thread_nr_); +} + +KernelBase *CreatePow(OpParameter *param, int data_type) { + PowStruct *pow = (PowStruct *)malloc(sizeof(PowStruct)); + NNACL_CHECK_NULL_RETURN_NULL(pow); + pow->data_type_ = data_type; + pow->base_.Release = DefaultRelease; + pow->base_.Prepare = DefaultPrepare2In1Out; + pow->base_.Resize = DefaultResize; + pow->base_.Compute = PowCompute; + return (KernelBase *)pow; +} + +REG_KERNEL_CREATOR(PrimType_PowFusion, kNumberTypeFloat32, CreatePow) +REG_KERNEL_CREATOR(PrimType_PowFusion, kNumberTypeFloat16, CreatePow) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pow.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pow.h new file mode 100644 index 0000000000000000000000000000000000000000..b77887222b062203fa52a01a3438456e6ad24990 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pow.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_POW_H_ +#define NNACL_KERNEL_POW_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct PowStruct { + KernelBase base_; + int data_type_; +} PowStruct; + +KernelBase *CreatePow(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_POW_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prelu.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prelu.c new file mode 100644 index 0000000000000000000000000000000000000000..72499684c5d1aa02476775dd57082884b33863c0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prelu.c @@ -0,0 +1,111 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/prelu.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/fp32/prelu_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/prelu_fp16.h" +#endif + +int PReluRun(void *cdata, int task_id, float l, float r) { + PReluStruct *prelu = (PReluStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(prelu); + + int thread_num = prelu->base_.thread_nr_; + int num = prelu->channel_shared_ ? prelu->input_num_ : prelu->input_num_ / prelu->channel_num_; + int step = UP_DIV(num, thread_num); + int start = task_id * step; + int end = MSMIN(start + step, num); + + void *in_data = prelu->base_.in_[FIRST_INPUT]->data_; + void *out_data = prelu->base_.out_[OUTPUT_INDEX]->data_; + void *slope_data = prelu->base_.in_[SECOND_INPUT]->data_; + + if (prelu->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + if (prelu->channel_shared_) { + PReluShareChannelFp16((float16_t *)in_data, (float16_t *)out_data, ((float16_t *)slope_data)[0], start, end); + } else { + PReluFp16((float16_t *)in_data, (float16_t *)out_data, (float16_t *)slope_data, start, end, prelu->channel_num_); + } +#endif + } else { + if (prelu->channel_shared_) { + PReluShareChannel((float *)in_data, (float *)out_data, ((float *)slope_data)[0], start, end); + } else { + PRelu((float *)in_data, (float *)out_data, (float *)slope_data, start, end, prelu->channel_num_); + } + } + return NNACL_OK; +} + +int PReluPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int PReluResize(KernelBase *self) { + PReluStruct *prelu = (PReluStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(prelu); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + prelu->input_num_ = NNACLGetElementNum(input); + prelu->channel_num_ = NNACLGetChannel(input); + return NNACL_OK; +} + +int PReluCompute(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[SECOND_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[SECOND_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]->data_); + PReluStruct *prelu = (PReluStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(prelu); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *slope = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(slope); + + int slope_num = NNACLGetElementNum(slope); + if (slope_num == Num1) { + prelu->channel_shared_ = true; + } else if (slope_num == NNACLGetChannel(input)) { + prelu->channel_shared_ = false; + } else { + return NNACL_PRELU_SLOPE_NUM_INVALID; + } + return self->env_->ParallelLaunch(self->env_->thread_pool_, PReluRun, self, self->thread_nr_); +} + +KernelBase *CreatePRelu(OpParameter *param, int data_type) { + PReluStruct *prelu = (PReluStruct *)malloc(sizeof(PReluStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(prelu); + memset(prelu, 0, sizeof(PReluStruct)); + prelu->data_type_ = data_type; + prelu->base_.Prepare = PReluPrepare; + prelu->base_.Resize = PReluResize; + prelu->base_.Compute = PReluCompute; + prelu->base_.Release = DefaultRelease; + return (KernelBase *)prelu; +} + +REG_KERNEL_CREATOR(PrimType_PReLUFusion, kNumberTypeFloat16, CreatePRelu) +REG_KERNEL_CREATOR(PrimType_PReLUFusion, kNumberTypeFloat32, CreatePRelu) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prelu.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prelu.h new file mode 100644 index 0000000000000000000000000000000000000000..e38d1d939be48e951486aa04ffd4249fd1b108c8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prelu.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_PRELU_H_ +#define NNACL_KERNEL_PRELU_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct PReluStruct { + KernelBase base_; + int data_type_; + int input_num_; + int channel_num_; + bool channel_shared_; +} PReluStruct; + +KernelBase *CreatePRelu(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_PRELU_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prior_box.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prior_box.c new file mode 100644 index 0000000000000000000000000000000000000000..05438d9060ab564beba0d7d6e957911f020bae1e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prior_box.c @@ -0,0 +1,190 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/prior_box.h" +#include +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/prior_box_fp32.h" +#include "nnacl_c/tensor_c_utils.h" + +int PriorBoxInitOutput(PriorBoxStruct *prior_box, const PriorBoxParameter *param, const float *different_aspect_ratios, + int different_aspect_ratios_size) { + for (int i = 0; i < prior_box->fmap_h_; i++) { + float cy = i + param->offset; + for (int j = 0; j < prior_box->fmap_w_; j++) { + float cx = j + param->offset; + for (int32_t k = 0; k < param->min_sizes_size; k++) { + float min = param->min_sizes[k]; + prior_box->output_[prior_box->output_size_++] = (cx - min / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy - min / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + prior_box->output_[prior_box->output_size_++] = (cx + min / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy + min / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + + if (param->max_sizes_size > 0) { + float max = param->max_sizes[k]; + NNACL_CHECK_FALSE(min * max <= 0, NNACL_PRIOR_BOX_VALUE_INVALID); + float prime = sqrt(min * max); + prior_box->output_[prior_box->output_size_++] = (cx - prime / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy - prime / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + prior_box->output_[prior_box->output_size_++] = (cx + prime / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy + prime / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + } + + for (int m = 0; m < different_aspect_ratios_size; m++) { + float v = different_aspect_ratios[m]; + if (fabs(v - 1.0f) < 1e-6) { + continue; + } + NNACL_CHECK_FALSE(v <= 0, NNACL_PRIOR_BOX_VALUE_INVALID); + float as_square_root = sqrt(v); + NNACL_CHECK_FALSE(as_square_root <= 0, NNACL_PRIOR_BOX_VALUE_INVALID); + float box_w = min * as_square_root; + float box_h = min / as_square_root; + prior_box->output_[prior_box->output_size_++] = (cx - box_w / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy - box_h / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + prior_box->output_[prior_box->output_size_++] = (cx + box_w / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy + box_h / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + } + } + } + } + return NNACL_OK; +} + +int RunPriorBox(void *cdata, int task_id, float l, float r) { + PriorBoxStruct *prior_box = (PriorBoxStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(prior_box); + TensorC *output_tensor = prior_box->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + float *output_data = output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + return PriorBox(prior_box->output_, output_data, NNACLGetSize(output_tensor), task_id, prior_box->base_.thread_nr_); +} + +int PriorBoxRelease(KernelBase *self) { + PriorBoxStruct *prior_box = (PriorBoxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(prior_box); + if (prior_box->output_ != NULL) { + self->env_->Free(self->env_->allocator_, prior_box->output_); + prior_box->output_ = NULL; + prior_box->output_size_ = 0; + } + return NNACL_OK; +} + +int PriorBoxResize(KernelBase *self) { + PriorBoxStruct *prior_box = (PriorBoxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(prior_box); + PriorBoxParameter *param = (PriorBoxParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + TensorC *input0_tensor = prior_box->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input0_tensor); + TensorC *input1_tensor = prior_box->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input1_tensor); + TensorC *output_tensor = prior_box->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + + prior_box->fmap_w_ = NNACLGetWidth(input0_tensor); + NNACL_CHECK_ZERO_RETURN_ERR(prior_box->fmap_w_); + prior_box->fmap_h_ = NNACLGetHeight(input1_tensor); + NNACL_CHECK_ZERO_RETURN_ERR(prior_box->fmap_h_); + const int image_w = param->image_size_w > 0 ? param->image_size_w : NNACLGetWidth(input1_tensor); + const int image_h = param->image_size_h > 0 ? param->image_size_h : NNACLGetHeight(input1_tensor); + + prior_box->step_w_ = param->step_w > 0.0f ? param->step_w : (float)(image_w) / prior_box->fmap_w_; + prior_box->step_h_ = param->step_h > 0.0f ? param->step_h : (float)(image_h) / prior_box->fmap_h_; + + float *different_aspect_ratios = + (float *)self->env_->Alloc(self->env_->allocator_, param->aspect_ratios_size * sizeof(float) * Num2); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(different_aspect_ratios); + different_aspect_ratios[Index0] = 1.0f; + int different_aspect_ratios_size = 1; + + float *aspect_ratios = param->aspect_ratios; + for (int32_t i = 0; i < param->aspect_ratios_size; i++) { + float ratio = aspect_ratios[i]; + + bool exist = false; + for (int k = 0; k < different_aspect_ratios_size; k++) { + if (fabs(ratio - different_aspect_ratios[k]) < 1e-6) { + exist = true; + } + } + + if (!exist) { + different_aspect_ratios[different_aspect_ratios_size++] = ratio; + if (param->flip) { + NNACL_CHECK_FALSE(fabs(ratio) <= 1e-5, NNACL_PRIOR_BOX_RATIO_INVALID); + different_aspect_ratios[different_aspect_ratios_size++] = 1.0f / ratio; + } + } + } + + PriorBoxRelease(self); + int size = Num4 + Num4 + different_aspect_ratios_size; + size = size * prior_box->fmap_h_ * prior_box->fmap_w_ * param->min_sizes_size; + size = size + UP_ROUND(NNACLGetHeight(output_tensor), COMM_SHAPE_SIZE); + size = size * sizeof(float); + NNACL_CHECK_MALLOC_SIZE(size); + prior_box->output_ = (float *)self->env_->Alloc(self->env_->allocator_, size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(prior_box->output_); + prior_box->output_size_ = 0; + + int ret = PriorBoxInitOutput(prior_box, param, different_aspect_ratios, different_aspect_ratios_size); + if (ret != NNACL_OK) { + return ret; + } + + // do clip + if (param->clip) { + for (int i = 0; i < prior_box->output_size_; i++) { + float item = prior_box->output_[i]; + if (item > 1.0f) { + item = 1.0f; + } + if (item < 0.0f) { + item = 0.0f; + } + } + } + + // variance + for (int i = 0; i < NNACLGetHeight(output_tensor) / COMM_SHAPE_SIZE; i++) { + for (int j = 0; j < COMM_SHAPE_SIZE; j++) { + prior_box->output_[prior_box->output_size_++] = param->variances[j]; + } + } + return NNACL_OK; +} + +int PriorBoxCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, RunPriorBox, self, self->thread_nr_); +} + +KernelBase *CreatePriorBox(OpParameter *param, int data_type) { + PriorBoxStruct *prior_box = (PriorBoxStruct *)malloc(sizeof(PriorBoxStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(prior_box); + memset(prior_box, 0, sizeof(PriorBoxStruct)); + + prior_box->base_.Prepare = DefaultPrepare2In1Out; + prior_box->base_.Resize = PriorBoxResize; + prior_box->base_.Release = PriorBoxRelease; + prior_box->base_.Compute = PriorBoxCompute; + return (KernelBase *)prior_box; +} + +REG_KERNEL_CREATOR(PrimType_PriorBox, kNumberTypeFloat32, CreatePriorBox) +REG_KERNEL_CREATOR(PrimType_PriorBox, kNumberTypeInt8, CreatePriorBox) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prior_box.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prior_box.h new file mode 100644 index 0000000000000000000000000000000000000000..5d728fddb08b569e6dce0264fb42f02a588c2a1a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prior_box.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_PRIOR_BOX_H_ +#define NNACL_KERNEL_PRIOR_BOX_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct PriorBoxStruct { + KernelBase base_; + float *output_; + int output_size_; + int fmap_h_; + int fmap_w_; + float step_h_; + float step_w_; +} PriorBoxStruct; + +KernelBase *CreatePriorBox(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_PRIOR_BOX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ragged_range.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ragged_range.c new file mode 100644 index 0000000000000000000000000000000000000000..a11d6d9019b76bafaf9cfb61a25c151db8d99d30 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ragged_range.c @@ -0,0 +1,74 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/ragged_range.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/ragged_range_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/ragged_range_fp16.h" +#endif + +int RaggedRangeCompute(KernelBase *self) { + RaggedRangeStruct *ragged_range = (RaggedRangeStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(ragged_range); + + TensorC *input0 = self->in_[Index0]; + TensorC *input1 = self->in_[Index1]; + TensorC *input2 = self->in_[Index2]; + TensorC *output0 = self->out_[Index0]; + TensorC *output1 = self->out_[Index1]; + + if (input0->data_type_ == kNumberTypeFloat32) { + RaggedRangeFp32((float *)input0->data_, (float *)input1->data_, (float *)input2->data_, (int *)output0->data_, + (float *)output1->data_, ragged_range); + } else if (input0->data_type_ == kNumberTypeInt32) { + RaggedRangeInt((int *)input0->data_, (int *)input1->data_, (int *)input2->data_, (int *)output0->data_, + (int *)output1->data_, ragged_range); + } else if (input0->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + RaggedRangeFp16((float16_t *)input0->data_, (float16_t *)input1->data_, (float16_t *)input2->data_, + (int *)output0->data_, (float16_t *)output1->data_, ragged_range); +#endif + } else { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +int RaggedRangeResize(KernelBase *self) { + RaggedRangeStruct *ragged_range = (RaggedRangeStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(ragged_range); + + ragged_range->rows_ = self->out_[OUTPUT_INDEX]->shape_[Index0] - 1; + ragged_range->starts_is_scalar_ = self->in_[FIRST_INPUT]->shape_size_ == 0; + ragged_range->limits_is_scalar_ = self->in_[SECOND_INPUT]->shape_size_ == 0; + ragged_range->deltas_is_scalar_ = self->in_[THIRD_INPUT]->shape_size_ == 0; + return NNACL_OK; +} + +KernelBase *CreateRaggedRange(OpParameter *param, int data_type) { + RaggedRangeStruct *ragged_range = (RaggedRangeStruct *)malloc(sizeof(RaggedRangeStruct)); + NNACL_CHECK_NULL_RETURN_NULL(ragged_range); + ragged_range->base_.Release = DefaultRelease; + ragged_range->base_.Prepare = DefaultPrepare3In2Out; + ragged_range->base_.Resize = RaggedRangeResize; + ragged_range->base_.Compute = RaggedRangeCompute; + return (KernelBase *)ragged_range; +} + +REG_KERNEL_CREATOR(PrimType_RaggedRange, kNumberTypeInt32, CreateRaggedRange) +REG_KERNEL_CREATOR(PrimType_RaggedRange, kNumberTypeFloat16, CreateRaggedRange) +REG_KERNEL_CREATOR(PrimType_RaggedRange, kNumberTypeFloat32, CreateRaggedRange) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ragged_range.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ragged_range.h new file mode 100644 index 0000000000000000000000000000000000000000..e19ea067f5a95f8aaaeba51b3b3824ca2bee4882 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ragged_range.h @@ -0,0 +1,35 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_RAGGED_RANGE_H_ +#define NNACL_KERNEL_RAGGED_RANGE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct RaggedRangeStruct { + KernelBase base_; + int rows_; + bool starts_is_scalar_; + bool limits_is_scalar_; + bool deltas_is_scalar_; +} RaggedRangeStruct; + +KernelBase *CreateRaggedRange(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_RAGGED_RANGE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/range.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/range.c new file mode 100644 index 0000000000000000000000000000000000000000..812a50c109b96d4e70bade3f8a8945ada6701be6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/range.c @@ -0,0 +1,74 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/range.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/range_parameter.h" +#include "nnacl_c/fp32/range_fp32.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/range_fp16.h" +#endif + +int RangeCompute(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + int output_num = NNACLGetElementNum(output); + + if (self->in_size_ == THREE_TENSOR) { + TensorC *delta = self->in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(delta); + + if (input->data_type_ == kNumberTypeFloat32) { + Range((float *)output->data_, *(float *)input->data_, *(float *)delta->data_, output_num); + } else if (input->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + RangeFp16((float16_t *)output->data_, *(float16_t *)input->data_, *(float16_t *)delta->data_, output_num); +#endif + } else if (input->data_type_ == kNumberTypeInt32) { + RangeInt((int *)output->data_, *(int *)input->data_, *(int *)delta->data_, output_num); + } else { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + } else { + if (input->data_type_ == kNumberTypeInt32) { + RangeParameter *param = (RangeParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + RangeInt((int *)output->data_, param->start_, param->delta_, output_num); + } else { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + } + return NNACL_OK; +} + +KernelBase *CreateRange(OpParameter *param, int data_type) { + RangeStruct *range = (RangeStruct *)malloc(sizeof(RangeStruct)); + NNACL_CHECK_NULL_RETURN_NULL(range); + range->base_.Release = DefaultRelease; + range->base_.Prepare = DefaultPrepare1In1Out; + range->base_.Resize = DefaultResize; + range->base_.Compute = RangeCompute; + return (KernelBase *)range; +} + +REG_KERNEL_CREATOR(PrimType_Range, kNumberTypeInt32, CreateRange) +REG_KERNEL_CREATOR(PrimType_Range, kNumberTypeFloat32, CreateRange) +REG_KERNEL_CREATOR(PrimType_Range, kNumberTypeFloat16, CreateRange) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/range.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/range.h new file mode 100644 index 0000000000000000000000000000000000000000..32ba24985fb506908bb79a11355bb80ebeb4b61d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/range.h @@ -0,0 +1,31 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_RANGE_H_ +#define NNACL_KERNEL_RANGE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct RangeStruct { + KernelBase base_; +} RangeStruct; + +KernelBase *CreateRange(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_RANGE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/rank.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/rank.c new file mode 100644 index 0000000000000000000000000000000000000000..ce965034068ab2a4d21bed9a076abd686cacd409 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/rank.c @@ -0,0 +1,44 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/rank.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +int RankCompute(KernelBase *self) { + size_t rank = self->in_[FIRST_INPUT]->shape_size_; + void *output_data = self->out_[OUTPUT_INDEX]->data_; + if (self->in_[FIRST_INPUT]->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + *(float16_t *)output_data = (float16_t)rank; +#endif + } else { + *(float *)output_data = (float)rank; + } + return NNACL_OK; +} + +KernelBase *CreateRank(OpParameter *param, int data_type) { + RankStruct *rank = (RankStruct *)malloc(sizeof(RankStruct)); + NNACL_CHECK_NULL_RETURN_NULL(rank); + rank->base_.Release = DefaultRelease; + rank->base_.Prepare = DefaultPrepare1In1Out; + rank->base_.Resize = DefaultResize; + rank->base_.Compute = RankCompute; + return (KernelBase *)rank; +} + +REG_KERNEL_CREATOR(PrimType_Rank, kNumberTypeFloat32, CreateRank) +REG_KERNEL_CREATOR(PrimType_Rank, kNumberTypeFloat16, CreateRank) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/rank.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/rank.h new file mode 100644 index 0000000000000000000000000000000000000000..ef2e55e5c2c02a8d667430d1ba6b521c4f4e4ccf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/rank.h @@ -0,0 +1,31 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_RANK_H_ +#define NNACL_KERNEL_RANK_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct RankStruct { + KernelBase base_; +} RankStruct; + +KernelBase *CreateRank(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_RANK_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.c new file mode 100644 index 0000000000000000000000000000000000000000..4752357a3639307afd74def38ccce5286f3f85c9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.c @@ -0,0 +1,434 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/reduce.h" +#include +#include "nnacl_c/fp32/reduce_fp32.h" +#include "nnacl_c/kernel/reshape.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +void InitialReduceKernelList(KernelBase *base) { + ReduceStruct *reduce = (ReduceStruct *)base; + ReduceParameter *param = (ReduceParameter *)(base->param_); + + ReduceKernelList func_list[] = {{Reduce_Sum, ReduceSum, IntReduceSum, NULL, ReduceSumByLastAxis}, + {Reduce_Mean, ReduceMean, IntReduceMean, NULL, NULL}, + {Reduce_Max, ReduceMax, IntReduceMax, NULL, ReduceMaxByLastAxis}, + {Reduce_Min, ReduceMin, IntReduceMin, NULL, NULL}, + {Reduce_Prod, ReduceProd, IntReduceProd, NULL, NULL}, + {Reduce_SumSquare, ReduceSum, IntReduceSum, NULL, NULL}, + {Reduce_ASum, ReduceSum, IntReduceSum, NULL, NULL}, + {Reduce_All, NULL, NULL, ReduceAll, NULL}, + {Reduce_L2, ReduceL2Norm, NULL, NULL, NULL}}; + + size_t list_len = sizeof(func_list) / sizeof(ReduceKernelList); + for (size_t i = 0; i < list_len; ++i) { + if (param->mode_ == func_list[i].type_) { + reduce->compute_ = func_list[i]; + return; + } + } +} + +int CallReduceUnit(KernelBase *base, int task_id) { + ReduceStruct *reduce = (ReduceStruct *)base; + NNACL_CHECK_NULL_RETURN_ERR(reduce->src_data_); + NNACL_CHECK_NULL_RETURN_ERR(reduce->dst_data_); + + if (reduce->data_type_ == kNumberTypeFloat32) { + if (reduce->inner_size_ == 1 && reduce->compute_.float_last_axis_func_ != NULL) { + return reduce->compute_.float_last_axis_func_(reduce->outer_size_, reduce->inner_size_, reduce->axis_size_, + (float *)(reduce->src_data_), (float *)(reduce->dst_data_), task_id, + reduce->base_.thread_nr_); + } else { + NNACL_CHECK_NULL_RETURN_ERR(reduce->compute_.float_function_); + return reduce->compute_.float_function_(reduce->outer_size_, reduce->inner_size_, reduce->axis_size_, + (float *)(reduce->src_data_), (float *)(reduce->dst_data_), task_id, + reduce->base_.thread_nr_); + } + } + + if (reduce->data_type_ == kNumberTypeBool) { + NNACL_CHECK_NULL_RETURN_ERR(reduce->compute_.bool_function_); + return reduce->compute_.bool_function_(reduce->outer_size_, reduce->inner_size_, reduce->axis_size_, + (bool *)(reduce->src_data_), (bool *)(reduce->dst_data_), task_id, + reduce->base_.thread_nr_); + } + + if (reduce->data_type_ == kNumberTypeInt32) { + NNACL_CHECK_NULL_RETURN_ERR(reduce->compute_.int_function_); + return reduce->compute_.int_function_(reduce->outer_size_, reduce->inner_size_, reduce->axis_size_, + (int *)(reduce->src_data_), (int *)(reduce->dst_data_), task_id, + reduce->base_.thread_nr_); + } + + return NNACL_REDUCE_UNSUPPORTED_DATA_TYPE; +} + +int ReduceImpl(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + ReduceStruct *reduce = (ReduceStruct *)cdata; + return reduce->call_uint_((KernelBase *)reduce, task_id); +} + +int CopyReduceyInputToOutput(ReduceStruct *reduce) { + int total_num = NNACLGetElementNum(reduce->base_.in_[FIRST_INPUT]); + NNACL_CHECK_FALSE(total_num == 0, NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID); + int block_num = UP_DIV(total_num, reduce->base_.thread_nr_); + int tmp_thread_num = UP_DIV(total_num, block_num); + NNACL_CHECK_FALSE(tmp_thread_num == 0, NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID); + + ReshapeStruct reshape_struct; + reshape_struct.base_.in_ = reduce->base_.in_; + reshape_struct.base_.out_ = reduce->base_.out_; + reshape_struct.block_num_ = block_num; + reshape_struct.total_num_ = total_num; + reshape_struct.base_.thread_nr_ = tmp_thread_num; + return reduce->base_.env_->ParallelLaunch(reduce->base_.env_->thread_pool_, ParallelReshape, &reshape_struct, + tmp_thread_num); +} + +int MallocReduceTmpBuffer(ReduceStruct *reduce) { + // Clean pointers in data_buffer for free condition checking in FreeReduceTmpBuffer. + memset(reduce->data_buffers_, 0, reduce->data_buffers_size_ * sizeof(void *)); + + for (int i = 0; i < reduce->data_buffers_size_; i++) { + reduce->data_buffers_[i] = reduce->base_.env_->Alloc( + reduce->base_.env_->allocator_, reduce->data_buffer_sizes_[i] * DataTypeCSize(reduce->data_type_)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(reduce->data_buffers_[i]); + } + return NNACL_OK; +} + +void FreeReduceTmpBuffer(ReduceStruct *reduce) { + for (int i = 0; i < reduce->data_buffers_size_; i++) { + if (reduce->data_buffers_[i] != NULL) { + reduce->base_.env_->Free(reduce->base_.env_->allocator_, reduce->data_buffers_[i]); + } + reduce->data_buffers_[i] = NULL; + } +} + +int CalculateReduceCoeffOutput(KernelBase *base) { + ReduceStruct *reduce = (ReduceStruct *)base; + + if (reduce->data_type_ != kNumberTypeFloat32) { + return NNACL_REDUCE_COEFF_DATA_TYPE_INVALID; + } + TensorC *out_tensor = reduce->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + NNACL_CHECK_NULL_RETURN_ERR(out_tensor->data_); + int num = NNACLGetElementNum(out_tensor); + + float *out_data = (float *)out_tensor->data_; + for (int i = 0; i < num; ++i) { + out_data[i] *= ((ReduceParameter *)reduce->base_.param_)->coeff; + } + return NNACL_OK; +} + +void HandleReduceASumAndSumSquare(KernelBase *base) { + ReduceStruct *reduce = (ReduceStruct *)base; + if (reduce->data_type_ == kNumberTypeInt32 || reduce->data_type_ == kNumberTypeBool) { + return; + } + + TensorC *in_tensor = base->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(in_tensor); + float *data = (float *)in_tensor->data_; + NNACL_CHECK_NULL_RETURN_VOID(data); + + int num = NNACLGetElementNum(in_tensor); + + if (((ReduceParameter *)base->param_)->mode_ == Reduce_ASum) { + for (int i = 0; i < num; ++i) { + if (data[i] < 0.0f) { + data[i] = 0.0f - data[i]; + } + } + } + + if (((ReduceParameter *)base->param_)->mode_ == Reduce_SumSquare) { + for (int i = 0; i < num; ++i) { + data[i] = data[i] * data[i]; + } + return; + } +} + +int ReduceCheckInputsOutputs(ReduceStruct *reduce) { + NNACL_CHECK_FALSE(reduce->base_.in_size_ < ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(reduce->base_.out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + for (size_t i = 0; i < reduce->base_.in_size_; i++) { + NNACL_CHECK_NULL_RETURN_ERR(reduce->base_.in_[i]); + } + for (size_t i = 0; i < reduce->base_.out_size_; i++) { + NNACL_CHECK_NULL_RETURN_ERR(reduce->base_.out_[i]); + } + TensorC *input_tensor = reduce->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + if (reduce->base_.in_size_ > ONE_TENSOR) { + TensorC *axes_tensor = reduce->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(axes_tensor); + NNACL_CHECK_FALSE(axes_tensor->data_type_ != kNumberTypeInt && axes_tensor->data_type_ != kNumberTypeInt32 && + axes_tensor->data_type_ != kNumberTypeInt64, + NNACL_REDUCE_AXES_TENSOR_ERROR); + } + return NNACL_OK; +} + +int ReduceCommonPrepare(ReduceStruct *reduce) { + int ret = ReduceCheckInputsOutputs(reduce); + if (ret != NNACL_OK) { + return ret; + } + + if (reduce->base_.in_size_ == ONE_TENSOR) { + reduce->num_axes_ = 0; + return NNACL_OK; + } + + TensorC *axes_tensor = reduce->base_.in_[SECOND_INPUT]; + reduce->num_axes_ = NNACLGetElementNum(axes_tensor); + if (axes_tensor->data_ != NULL && (reduce->num_axes_ <= 0 || reduce->num_axes_ > MAX_SHAPE_SIZE)) { + return NNACL_REDUCE_AXES_TENSOR_ERROR; + } + if (axes_tensor->data_ == NULL) { + reduce->num_axes_ = reduce->base_.in_[FIRST_INPUT]->shape_size_; + for (int i = 0; i < reduce->num_axes_; i++) { + reduce->axes_[i] = i; + } + } else { + if (axes_tensor->data_type_ == kNumberTypeInt32 || axes_tensor->data_type_ == kNumberTypeInt) { + NNACL_CHECK_FALSE(NNACLGetSize(axes_tensor) == 0, NNACL_REDUCE_AXES_TENSOR_ERROR); + (void)memcpy(reduce->axes_, axes_tensor->data_, NNACLGetSize(axes_tensor)); + } else { + int64_t *axes_data = axes_tensor->data_; + for (size_t i = 0; i < reduce->num_axes_; i++) { + reduce->axes_[i] = (int32_t)axes_data[i]; + } + } + } + + return NNACL_OK; +} + +int CheckReduceParameters(ReduceStruct *reduce) { + int input_shape_size = reduce->base_.in_[FIRST_INPUT]->shape_size_; + NNACL_CHECK_FALSE(reduce->num_axes_ > input_shape_size, NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID); + + for (int i = 0; i < reduce->num_axes_; i++) { + NNACL_CHECK_FALSE(reduce->axes_[i] < -input_shape_size, NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID); + NNACL_CHECK_FALSE(reduce->axes_[i] >= input_shape_size, NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID); + + if (reduce->axes_[i] < 0) { + reduce->axes_[i] += input_shape_size; + } + } + + if (((ReduceParameter *)reduce->base_.param_)->reduce_to_end_) { + // actual num of axes to reduce + reduce->num_axes_ = (int)(input_shape_size)-reduce->axes_[0]; + for (int i = 1; i < reduce->num_axes_; ++i) { + reduce->axes_[i] = reduce->axes_[0] + i; + } + } + + if (reduce->num_axes_ == 0) { + for (int i = 0; i < input_shape_size; i++) { + reduce->axes_[i] = i; + } + reduce->num_axes_ = input_shape_size; + } + return NNACL_OK; +} + +void ReduceCalculateInnerOuterSize(ReduceStruct *reduce) { + TensorC *input_tensor = reduce->base_.in_[FIRST_INPUT]; + int tmp_input_shape[MAX_SHAPE_SIZE]; + memcpy(tmp_input_shape, input_tensor->shape_, MAX_SHAPE_SIZE * sizeof(int)); + reduce->offset_size_ = 0; + + for (int i = 0; i < reduce->num_axes_; ++i) { + int axis = reduce->axes_[i]; + int outer_size = 1; + for (int j = 0; j < axis; j++) { + outer_size *= tmp_input_shape[j]; + } + reduce->outer_sizes_[reduce->offset_size_] = outer_size; + + int inner_size = 1; + for (int k = axis + 1; k < input_tensor->shape_size_; k++) { + inner_size *= tmp_input_shape[k]; + } + reduce->inner_sizes_[reduce->offset_size_] = inner_size; + reduce->axis_sizes_[reduce->offset_size_] = tmp_input_shape[axis]; + + reduce->offset_size_++; + tmp_input_shape[axis] = 1; + } +} + +void ReduceCalculateTmpBufferSize(ReduceStruct *reduce) { + reduce->data_buffers_size_ = 0; + + TensorC *input_tensor = reduce->base_.in_[FIRST_INPUT]; + int tmp_input_shape[MAX_SHAPE_SIZE]; + memcpy(tmp_input_shape, input_tensor->shape_, MAX_SHAPE_SIZE * sizeof(int)); + // calculate size of buffer to malloc for each reducing axis + for (int i = 0; i < reduce->num_axes_ - 1; i++) { + int axis = reduce->axes_[i]; + size_t size = 1; + for (size_t j = 0; j < input_tensor->shape_size_; j++) { + if (axis != (int)(j)) { + size *= (size_t)(tmp_input_shape[j]); + } + } + reduce->data_buffer_sizes_[reduce->data_buffers_size_++] = size; + tmp_input_shape[axis] = 1; + } +} + +void ReduceDecideIfOnlyCopy(ReduceStruct *reduce) { + ReduceModeC can_not_copy[] = {Reduce_SumSquare, Reduce_ASum, Reduce_All, Reduce_L2}; + for (int i = 0; i < sizeof(can_not_copy) / sizeof(ReduceModeC); i++) { + if (can_not_copy[i] == ((ReduceParameter *)reduce->base_.param_)->mode_) { + reduce->only_copy_ = false; + return; + } + } + + int *in_shape = reduce->base_.in_[FIRST_INPUT]->shape_; + + for (int i = 0; i < reduce->num_axes_; i++) { + int axis = reduce->axes_[i]; + if (in_shape[axis] != 1) { + reduce->only_copy_ = false; + return; + } + } + reduce->only_copy_ = true; + return; +} + +int ReducePrepare(struct KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + ReduceStruct *reduce = (ReduceStruct *)self; + + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, ONE_TENSOR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, ONE_TENSOR); + + int ret = ReduceCommonPrepare(reduce); + if (ret != NNACL_OK) { + return ret; + } + + reduce->init_kernel_list_(self); + return NNACL_OK; +} + +int ReduceResize(struct KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + ReduceStruct *reduce = (ReduceStruct *)self; + + int ret = CheckReduceParameters(reduce); + if (ret != NNACL_OK) { + return ret; + } + + ReduceDecideIfOnlyCopy(reduce); + ReduceCalculateTmpBufferSize(reduce); + ReduceCalculateInnerOuterSize(reduce); + + if (reduce->num_axes_ == 1) { + self->thread_nr_ = self->UpdateThread( + TC_TYPE(PrimType_ReduceFusion, ((ReduceParameter *)reduce->base_.param_)->mode_), + reduce->inner_sizes_[Index0] * reduce->axis_sizes_[Index0], + reduce->inner_sizes_[Index0] * reduce->axis_sizes_[Index0], reduce->outer_sizes_[Index0], self->thread_nr_); + } else { + self->thread_nr_ = self->UpdateThread(TC_TYPE(PrimType_ReduceFusion, Reduce_Max + 1), 0, 0, + NNACLGetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_); + } + return NNACL_OK; +} + +int ReduceCompute(struct KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + ReduceStruct *reduce = (ReduceStruct *)self; + NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ != reduce->data_type_, NNACL_ERR); + + if (reduce->only_copy_) { + return CopyReduceyInputToOutput(reduce); + } + + int ret = MallocReduceTmpBuffer(reduce); + if (ret != NNACL_OK) { + FreeReduceTmpBuffer(reduce); + return ret; + } + + reduce->src_data_ = self->in_[FIRST_INPUT]->data_; + reduce->handle_sum_square_(self); + for (int i = 0; i < reduce->num_axes_; i++) { + if (i != (reduce->num_axes_ - 1)) { + reduce->dst_data_ = reduce->data_buffers_[i]; + } else { + reduce->dst_data_ = self->out_[FIRST_INPUT]->data_; + } + reduce->outer_size_ = reduce->outer_sizes_[i]; + reduce->inner_size_ = reduce->inner_sizes_[i]; + reduce->axis_size_ = reduce->axis_sizes_[i]; + NNACL_CHECK_FALSE(reduce->axis_size_ == 0, NNACL_REDUCE_AXIS_SIZE_ERROR); + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ReduceImpl, self, self->thread_nr_); + if (ret != NNACL_OK) { + FreeReduceTmpBuffer(reduce); + return ret; + } + reduce->src_data_ = reduce->dst_data_; + } + + ReduceParameter *param = (ReduceParameter *)reduce->base_.param_; + if (param->reduce_to_end_ && fabsf(param->coeff) > 1e-5) { + ret = reduce->calculate_coeff_(self); + } + + FreeReduceTmpBuffer(reduce); + return ret; +} + +KernelBase *CreateReduce(OpParameter *param, int data_type) { + ReduceStruct *reduce = (ReduceStruct *)malloc(sizeof(ReduceStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(reduce); + memset(reduce, 0, sizeof(ReduceStruct)); + reduce->data_type_ = data_type; + reduce->base_.Release = DefaultRelease; + reduce->base_.Prepare = ReducePrepare; + reduce->base_.Resize = ReduceResize; + reduce->base_.Compute = ReduceCompute; + reduce->handle_sum_square_ = HandleReduceASumAndSumSquare; + reduce->calculate_coeff_ = CalculateReduceCoeffOutput; + reduce->init_kernel_list_ = InitialReduceKernelList; + reduce->call_uint_ = CallReduceUnit; + return (KernelBase *)reduce; +} + +REG_KERNEL_CREATOR(PrimType_ReduceFusion, kNumberTypeBool, CreateReduce) +REG_KERNEL_CREATOR(PrimType_ReduceFusion, kNumberTypeInt32, CreateReduce) +REG_KERNEL_CREATOR(PrimType_ReduceFusion, kNumberTypeFloat32, CreateReduce) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..18f756497e2e0f38fee764606aaff3b39533cb19 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.h @@ -0,0 +1,72 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_REDUCE_H_ +#define NNACL_KERNEL_REDUCE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ReduceKernelList { + int type_; + int (*float_function_)(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + float *dst_data, const int tid, const int thread_num); + int (*int_function_)(const int outer_size, const int inner_size, const int axis_size, const int *src_data, + int *dst_data, const int tid, const int thread_num); + int (*bool_function_)(const int outer_size, const int inner_size, const int axis_size, const bool *src_data, + bool *dst_data, const int tid, const int thread_num); + int (*float_last_axis_func_)(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + float *dst_data, const int tid, const int thread_num); +} ReduceKernelList; + +typedef struct ReduceStruct { + KernelBase base_; + bool only_copy_; + int num_axes_; + TypeIdC data_type_; + int axes_[MAX_SHAPE_SIZE]; + + void *data_buffers_[MAX_SHAPE_SIZE]; + size_t data_buffer_sizes_[MAX_SHAPE_SIZE]; + int data_buffers_size_; + ReduceModeC mode_; + + int outer_sizes_[MAX_SHAPE_SIZE]; + int inner_sizes_[MAX_SHAPE_SIZE]; + int axis_sizes_[MAX_SHAPE_SIZE]; + int offset_size_; + + int outer_size_; + int inner_size_; + int axis_size_; + + void *src_data_; + void *dst_data_; + ReduceKernelList compute_; + + void (*handle_sum_square_)(KernelBase *base); + void (*init_kernel_list_)(KernelBase *base); + int (*calculate_coeff_)(KernelBase *base); + int (*call_uint_)(KernelBase *base, int task_id); +} ReduceStruct; + +KernelBase *CreateReduce(OpParameter *param, int data_type); +int ReducePrepare(KernelBase *self); +int ReduceResize(KernelBase *self); +int ReduceCompute(KernelBase *self); + +#endif // NNACL_KERNEL_RESHAPE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reshape.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reshape.c new file mode 100644 index 0000000000000000000000000000000000000000..bc6271e94433a6607a474605ff7570809d08cf26 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reshape.c @@ -0,0 +1,96 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/reshape.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/nnacl_common.h" + +int kMinCostPerThread = 16384; + +int ParallelReshape(void *param, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(param); + ReshapeStruct *reshape = (ReshapeStruct *)param; + + int data_size = (int)DataTypeCSize(reshape->base_.in_[0]->data_type_); + uint8_t *in_start = (uint8_t *)(reshape->base_.in_[0]->data_) + task_id * reshape->block_num_ * data_size; + uint8_t *out_start = (uint8_t *)(reshape->base_.out_[0]->data_) + task_id * reshape->block_num_ * data_size; + int copy_num = reshape->block_num_; + if (task_id == (reshape->base_.thread_nr_ - 1)) { + copy_num = reshape->total_num_ - task_id * reshape->block_num_; + } + (void)memcpy(out_start, in_start, copy_num * data_size); + return NNACL_OK; +} + +int ReshapeResize(struct KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + ReshapeStruct *reshape = (ReshapeStruct *)self; + reshape->total_num_ = NNACLGetElementNum(self->in_[0]); + if (reshape->total_num_ == 0) { + return NNACL_OK; + } + + self->thread_nr_ = MSMIN(self->thread_nr_, UP_DIV(reshape->total_num_, kMinCostPerThread)); + if (self->thread_nr_ < 1) { + self->thread_nr_ = 1; + } + NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_); + reshape->block_num_ = UP_DIV(reshape->total_num_, self->thread_nr_); + NNACL_CHECK_ZERO_RETURN_ERR(reshape->block_num_); + self->thread_nr_ = UP_DIV(reshape->total_num_, reshape->block_num_); + + return NNACL_OK; +} + +int ReshapeCompute(struct KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, ParallelReshape, self, self->thread_nr_); +} + +KernelBase *CreateReshape(OpParameter *param, int data_type) { + ReshapeStruct *reshape = (ReshapeStruct *)malloc(sizeof(ReshapeStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(reshape); + reshape->base_.Release = DefaultRelease; + reshape->base_.Prepare = DefaultPrepare1In1Out; + reshape->base_.Resize = ReshapeResize; + reshape->base_.Compute = ReshapeCompute; + return (KernelBase *)reshape; +} + +REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeInt32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeBool, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Flatten, kNumberTypeInt32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Flatten, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Flatten, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_FlattenGrad, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_FlattenGrad, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeInt32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeBool, CreateReshape) +REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeInt8, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeInt32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeBool, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeUInt8, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeInt32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeInt64, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeBool, CreateReshape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reshape.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reshape.h new file mode 100644 index 0000000000000000000000000000000000000000..dca87f68c2afc73f2e855b1ebb349da34f4d397e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reshape.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_RESHAPE_H_ +#define NNACL_KERNEL_RESHAPE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ReshapeStruct { + KernelBase base_; + int block_num_; + int total_num_; +} ReshapeStruct; + +int ParallelReshape(void *param, int task_id, float l, float r); + +KernelBase *CreateReshape(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_RESHAPE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reverse.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reverse.c new file mode 100644 index 0000000000000000000000000000000000000000..ade5ccf5a9e559152dd76027b35d7b4d69321643 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reverse.c @@ -0,0 +1,166 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/reverse.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/reverse_parameter.h" +#include "nnacl_c/fp32/reverse_fp32.h" + +int ReverseStride(TensorC *input, int index) { + int stride = 1; + for (int i = index + 1; i < (int)input->shape_size_; i++) { + stride *= input->shape_[i]; + } + return stride; +} + +int ReverseRun(void *cdata, int task_id, float l, float r) { + ReverseStruct *reverse = (ReverseStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(reverse); + + int offset = task_id * reverse->thread_stride_; + int count = NNACL_MIN(reverse->thread_stride_, reverse->data_size_ - offset); + if (count <= 0) { + return NNACL_OK; + } + + float *in_ptr = (float *)reverse->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(in_ptr); + float *out_ptr = (float *)reverse->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_ptr); + return Reverse(in_ptr + offset, out_ptr, reverse->thread_stride_, reverse->tmp_ + offset); +} + +int ReverseUpdateAxisInfo(ReverseStruct *reverse) { + ReverseParameter *reverse_param = (ReverseParameter *)reverse->base_.param_; + int in_shape_len = reverse->base_.in_[FIRST_INPUT]->shape_size_; + for (int i = 0; i < reverse_param->num_axis_; ++i) { + if (reverse_param->axis_[i] < 0) { + reverse_param->axis_[i] += in_shape_len; + } + if (reverse_param->axis_[i] < 0 || reverse_param->axis_[i] >= in_shape_len) { + return NNACL_REVERSE_AXIS_VALUE_INVALID; + } + } + return NNACL_OK; +} + +int ReverseCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, ReverseRun, self, self->thread_nr_); +} + +int ReversePrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + ReverseStruct *reverse = (ReverseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(reverse); + if (((ReverseParameter *)self->param_)->num_axis_ < Num1) { + return NNACL_REVERSE_AXIS_INVALID; + } + return NNACL_OK; +} + +int ReverseRelease(KernelBase *self) { + ReverseStruct *reverse = (ReverseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(reverse); + if (reverse->tmp_ != NULL) { + self->env_->Free(self->env_->allocator_, reverse->tmp_); + reverse->tmp_ = NULL; + } + return NNACL_OK; +} + +int ReverseResize(KernelBase *self) { + ReverseStruct *reverse = (ReverseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(reverse); + + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + // trans negative to positive axis + int ret = ReverseUpdateAxisInfo(reverse); + if (ret != NNACL_OK) { + return ret; + } + + reverse->data_size_ = NNACLGetElementNum(input); + if (NNACLGetElementNum(output) != reverse->data_size_) { + return NNACL_REVERSE_DATA_SIZE_INVALID; + } + + self->thread_nr_ = NNACL_MIN(self->thread_nr_, reverse->data_size_); + NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_); + reverse->thread_stride_ = UP_DIV(reverse->data_size_, self->thread_nr_); + + ReverseParameter *reverse_param = (ReverseParameter *)self->param_; + if (reverse_param->num_axis_ > input->shape_size_) { + return NNACL_REVERSE_NUM_AXIS_INVALID; + } + if (input->shape_size_ > REVERSE_SHAPE_MAX_SIZE) { + return NNACL_REVERSE_NUM_AXIS_INVALID; + } + + (void)self->Release(self); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(reverse->data_size_, sizeof(int), NNACL_ERR); + reverse->tmp_ = (int *)self->env_->Alloc(self->env_->allocator_, reverse->data_size_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(reverse->tmp_); + memset(reverse->tmp_, 0, reverse->data_size_ * sizeof(int)); + + for (int i = 0; i < reverse_param->num_axis_; i++) { + int axis = reverse_param->axis_[i]; + int stride = ReverseStride(input, axis); + reverse->strides_[i] = stride; + reverse->in_count_[i] = input->shape_[axis]; + reverse->out_count_[i] = 1; + for (int j = 0; j < axis; j++) { + reverse->out_count_[i] *= input->shape_[j]; + } + } + + int out; + int in; + int C; + int m; + for (int i = 0; i < reverse->data_size_; ++i) { + int tmp = i; + for (int j = 0; j < reverse_param->num_axis_; ++j) { + C = reverse->in_count_[j]; + out = tmp / (C * reverse->strides_[j]); + in = tmp / reverse->strides_[j] - out * C; + m = tmp % reverse->strides_[j]; + tmp = out * C * reverse->strides_[j] + reverse->strides_[j] * (C - 1 - in) + m; + } + reverse->tmp_[i] = tmp; + } + + return NNACL_OK; +} + +KernelBase *CreateReverse(OpParameter *param, int data_type) { + ReverseStruct *reverse = (ReverseStruct *)malloc(sizeof(ReverseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(reverse); + memset(reverse, 0, sizeof(ReverseStruct)); + reverse->base_.Release = ReverseRelease; + reverse->base_.Prepare = ReversePrepare; + reverse->base_.Resize = ReverseResize; + reverse->base_.Compute = ReverseCompute; + return (KernelBase *)reverse; +} + +REG_KERNEL_CREATOR(PrimType_ReverseV2, kNumberTypeFloat32, CreateReverse) +REG_KERNEL_CREATOR(PrimType_ReverseV2, kNumberTypeInt32, CreateReverse) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reverse.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reverse.h new file mode 100644 index 0000000000000000000000000000000000000000..b69760d68176e16748f4b6d0dad3525609b34114 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reverse.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_REVERSE_H_ +#define NNACL_KERNEL_REVERSE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct { + KernelBase base_; + int thread_stride_; + int data_size_; + int *tmp_; + int strides_[COMM_SHAPE_SIZE]; + int in_count_[COMM_SHAPE_SIZE]; + int out_count_[COMM_SHAPE_SIZE]; +} ReverseStruct; + +KernelBase *CreateReverse(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_REVERSE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/scale.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/scale.c new file mode 100644 index 0000000000000000000000000000000000000000..d87f6dd9e18d117c3b6d6c2ff03a2e94e6b64fd0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/scale.c @@ -0,0 +1,333 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/scale.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/fp32/scale_fp32.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/utils_fp16.h" +#include "nnacl_c/fp16/scale_fp16.h" +#endif + +int ScaleRunF16(ScaleStruct *scale, int task_id, ActType act_type) { +#ifdef ENABLE_FP16 + switch (act_type) { + case ActType_Relu6: + DoScaleRelu6Fp16((const float16_t *)scale->input_, (float16_t *)scale->output_, (const float16_t *)scale->scale_, + (const float16_t *)scale->offset_, task_id, scale); + break; + case ActType_Relu: + Fp16DoScaleRelu((const float16_t *)scale->input_, (float16_t *)scale->output_, (const float16_t *)scale->scale_, + (const float16_t *)scale->offset_, task_id, scale); + break; + case ActType_No: + DoScaleFp16((const float16_t *)scale->input_, (float16_t *)scale->output_, (const float16_t *)scale->scale_, + (const float16_t *)scale->offset_, task_id, scale); + break; + default: + return NNACL_ERR; + } + return NNACL_OK; +#endif + return NNACL_DISABLE_FP16; +} + +int ScaleInitInputDataType(ScaleStruct *scale) { + if (scale->data_type_ == kNumberTypeFloat32) { + return NNACL_OK; + } + +#ifdef ENABLE_FP16 + TensorC *scale_tensor = scale->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + if (scale_tensor->data_type_ != kNumberTypeFloat16 && scale->malloc_scale_ == false) { + scale->malloc_scale_ = true; + scale->scale_ = GetOrAllocFp16Data(scale_tensor, scale->base_.env_, true); + } else { + scale->malloc_scale_ = false; + scale->scale_ = NULL; + } + + if (scale->base_.in_size_ == TWO_TENSOR) { + /* already done in prepare */ + return NNACL_OK; + } + + TensorC *offset_tensor = scale->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + if (offset_tensor->data_type_ != kNumberTypeFloat16 && scale->malloc_scale_ == false) { + scale->malloc_offset_ = true; + scale->offset_ = GetOrAllocFp16Data(offset_tensor, scale->base_.env_, true); + } else { + scale->malloc_offset_ = false; + scale->offset_ = NULL; + } + return NNACL_OK; +#endif + return NNACL_DISABLE_FP16; +} + +int ScaleRunF32(ScaleStruct *scale, int task_id, ActType act_type) { + switch (act_type) { + case ActType_Relu6: + DoScaleRelu6((const float *)scale->input_, (float *)scale->output_, (const float *)scale->scale_, + (const float *)scale->offset_, task_id, scale); + break; + case ActType_Relu: + DoScaleRelu((const float *)scale->input_, (float *)scale->output_, (const float *)scale->scale_, + (const float *)scale->offset_, task_id, scale); + break; + case ActType_No: + DoScale((const float *)scale->input_, (float *)scale->output_, (const float *)scale->scale_, + (const float *)scale->offset_, task_id, scale); + break; + default: + return NNACL_SCALE_UNSUPPORT_ACT_TYPE; + } + return NNACL_OK; +} + +int ScaleRun(void *cdata, int task_id, float l, float r) { + ScaleStruct *scale = (ScaleStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(scale); + ActType act_type = ((ScaleParameter *)scale->base_.param_)->activation_type_; + if (scale->data_type_ == kNumberTypeFloat16) { + return ScaleRunF16(scale, task_id, act_type); + } else if (scale->data_type_ == kNumberTypeFloat32) { + return ScaleRunF32(scale, task_id, act_type); + } + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +int ScaleCalculateParameter(ScaleStruct *scale) { + TensorC *input_tensor = scale->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *scale_tensor = scale->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + TensorC *output_tensor = scale->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + + scale->outer_size_ = 1; + scale->axis_size_ = 1; + scale->inner_size_ = 1; + for (int i = 0; i < scale->axis_; i++) { + scale->outer_size_ *= input_tensor->shape_[i]; + } + for (size_t i = 0; i < scale_tensor->shape_size_; i++) { + scale->axis_size_ *= input_tensor->shape_[i + scale->axis_]; + } + for (size_t i = scale->axis_ + scale_tensor->shape_size_; i < input_tensor->shape_size_; i++) { + scale->inner_size_ *= input_tensor->shape_[i]; + } + + scale->base_.thread_nr_ = MSMIN(scale->base_.thread_nr_, scale->outer_size_); + NNACL_CHECK_ZERO_RETURN_ERR(scale->base_.thread_nr_); + + return NNACL_OK; +} + +int ScaleInitScaleOffset(ScaleStruct *scale) { + TensorC *scale_tensor = scale->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + int data_type_size = DataTypeCSize(scale->data_type_); + + if (scale->base_.in_size_ == TWO_TENSOR) { + scale->malloc_offset_ = true; + int malloc_size = NNACLGetElementNum(scale_tensor) * data_type_size; + NNACL_CHECK_MALLOC_SIZE(malloc_size); + scale->offset_ = scale->base_.env_->Alloc(scale->base_.env_->allocator_, malloc_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(scale->offset_); + memset(scale->offset_, 0, malloc_size); + } + + if (scale->data_type_ == kNumberTypeFloat16) { + /* handle fp16 scale and offset in compute */ + return NNACL_OK; + } + + if (scale_tensor->data_ != NULL) { + scale->malloc_scale_ = true; + int malloc_size = NNACLGetElementNum(scale_tensor) * data_type_size; + NNACL_CHECK_MALLOC_SIZE(malloc_size); + scale->scale_ = scale->base_.env_->Alloc(scale->base_.env_->allocator_, malloc_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(scale->scale_); + (void)memcpy(scale->scale_, scale_tensor->data_, malloc_size); + } else { + scale->malloc_scale_ = false; + scale->scale_ = NULL; + } + + if (scale->base_.in_size_ == TWO_TENSOR) { + return NNACL_OK; + } + NNACL_CHECK_FALSE(scale->base_.in_size_ != THREE_TENSOR, NNACL_SCALE_INPUT_NUM_INVALID); + + TensorC *offset_tensor = scale->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(offset_tensor); + if (offset_tensor->data_ != NULL) { + scale->malloc_offset_ = true; + int malloc_size = NNACLGetElementNum(offset_tensor) * data_type_size; + NNACL_CHECK_MALLOC_SIZE(malloc_size); + scale->offset_ = scale->base_.env_->Alloc(scale->base_.env_->allocator_, malloc_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(scale->scale_); + (void)memcpy(scale->offset_, offset_tensor->data_, malloc_size); + } else { + scale->malloc_offset_ = false; + scale->offset_ = NULL; + } + + return NNACL_OK; +} + +int ScaleCheckInputsOutputs(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + + for (size_t i = 0; i < self->in_size_; i++) { + TensorC *input_tensor = self->in_[i]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + if (input_tensor->data_type_ != kNumberTypeFloat32 && input_tensor->data_type_ != kNumberTypeFloat16) { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + } + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + if (output_tensor->data_type_ != kNumberTypeFloat32 && output_tensor->data_type_ != kNumberTypeFloat16) { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +int ScaleRelease(struct KernelBase *self) { + ScaleStruct *scale = (ScaleStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(scale); + + if (scale->malloc_scale_ && scale->scale_ != NULL) { + self->env_->Free(self->env_->allocator_, scale->scale_); + scale->scale_ = NULL; + scale->malloc_scale_ = false; + } + + if (scale->malloc_offset_ && scale->offset_ != NULL) { + self->env_->Free(self->env_->allocator_, scale->offset_); + scale->offset_ = NULL; + scale->malloc_offset_ = false; + } + return NNACL_OK; +} + +int ScaleResize(struct KernelBase *self) { + ScaleStruct *scale = (ScaleStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(scale); + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *scale_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + + int origin_axis = ((ScaleParameter *)self->param_)->axis_; + scale->axis_ = origin_axis < 0 ? origin_axis + input_tensor->shape_size_ : origin_axis; + + for (size_t i = 0; i < scale_tensor->shape_size_; i++) { + if (i + scale->axis_ >= input_tensor->shape_size_) { + return NNACL_SCALE_AXIS_AND_SHAPE_UNMATCH; + } + if (input_tensor->shape_[i + scale->axis_] != scale_tensor->shape_[i]) { + return NNACL_SCALE_SCALE_SHAPE_UNMATCH; + } + } + + int ret = ScaleCalculateParameter(scale); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +int ScaleCompute(struct KernelBase *self) { + ScaleStruct *scale = (ScaleStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(scale); + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + scale->input_ = input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(scale->input_); + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + scale->output_ = output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(scale->output_); + + int ret = ScaleInitInputDataType(scale); + if (ret != NNACL_OK) { + return ret; + } + + if (!scale->malloc_scale_) { + TensorC *scale_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + scale->scale_ = scale_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(scale->scale_); + } + + if (!scale->malloc_offset_) { + TensorC *offset_tensor = self->in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(offset_tensor); + scale->offset_ = offset_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(scale->offset_); + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, ScaleRun, self, self->thread_nr_); +} + +int ScalePrepare(struct KernelBase *self) { + ScaleStruct *scale = (ScaleStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(scale); + + int ret = ScaleCheckInputsOutputs(self); + if (ret != NNACL_OK) { + return ret; + } + + ret = ScaleInitScaleOffset(scale); + if (ret != NNACL_OK) { + return ret; + } + + return NNACL_OK; +} + +KernelBase *CreateScale(OpParameter *param, int data_type) { + ScaleStruct *scale = (ScaleStruct *)malloc(sizeof(ScaleStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(scale); + memset(scale, 0, sizeof(ScaleStruct)); + scale->data_type_ = data_type; + scale->scale_ = NULL; + scale->offset_ = NULL; + scale->malloc_scale_ = false; + scale->malloc_offset_ = false; + scale->base_.Prepare = ScalePrepare; + scale->base_.Resize = ScaleResize; + scale->base_.Compute = ScaleCompute; + scale->base_.Release = ScaleRelease; + return (KernelBase *)scale; +} + +REG_KERNEL_CREATOR(PrimType_ScaleFusion, kNumberTypeFloat16, CreateScale) +REG_KERNEL_CREATOR(PrimType_ScaleFusion, kNumberTypeFloat32, CreateScale) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/scale.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/scale.h new file mode 100644 index 0000000000000000000000000000000000000000..878494216a779e3c7cf43d84e0f02c40d7791542 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/scale.h @@ -0,0 +1,41 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SCALE_H_ +#define NNACL_KERNEL_SCALE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ScaleStruct { + KernelBase base_; + int axis_; + int data_type_; + int axis_size_; + int outer_size_; + int inner_size_; + bool malloc_scale_; + bool malloc_offset_; + void *scale_; + void *offset_; + void *input_; + void *output_; +} ScaleStruct; + +KernelBase *CreateScale(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SCALE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/shape.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/shape.c new file mode 100644 index 0000000000000000000000000000000000000000..e9637fdf1d7cd493a943dfcd02f85cf32ed18e5a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/shape.c @@ -0,0 +1,51 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/shape.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +int ShapeCompute(struct KernelBase *self) { + ShapeStruct *shape = (ShapeStruct *)self; + memcpy(self->out_[OUTPUT_INDEX]->data_, self->in_[FIRST_INPUT]->shape_, shape->shape_size_); + return NNACL_OK; +} + +int ShapeResize(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + ShapeStruct *shape = (ShapeStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(shape); + shape->shape_size_ = self->in_[FIRST_INPUT]->shape_size_ * sizeof(int); + return NNACL_OK; +} + +KernelBase *CreateShape(OpParameter *param, int data_type) { + ShapeStruct *shape = (ShapeStruct *)malloc(sizeof(ShapeStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(shape); + shape->base_.Release = DefaultRelease; + shape->base_.Prepare = DefaultPrepare1In1Out; + shape->base_.Resize = ShapeResize; + shape->base_.Compute = ShapeCompute; + return (KernelBase *)shape; +} + +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeInt32, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeBool, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeFloat16, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeFloat32, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeInt8, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeUInt8, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeInt64, CreateShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/shape.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/shape.h new file mode 100644 index 0000000000000000000000000000000000000000..3cbc9aa2ce21337732d01851ba11cc1364d1b180 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/shape.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SHAPE_H_ +#define NNACL_KERNEL_SHAPE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ShapeStruct { + KernelBase base_; + int shape_size_; +} ShapeStruct; + +KernelBase *CreateShape(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SHAPE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/size.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/size.c new file mode 100644 index 0000000000000000000000000000000000000000..ae1768c6e181c9fe19bdead1f33835ad438349ee --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/size.c @@ -0,0 +1,44 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/size.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" + +int SizeCompute(KernelBase *self) { + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + int *out_data = (int *)out_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + out_data[Index0] = NNACLGetElementNum(in_tensor); + return NNACL_OK; +} + +KernelBase *CreateSize(OpParameter *param, int data_type) { + SizeStruct *size = (SizeStruct *)malloc(sizeof(SizeStruct)); + NNACL_CHECK_NULL_RETURN_NULL(size); + size->base_.Release = DefaultRelease; + size->base_.Prepare = DefaultPrepare1In1Out; + size->base_.Resize = DefaultResize; + size->base_.Compute = SizeCompute; + return (KernelBase *)size; +} + +REG_KERNEL_CREATOR(PrimType_Size, kNumberTypeInt32, CreateSize) +REG_KERNEL_CREATOR(PrimType_Size, kNumberTypeFloat32, CreateSize) +REG_KERNEL_CREATOR(PrimType_Size, kNumberTypeFloat16, CreateSize) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/size.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/size.h new file mode 100644 index 0000000000000000000000000000000000000000..32690bc66ae1ed126f286fdc926efdd1bec54afe --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/size.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SIZE_H_ +#define NNACL_KERNEL_SIZE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct SizeStruct { + KernelBase base_; +} SizeStruct; + +KernelBase *CreateSize(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SIZE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/slice.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/slice.c new file mode 100644 index 0000000000000000000000000000000000000000..28ebacc22f19f3fc1482de0a809582446658fda5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/slice.c @@ -0,0 +1,76 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/slice.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/base/slice_base.h" +#include "nnacl_c/nnacl_common.h" + +int SliceLaunch(void *cdata, int task_id, float l, float r) { + SliceStruct *slice = (SliceStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(slice); + void *in_data = slice->base_.in_[FIRST_INPUT]->data_; + void *out_data = slice->base_.out_[OUTPUT_INDEX]->data_; + DoSlice(in_data, out_data, slice, task_id, slice->base_.thread_nr_, slice->data_type_size_); + return NNACL_OK; +} + +int SliceResize(KernelBase *self) { + SliceStruct *slice = (SliceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(slice); + + InitSliceStruct(slice, self->in_[Index0], self->in_[Index1], self->in_[Index2]); + + if (slice->param_length_ < DIMENSION_8D) { + PadSliceParameterTo8D(slice); + } + return NNACL_OK; +} + +int SliceCompute(KernelBase *self) { + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + + SliceStruct *slice = (SliceStruct *)self; + if (slice->size_[Index5] < self->thread_nr_) { + DoSliceNoParallel(in_tensor->data_, out_tensor->data_, slice, slice->data_type_size_); + return NNACL_OK; + } + + int ret = self->env_->ParallelLaunch(self->env_->thread_pool_, SliceLaunch, self, self->thread_nr_); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +KernelBase *CreateSlice(OpParameter *param, int data_type) { + SliceStruct *slice = (SliceStruct *)malloc(sizeof(SliceStruct)); + NNACL_CHECK_NULL_RETURN_NULL(slice); + slice->data_type_size_ = DataTypeCSize(data_type); + slice->base_.Release = DefaultRelease; + slice->base_.Prepare = DefaultPrepare3In1Out; + slice->base_.Resize = SliceResize; + slice->base_.Compute = SliceCompute; + return (KernelBase *)slice; +} + +REG_KERNEL_CREATOR(PrimType_SliceFusion, kNumberTypeInt32, CreateSlice) +REG_KERNEL_CREATOR(PrimType_SliceFusion, kNumberTypeFloat32, CreateSlice) +REG_KERNEL_CREATOR(PrimType_SliceFusion, kNumberTypeFloat16, CreateSlice) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/slice.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/slice.h new file mode 100644 index 0000000000000000000000000000000000000000..d4bd3ce4f0582820ce1e75b2a3ff4885a6e40bb4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/slice.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SLICE_H_ +#define NNACL_KERNEL_SLICE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct SliceStruct { + KernelBase base_; + int data_type_size_; + int32_t begin_[DIMENSION_8D]; + int32_t size_[DIMENSION_8D]; + int32_t shape_[DIMENSION_8D]; + int32_t end_[DIMENSION_8D]; + int32_t param_length_; +} SliceStruct; + +KernelBase *CreateSlice(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SLICE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/softmax.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/softmax.c new file mode 100644 index 0000000000000000000000000000000000000000..967b142d8824f609d73f1b139376cafbe6b0480b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/softmax.c @@ -0,0 +1,157 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/softmax.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/fp32/softmax_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/softmax_fp16.h" +#endif + +int SoftmaxLastAxisRun(void *cdata, int task_id, float l, float r) { + SoftmaxStruct *softmax = (SoftmaxStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(softmax); + + NNACL_CHECK_ZERO_RETURN_ERR(softmax->base_.thread_nr_); + int unit = UP_DIV(softmax->out_plane_size_, softmax->base_.thread_nr_); + + int *in_shape = softmax->base_.in_[FIRST_INPUT]->shape_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, unit, NNACL_ERR); + int begin = task_id * unit; + int end = MSMIN(begin + unit, softmax->out_plane_size_); + int channel = in_shape[softmax->axis_]; + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(begin, channel, NNACL_ERR); + int offset = begin * channel; + + void *input_ptr = softmax->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + void *output_ptr = softmax->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + +#ifdef ENABLE_FP16 + if (softmax->data_type_ == kNumberTypeFloat16) { + SoftmaxLastAxisFp16((float16_t *)input_ptr + offset, (float16_t *)output_ptr + offset, end - begin, channel); + return NNACL_OK; + } +#endif + return SoftmaxLastAxis((float *)input_ptr + offset, (float *)output_ptr + offset, end - begin, channel); +} + +int SoftmaxRelease(struct KernelBase *self) { + SoftmaxStruct *softmax = (SoftmaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(softmax); + if (softmax->sum_data_ != NULL) { + self->env_->Free(self->env_->allocator_, softmax->sum_data_); + } + softmax->sum_data_ = NULL; + return NNACL_OK; +} + +int InitSoftmaxParam(SoftmaxStruct *softmax) { + TensorC *in_tensor = softmax->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + int *in_shape = in_tensor->shape_; + + softmax->n_dim_ = (int)in_tensor->shape_size_; + int origin_axis = ((SoftmaxParameter *)softmax->base_.param_)->axis_; + softmax->axis_ = origin_axis == -1 ? origin_axis + softmax->n_dim_ : origin_axis; + + NNACL_CHECK_TRUE_RET(softmax->axis_ >= 0, NNACL_SOFTMAX_AXIS_INVALID); + NNACL_CHECK_TRUE_RET(softmax->axis_ < (int)in_tensor->shape_size_, NNACL_SOFTMAX_AXIS_INVALID); + + int out_plane_size = 1; + for (int i = 0; i < softmax->axis_; ++i) { + out_plane_size *= in_shape[i]; + } + int in_plane_size = 1; + for (int i = softmax->axis_ + 1; i < softmax->n_dim_; i++) { + in_plane_size *= in_shape[i]; + } + + ExecEnv *env = softmax->base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + softmax->in_plane_size_ = in_plane_size; + softmax->out_plane_size_ = out_plane_size; + + (void)softmax->base_.Release(&softmax->base_); + if (softmax->in_plane_size_ > 1) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(out_plane_size, in_plane_size, NNACL_ERR); + int sum_data_size = out_plane_size * in_plane_size; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(sum_data_size, (int)DataTypeCSize(softmax->data_type_), NNACL_ERR); + softmax->sum_data_ = env->Alloc(env->allocator_, sum_data_size * DataTypeCSize(softmax->data_type_)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(softmax->sum_data_); + } + return NNACL_OK; +} + +int SoftmaxResize(struct KernelBase *self) { + SoftmaxStruct *softmax = (SoftmaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(softmax); + InitSoftmaxParam(softmax); + + TensorC *in_tensor = self->in_[FIRST_INPUT]; + int *in_shape = in_tensor->shape_; + + self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_Softmax), in_shape[softmax->axis_], in_shape[softmax->axis_], + NNACLGetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_); + return NNACL_OK; +} + +int SoftmaxCompute(struct KernelBase *self) { + SoftmaxStruct *softmax = (SoftmaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(softmax); + + if (softmax->in_plane_size_ == 1) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, SoftmaxLastAxisRun, softmax, self->thread_nr_); + } + + void *input_ptr = self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + void *output_ptr = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + NNACL_CHECK_NULL_RETURN_ERR(softmax->sum_data_); +#ifdef ENABLE_FP16 + if (softmax->data_type_ == kNumberTypeFloat16) { + SoftmaxFp16((float16_t *)input_ptr, (float16_t *)output_ptr, (float16_t *)softmax->sum_data_, softmax->axis_, + softmax->n_dim_, self->in_[FIRST_INPUT]->shape_); + return NNACL_OK; + } +#endif + Softmax((float *)input_ptr, (float *)output_ptr, (float *)softmax->sum_data_, softmax->axis_, softmax->n_dim_, + self->in_[FIRST_INPUT]->shape_); + return NNACL_OK; +} + +KernelBase *CreateSoftmax(OpParameter *param, int data_type) { + SoftmaxStruct *softmax = (SoftmaxStruct *)malloc(sizeof(SoftmaxStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(softmax); + memset(softmax, 0, sizeof(SoftmaxStruct)); + + softmax->sum_data_ = NULL; + softmax->data_type_ = data_type; + softmax->base_.Release = SoftmaxRelease; + softmax->base_.Prepare = DefaultPrepare1In1Out; + softmax->base_.Resize = SoftmaxResize; + softmax->base_.Compute = SoftmaxCompute; + return (KernelBase *)softmax; +} + +REG_KERNEL_CREATOR(PrimType_Softmax, kNumberTypeFloat16, CreateSoftmax) +REG_KERNEL_CREATOR(PrimType_Softmax, kNumberTypeFloat32, CreateSoftmax) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/softmax.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..f37d0e130f7820acf321d4e7620cd9260cea1f22 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/softmax.h @@ -0,0 +1,39 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SOFTMAX_H_ +#define NNACL_KERNEL_SOFTMAX_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct SoftmaxStruct { + KernelBase base_; + int axis_; + int n_dim_; + int in_plane_size_; + int out_plane_size_; + void *sum_data_; + TypeIdC data_type_; + int unit_; +} SoftmaxStruct; + +int InitSoftmaxParam(SoftmaxStruct *softmax); +int SoftmaxRelease(struct KernelBase *self); +KernelBase *CreateSoftmax(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SOFTMAX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/splice.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/splice.c new file mode 100644 index 0000000000000000000000000000000000000000..be845ae1cab9d4c189f904abd2fe24b7bf1bd85e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/splice.c @@ -0,0 +1,79 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/splice.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/splice_parameter.h" +#include "nnacl_c/fp32/splice_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/splice_fp16.h" +#endif + +int SpliceCompute(struct KernelBase *self) { + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + NNACL_CHECK_FALSE(input->shape_size_ != output->shape_size_, NNACL_SPLICE_SHAPE_INVALID); + NNACL_CHECK_FALSE(input->shape_size_ != DIMENSION_3D, NNACL_SPLICE_SHAPE_INVALID); + NNACL_CHECK_FALSE(output->shape_size_ != DIMENSION_3D, NNACL_SPLICE_SHAPE_INVALID); + + SpliceParameter *param = (SpliceParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + int src_row = input->shape_[Index1]; + int src_col = input->shape_[Index2]; + int dst_row = output->shape_[Index1]; + int dst_col = output->shape_[Index2]; + + NNACL_CHECK_FALSE(src_col * param->context_dim_ != dst_col, NNACL_SPLICE_SHAPE_INVALID); + NNACL_CHECK_FALSE(param->context_dim_ * dst_row != param->forward_indexes_dim_, NNACL_SPLICE_SHAPE_INVALID); + + for (int i = 0; i < param->forward_indexes_dim_; ++i) { + if (param->forward_indexes_[i] >= src_row) { + return NNACL_SPLICE_SHAPE_INVALID; + } + } + + void *input_data = input->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + void *output_data = output->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + +#ifdef ENABLE_FP16 + if (input->data_type_ == kNumberTypeFloat16) { + SpliceFp16((float16_t *)input_data, src_row, src_col, param, (float16_t *)output_data, dst_row, dst_col); + return NNACL_OK; + } +#endif + + SpliceFp32((float *)input_data, src_row, src_col, param, (float *)output_data, dst_row, dst_col); + return NNACL_OK; +} + +KernelBase *CreateSplice(OpParameter *param, int data_type) { + SpliceStruct *splice = (SpliceStruct *)malloc(sizeof(SpliceStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(splice); + splice->base_.Release = DefaultRelease; + splice->base_.Prepare = DefaultPrepare1In1Out; + splice->base_.Resize = DefaultResize; + splice->base_.Compute = SpliceCompute; + return (KernelBase *)splice; +} + +REG_KERNEL_CREATOR(PrimType_Splice, kNumberTypeFloat32, CreateSplice) +REG_KERNEL_CREATOR(PrimType_Splice, kNumberTypeFloat16, CreateSplice) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/splice.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/splice.h new file mode 100644 index 0000000000000000000000000000000000000000..45b9e39f1acee387c0db74bbcaae22e7fb015043 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/splice.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SPLICE_H_ +#define NNACL_KERNEL_SPLICE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct SpliceStruct { + KernelBase base_; +} SpliceStruct; + +KernelBase *CreateSplice(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SPLICE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/stack.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/stack.c new file mode 100644 index 0000000000000000000000000000000000000000..5745492281895affb87a4a754ca8b3fbeb46d5f6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/stack.c @@ -0,0 +1,138 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/stack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/stack_parameter.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/base/stack_base.h" +#include "nnacl_c/tensor_c_utils.h" + +static inline int GetCopyNum(const int *in_shape, int axis, int n_dim) { + int copy_num = 1; + if (axis > 0) { + for (int j = n_dim - 1; j > axis - 1; j--) { + copy_num *= in_shape[j]; + } + } else { + for (int i = 0; i < n_dim; ++i) { + copy_num *= in_shape[i]; + } + } + return copy_num; +} + +static inline int GetOuterSize(const int *in_shape, int axis) { + int outer_size = 1; + for (int i = 0; i < axis; ++i) { + outer_size *= in_shape[i]; + } + return outer_size; +} + +int StackRelease(KernelBase *self) { + StackStruct *stack = (StackStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(stack); + if (stack->buffers_ != NULL) { + self->env_->Free(self->env_->allocator_, stack->buffers_); + stack->buffers_ = NULL; + } + return NNACL_OK; +} + +int StackPrepare(KernelBase *self) { + StackStruct *stack = (StackStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(stack); + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + stack->buffers_ = + (void **)self->env_->Alloc(self->env_->allocator_, (self->in_size_ + self->out_size_) * sizeof(void *)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(stack->buffers_); + return NNACL_OK; +} + +int StackResize(KernelBase *self) { + StackStruct *stack = (StackStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(stack); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + + int origin_axis = ((StackParameter *)self->param_)->axis_; + stack->axis_ = origin_axis < 0 ? origin_axis + (int)input->shape_size_ + 1 : origin_axis; + + if (self->in_size_ == 1) { + NNACL_CHECK_FALSE(NNACLGetElementNum(input) <= 0, NNACL_STACK_TENSOR_SHAPE_INVALID); + stack->copy_size_ = (size_t)NNACLGetElementNum(input) * DataTypeCSize(stack->data_type_); + stack->outer_size_ = 1; + } else { + NNACL_CHECK_FALSE((int)input->shape_size_ < stack->axis_, NNACL_STACK_TENSOR_SHAPE_INVALID); + size_t copy_num = (size_t)GetCopyNum(input->shape_, stack->axis_, input->shape_size_); + stack->copy_size_ = copy_num * DataTypeCSize(stack->data_type_); + stack->outer_size_ = GetOuterSize(input->shape_, stack->axis_); + } + + self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_Stack), stack->copy_size_, stack->copy_size_, + NNACLGetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_); + self->thread_nr_ = NNACL_MIN(UP_DIV(stack->outer_size_, NNACL_STACK_STEP), self->thread_nr_); + return NNACL_OK; +} + +int StackRun(void *cdata, int task_id, float l, float r) { + StackStruct *stack = (StackStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(stack); + + NNACL_CHECK_TRUE_RET(stack->base_.thread_nr_ != 0, NNACL_ERR); + int step = UP_DIV(stack->outer_size_, stack->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, step, NNACL_ERR); + int start = task_id * step; + int end = NNACL_MIN(start + step, stack->outer_size_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(stack->base_.in_size_ * (size_t)start, stack->copy_size_, NNACL_ERR); + + void *output_data = (void *)(stack->base_.out_[OUTPUT_INDEX]->data_); + NNACL_CHECK_NULL_RETURN_ERR(output_data); + uint8_t *output = (uint8_t *)output_data + stack->base_.in_size_ * (size_t)start * stack->copy_size_; + + Stack(stack->buffers_, (void *)output, stack->base_.in_size_, stack->copy_size_, start, end); + return NNACL_OK; +} + +int StackCompute(KernelBase *self) { + StackStruct *stack = (StackStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(stack); + + for (size_t i = 0; i < self->in_size_; ++i) { + stack->buffers_[i] = self->in_[i]->data_; + NNACL_CHECK_NULL_RETURN_ERR(stack->buffers_[i]); + } + stack->buffers_[self->in_size_] = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(stack->buffers_[self->in_size_]); + return self->env_->ParallelLaunch(self->env_->thread_pool_, StackRun, self, self->thread_nr_); +} + +KernelBase *CreateStack(OpParameter *param, int data_type) { + StackStruct *stack = (StackStruct *)malloc(sizeof(StackStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(stack); + stack->buffers_ = NULL; + stack->data_type_ = data_type; + stack->base_.Release = StackRelease; + stack->base_.Prepare = StackPrepare; + stack->base_.Resize = StackResize; + stack->base_.Compute = StackCompute; + return (KernelBase *)stack; +} + +REG_KERNEL_CREATOR(PrimType_Stack, kNumberTypeFloat32, CreateStack) +REG_KERNEL_CREATOR(PrimType_Stack, kNumberTypeInt32, CreateStack) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/stack.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/stack.h new file mode 100644 index 0000000000000000000000000000000000000000..e02d1cef1d1d9d903845f2023a5ea888c3616970 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/stack.h @@ -0,0 +1,41 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_STACK_H_ +#define NNACL_KERNEL_STACK_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +#define NNACL_STACK_STEP 64 + +typedef struct StackStruct { + KernelBase base_; + TypeIdC data_type_; + int axis_; + int outer_size_; + size_t copy_size_; + void **buffers_; +} StackStruct; + +KernelBase *CreateStack(OpParameter *param, int data_type); +int StackRun(void *cdata, int task_id, float l, float r); +int StackRelease(KernelBase *self); +int StackPrepare(KernelBase *self); +int StackResize(KernelBase *self); + +#endif // NNACL_KERNEL_STACK_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/strided_slice.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/strided_slice.c new file mode 100644 index 0000000000000000000000000000000000000000..3db06715311e8abfb7fa4d364a1ac01d8b976038 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/strided_slice.c @@ -0,0 +1,278 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/strided_slice.h" +#include "nnacl_c/strided_slice_parameter.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/strided_slice_fp32.h" +#include "nnacl_c/kernel/reshape.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" + +#define MinStridedSlicePerThread 16384 + +int StridedSliceFaseRun(void *cdata, int task_id, float l, float r) { + StridedSliceStruct *strided_slice = (StridedSliceStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(strided_slice); + + uint8_t *input_data = strided_slice->base_.in_[FIRST_INPUT]->data_; + uint8_t *output_data = strided_slice->base_.out_[OUTPUT_INDEX]->data_; + int *in_shape = strided_slice->base_.in_[FIRST_INPUT]->shape_; + int *out_shape = strided_slice->base_.out_[OUTPUT_INDEX]->shape_; + int begin_index = strided_slice->begins_[strided_slice->split_axis_]; + int caled_num = task_id * strided_slice->cal_num_per_thread_; + int64_t inner_size = (int64_t)strided_slice->inner_size_; + + if (strided_slice->parallel_on_outer_) { + uint8_t *cur_in_ptr = input_data + (caled_num * in_shape[strided_slice->split_axis_] + begin_index) * inner_size; + uint8_t *cur_out_ptr = output_data + caled_num * out_shape[strided_slice->split_axis_] * inner_size; + int cur_outer = (int)strided_slice->outer_ - caled_num; + if (cur_outer <= 0) { + return NNACL_OK; + } + if (cur_outer > strided_slice->cal_num_per_thread_) { + cur_outer = strided_slice->cal_num_per_thread_; + } + FastStride(cur_in_ptr, cur_out_ptr, out_shape[strided_slice->split_axis_], + strided_slice->strides_[strided_slice->split_axis_], cur_outer, strided_slice->inner_size_, + (size_t)in_shape[strided_slice->split_axis_] * strided_slice->inner_size_); + return NNACL_OK; + } + + if (strided_slice->parallel_on_split_axis_) { + uint8_t *cur_in_ptr = + input_data + (caled_num * strided_slice->strides_[strided_slice->split_axis_] + begin_index) * inner_size; + uint8_t *cur_out_ptr = output_data + caled_num * inner_size; + int cal_axis_num = out_shape[strided_slice->split_axis_] - caled_num; + if (cal_axis_num <= 0) { + return NNACL_OK; + } + if (cal_axis_num > strided_slice->cal_num_per_thread_) { + cal_axis_num = strided_slice->cal_num_per_thread_; + } + FastStride(cur_in_ptr, cur_out_ptr, (uint32_t)cal_axis_num, strided_slice->strides_[strided_slice->split_axis_], 1, + strided_slice->inner_size_, 0); + return NNACL_OK; + } + + return NNACL_STRIDED_SLICE_INVALID_PARALLEL_MOD; +} + +int StridedSliceFastRun(StridedSliceStruct *strided_slice) { + // Update length of inner size, because data type of tensor may be changed + // from float32 to float16 during fp16 sub-graph partition process. + size_t data_type_size = DataTypeCSize(strided_slice->base_.in_[FIRST_INPUT]->data_type_); + NNACL_CHECK_FALSE(data_type_size == 0, NNACL_STRIDED_SLICE_UNSUPPORTED_DATA_TYPE); + strided_slice->inner_size_ = strided_slice->inner_ * data_type_size; + + NNACL_CHECK_NULL_RETURN_ERR(strided_slice->base_.in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(strided_slice->base_.in_[OUTPUT_INDEX]->data_); + return strided_slice->base_.env_->ParallelLaunch(strided_slice->base_.env_->thread_pool_, StridedSliceFaseRun, + strided_slice, strided_slice->base_.thread_nr_); +} + +bool StridedSliceMatchInOutShapeEqualPattern(StridedSliceStruct *strided_slice) { + for (int i = 0; i < MAX_SHAPE_SIZE; i++) { + if (strided_slice->strides_[i] < 0) { + return false; + } + } + + TensorC *in_tensor = strided_slice->base_.in_[FIRST_INPUT]; + TensorC *out_tensor = strided_slice->base_.out_[OUTPUT_INDEX]; + + if (in_tensor->data_type_ != out_tensor->data_type_) { + return false; + } + + if (in_tensor->shape_size_ != out_tensor->shape_size_) { + return false; + } + + if (in_tensor->shape_size_ < ONE_TENSOR) { + return false; + } + + for (size_t i = 0; i < in_tensor->shape_size_; ++i) { + if (in_tensor->shape_[i] != out_tensor->shape_[i]) { + return false; + } + if (in_tensor->shape_[i] == -1) { + return false; + } + } + return true; +} + +int StridedSliceSoftCopyInputToOutput(StridedSliceStruct *strided_slice) { + TensorC *in_tensor = strided_slice->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + NNACL_CHECK_NULL_RETURN_ERR(in_tensor->data_); + TensorC *out_tensor = strided_slice->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + NNACL_CHECK_NULL_RETURN_ERR(out_tensor->data_); + + int total_num = NNACLGetElementNum(in_tensor); + NNACL_CHECK_FALSE(total_num == 0, NNACL_STRIDED_SLICE_INVALID_DATA_SIZE); + + strided_slice->base_.thread_nr_ = + NNACL_MIN(strided_slice->base_.thread_nr_, UP_DIV(total_num, MinStridedSlicePerThread)); + if (strided_slice->base_.thread_nr_ < 1) { + strided_slice->base_.thread_nr_ = 1; + } + + int block_num = UP_DIV(total_num, strided_slice->base_.thread_nr_); + strided_slice->base_.thread_nr_ = UP_DIV(total_num, block_num); + + if (in_tensor->data_ != out_tensor->data_) { + if (strided_slice->base_.thread_nr_ == 1) { + (void)memcpy(out_tensor->data_, in_tensor->data_, total_num * (int)DataTypeCSize(in_tensor->data_type_)); + return NNACL_OK; + } + ReshapeStruct reshape; + reshape.base_.in_ = strided_slice->base_.in_; + reshape.base_.out_ = strided_slice->base_.out_; + reshape.block_num_ = block_num; + reshape.total_num_ = total_num; + reshape.base_.thread_nr_ = strided_slice->base_.thread_nr_; + return strided_slice->base_.env_->ParallelLaunch(strided_slice->base_.env_->thread_pool_, ParallelReshape, &reshape, + strided_slice->base_.thread_nr_); + } + return NNACL_OK; +} + +bool StridedSliceMatchFastPattern(StridedSliceStruct *strided_slice) { + // This function is seeking if that the number of only one dimension + // is different between input and output. If so, we can do some trick. + // Example 1: + // input shape info: [1, 80, 46, 40] + // output shape info: [1, 80, 20, 40] + // Example 2: + // input shape info: [1, 46, 40] + // output shape info: [1, 20, 40] + TensorC *in_tensor = strided_slice->base_.in_[FIRST_INPUT]; + TensorC *out_tensor = strided_slice->base_.out_[OUTPUT_INDEX]; + if (in_tensor->shape_size_ != out_tensor->shape_size_) { + return false; + } + + int axis_list[MAX_SHAPE_SIZE]; + int axis_list_size = 0; + for (size_t i = 0; i < in_tensor->shape_size_; i++) { + if (in_tensor->shape_[i] != out_tensor->shape_[i]) { + axis_list[axis_list_size++] = (int)i; + } + } + if (axis_list_size == 1) { + strided_slice->split_axis_ = axis_list[Index0]; + return true; + } + return false; +} + +void StridedSliceInitFastRunParam(StridedSliceStruct *strided_slice) { + TensorC *input_tenspr = strided_slice->base_.in_[FIRST_INPUT]; + int *in_shape = input_tenspr->shape_; + int *out_shape = strided_slice->base_.out_[OUTPUT_INDEX]->shape_; + + // reset && cal inner, outer + strided_slice->outer_ = 1; + strided_slice->inner_ = 1; + for (int i = 0; i < strided_slice->split_axis_; ++i) { + strided_slice->outer_ *= (size_t)in_shape[i]; + } + for (size_t i = (size_t)strided_slice->split_axis_ + 1; i < input_tenspr->shape_size_; i++) { + strided_slice->inner_ *= (size_t)in_shape[i]; + } + + if (strided_slice->outer_ == 1) { + strided_slice->parallel_on_split_axis_ = true; + strided_slice->parallel_on_outer_ = false; + } else { + strided_slice->parallel_on_split_axis_ = false; + strided_slice->parallel_on_outer_ = true; + } + + strided_slice->base_.thread_nr_ = strided_slice->base_.UpdateThread( + TC_TYPE(PrimType_StridedSlice, strided_slice->parallel_on_outer_), 1, 1, + NNACLGetElementNum(strided_slice->base_.out_[OUTPUT_INDEX]), strided_slice->base_.thread_nr_); + + strided_slice->cal_num_per_thread_ = + strided_slice->parallel_on_split_axis_ + ? UP_DIV(out_shape[strided_slice->split_axis_], strided_slice->base_.thread_nr_) + : UP_DIV((int)strided_slice->outer_, strided_slice->base_.thread_nr_); +} + +int StridedSliceResize(KernelBase *self) { + StridedSliceStruct *strided_slice = (StridedSliceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(strided_slice); + + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->shape_size_ > MAX_SHAPE_SIZE, NNACL_STRIDED_SLICE_INVALID_SHAPE_SIZE); + + StridedSliceParameter *param = (StridedSliceParameter *)self->param_; + memcpy(strided_slice->begins_, param->begins_, MAX_SHAPE_SIZE * sizeof(int)); + memcpy(strided_slice->ends_, param->ends_, MAX_SHAPE_SIZE * sizeof(int)); + memcpy(strided_slice->in_shape_, param->in_shape_, MAX_SHAPE_SIZE * sizeof(int)); + memcpy(strided_slice->strides_, param->strides_, MAX_SHAPE_SIZE * sizeof(int)); + strided_slice->in_shape_size_ = param->in_shape_length_; + + strided_slice->soft_copy_mode_ = StridedSliceMatchInOutShapeEqualPattern(strided_slice); + strided_slice->fast_run_ = StridedSliceMatchFastPattern(strided_slice); + if (strided_slice->fast_run_) { + StridedSliceInitFastRunParam(strided_slice); + } + + if (strided_slice->soft_copy_mode_ == false && strided_slice->fast_run_ == false) { + return PadStridedSliceParameterTo8D(strided_slice); + } + + return NNACL_OK; +} + +int StridedSliceCompute(KernelBase *self) { + StridedSliceStruct *strided_slice = (StridedSliceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(strided_slice); + + if (strided_slice->soft_copy_mode_) { + return StridedSliceSoftCopyInputToOutput(strided_slice); + } + if (strided_slice->fast_run_) { + return StridedSliceFastRun(strided_slice); + } + + return DoStridedSliceIn8D(self->in_[FIRST_INPUT]->data_, self->out_[OUTPUT_INDEX]->data_, strided_slice); +} + +KernelBase *CreateStridedSlice(OpParameter *param, int data_type) { + StridedSliceStruct *strided_slice = (StridedSliceStruct *)malloc(sizeof(StridedSliceStruct)); + NNACL_CHECK_NULL_RETURN_NULL(strided_slice); + strided_slice->data_type_ = data_type; + strided_slice->base_.Release = DefaultRelease; + strided_slice->base_.Prepare = DefaultPrepare1In1Out; + strided_slice->base_.Resize = StridedSliceResize; + strided_slice->base_.Compute = StridedSliceCompute; + return (KernelBase *)strided_slice; +} + +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeFloat32, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeFloat16, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt64, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt32, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt8, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeBool, CreateStridedSlice) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/strided_slice.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/strided_slice.h new file mode 100644 index 0000000000000000000000000000000000000000..5fe3fc80c04a71883bb7c834551f446995d9ee7c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/strided_slice.h @@ -0,0 +1,47 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_STRIDED_SLICE_H_ +#define NNACL_KERNEL_STRIDED_SLICE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct StridedSliceStruct { + KernelBase base_; + TypeIdC data_type_; + bool fast_run_; + bool soft_copy_mode_; + bool parallel_on_outer_; + bool parallel_on_split_axis_; + + int split_axis_; + int in_shape_size_; + int begins_[MAX_SHAPE_SIZE]; + int ends_[MAX_SHAPE_SIZE]; + int strides_[MAX_SHAPE_SIZE]; + int in_shape_[MAX_SHAPE_SIZE]; + + size_t inner_; + size_t outer_; + size_t inner_size_; + int cal_num_per_thread_; +} StridedSliceStruct; + +KernelBase *CreateStridedSlice(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_STRIDED_SLICE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tile.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tile.c new file mode 100644 index 0000000000000000000000000000000000000000..18e339503b89ddff620151a2db03400b942c1123 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tile.c @@ -0,0 +1,182 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/tile.h" +#include "nnacl_c/tile_parameter.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/tile_base.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +#define kDoubleInputsSize 2 + +int TileDoubleInputScenes(TileStruct *tile) { + TensorC *t = tile->base_.in_[SECOND_INPUT]; + if (t->data_ == NULL) { + tile->resize_done_ = false; + return NNACL_OK; + } + + NNACL_CHECK_FALSE(NNACLGetElementNum(t) > (int)tile->base_.in_[FIRST_INPUT]->shape_size_, + NNACL_TILE_SECOND_INPUT_NUM_INVALID); + NNACL_CHECK_FALSE(t->data_type_ != kNumberTypeInt && t->data_type_ != kNumberTypeInt32, + NNACL_TILE_SECOND_INPUT_DATA_TYPE_INVALID); + + int *input1_addr = (int *)(t->data_); + for (int i = 0; i < NNACLGetElementNum(t); ++i) { + NNACL_CHECK_FALSE(input1_addr[i] <= 0, NNACL_TILE_SECOND_INPUT_VALUE_INVALID); + tile->dims_[i] = i; + tile->multiples_[i] = input1_addr[i]; + } + return NNACL_OK; +} + +int SimpleTileImpl(TileStruct *tile, int task_id) { + NNACL_CHECK_ZERO_RETURN_ERR(tile->base_.thread_nr_); + size_t unit = UP_DIV(tile->fast_outer_size_, (size_t)tile->base_.thread_nr_); + if (unit == 0 && task_id > 0) { + return NNACL_OK; + } + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(unit, (size_t)task_id), NNACL_ERR); + size_t begin = unit * (size_t)(task_id); + size_t end = MSMIN(begin + unit, tile->fast_outer_size_); + TileSimple(tile->input_addr_, tile->output_addr_, begin, end, tile); + return NNACL_OK; +} + +int SimpleTile(void *cdata, int task_id, float l, float r) { + TileStruct *tile = (TileStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(tile); + return SimpleTileImpl(tile, task_id); +} + +int TileFillOneDimTileParam(TileStruct *tile) { + // check if tile exact one dim + int large_one_multiple_count = 0; + int multiple = 0; + int mul_index = 0; + + for (int i = 0; i < tile->in_dim_; ++i) { + if (tile->multiples_[i] > 1) { + large_one_multiple_count++; + multiple = tile->multiples_[i]; + mul_index = i; + } + } + tile->one_dim_tile_ = large_one_multiple_count == 1; + if (tile->one_dim_tile_) { + tile->fast_multiple_ = (size_t)multiple; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(tile->in_shape_[mul_index], tile->in_strides_[mul_index]), NNACL_ERR); + tile->fast_stride_ = (size_t)(tile->in_shape_[mul_index] * tile->in_strides_[mul_index]); + NNACL_CHECK_FALSE(tile->fast_stride_ < 1, NNACL_TILE_INPUT_SHAPE_INVALID); + tile->fast_outer_size_ = (size_t)NNACLGetElementNum(tile->base_.in_[FIRST_INPUT]) / tile->fast_stride_; + } + tile->resize_done_ = true; + return NNACL_OK; +} + +int TileResize(struct KernelBase *self) { + TileStruct *tile = (TileStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(tile); + TileParameter *param = (TileParameter *)(self->param_); + NNACL_CHECK_NULL_RETURN_ERR(tile); + + tile->dims_size_ = param->dims_size_; + for (int i = 0; i < MAX_SHAPE_SIZE; i++) { + tile->dims_[i] = param->dims_[i]; + tile->multiples_[i] = param->multiples_[i]; + } + + if (self->in_size_ == kDoubleInputsSize) { + int ret = TileDoubleInputScenes(tile); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + } + + TensorC *input = self->in_[0]; + TensorC *output = self->out_[0]; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + + tile->in_dim_ = (int)input->shape_size_; + NNACL_CHECK_TRUE_RET(tile->in_dim_ > 0 && tile->in_dim_ <= MAX_SHAPE_SIZE, NNACL_TILE_INPUT_SHAPE_INVALID); + NNACL_CHECK_FALSE((int)output->shape_size_ < tile->in_dim_, NNACL_TILE_INPUT_SHAPE_INVALID); + + for (int i = 0; i < tile->in_dim_; ++i) { + tile->in_shape_[i] = input->shape_[i]; + tile->out_shape_[i] = output->shape_[i]; + } + + ComputeStrides(tile->in_shape_, tile->in_strides_, tile->in_dim_); + ComputeStrides(tile->out_shape_, tile->out_strides_, tile->in_dim_); + + for (size_t i = 0; i < tile->dims_size_; i++) { + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(tile->multiples_[i], tile->in_shape_[i]), NNACL_ERRCODE_MUL_OVERFLOW); + int ele_num = tile->multiples_[i] * tile->in_shape_[i] - 1; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(tile->out_strides_[i], ele_num), NNACL_ERRCODE_MUL_OVERFLOW); + } + + int ret = TileFillOneDimTileParam(tile); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + + if (tile->one_dim_tile_) { + self->thread_nr_ = + self->UpdateThread(TC_TYPE(PrimType_TileFusion, 0), 0, 0, tile->fast_outer_size_, self->thread_nr_); + } + return NNACL_OK; +} + +int TileCompute(struct KernelBase *self) { + TileStruct *tile = (TileStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(tile); + tile->input_addr_ = (uint8_t *)(self->in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(tile->input_addr_); + tile->output_addr_ = (uint8_t *)(self->out_[OUTPUT_INDEX]->data_); + NNACL_CHECK_NULL_RETURN_ERR(tile->output_addr_); + + if (!tile->resize_done_) { + int ret = TileResize(self); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + NNACL_CHECK_FALSE(tile->resize_done_ == false, NNACL_TILE_RESIZE_IN_RUNTIME_FAILED); + } + + tile->data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + NNACL_CHECK_TRUE_RET(tile->data_size_ > 0, NNACL_UNSUPPORTED_DATA_TYPE); + + if (tile->one_dim_tile_) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, SimpleTile, self, self->thread_nr_); + } + + Tile(tile->input_addr_, tile->output_addr_, tile); + return NNACL_OK; +} + +KernelBase *CreateTile(OpParameter *param, int data_type) { + TileStruct *tile = (TileStruct *)malloc(sizeof(TileStruct)); + NNACL_CHECK_NULL_RETURN_NULL(tile); + tile->resize_done_ = false; + tile->base_.Release = DefaultRelease; + tile->base_.Prepare = DefaultPrepare1In1Out; + tile->base_.Resize = TileResize; + tile->base_.Compute = TileCompute; + return (KernelBase *)tile; +} + +REG_KERNEL_CREATOR(PrimType_TileFusion, kNumberTypeInt32, CreateTile) +REG_KERNEL_CREATOR(PrimType_TileFusion, kNumberTypeFloat32, CreateTile) +REG_KERNEL_CREATOR(PrimType_TileFusion, kNumberTypeFloat16, CreateTile) +REG_KERNEL_CREATOR(PrimType_TileFusion, kNumberTypeBool, CreateTile) +REG_KERNEL_CREATOR(PrimType_TileFusion, kNumberTypeUInt8, CreateTile) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tile.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tile.h new file mode 100644 index 0000000000000000000000000000000000000000..e71000043bd136e3be3c58f092382b924c9e5779 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tile.h @@ -0,0 +1,48 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_TILE_H_ +#define NNACL_KERNEL_TILE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct TileStruct { + KernelBase base_; + bool one_dim_tile_; + bool resize_done_; + int dims_[MAX_SHAPE_SIZE]; + size_t dims_size_; + uint8_t *input_addr_; + uint8_t *output_addr_; + + int multiples_[MAX_SHAPE_SIZE]; + int in_shape_[MAX_SHAPE_SIZE]; + int out_shape_[MAX_SHAPE_SIZE]; + int in_strides_[MAX_SHAPE_SIZE]; + int out_strides_[MAX_SHAPE_SIZE]; + + int in_dim_; + size_t data_size_; + size_t fast_outer_size_; + size_t fast_stride_; + size_t fast_multiple_; +} TileStruct; + +KernelBase *CreateTile(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_TILE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/transpose.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/transpose.c new file mode 100644 index 0000000000000000000000000000000000000000..470ea301dc49ea19e23e3d9c527a9865be120440 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/transpose.c @@ -0,0 +1,358 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/transpose.h" +#include "nnacl_c/fp32/transpose_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/fp16/transpose_fp16.h" +#endif + +/* opt perm: { 0, 2, 1 } */ +#define OPT_PERM_0 0 +#define OPT_PERM_1 2 +#define OPT_PERM_2 1 + +int TransposeComputeinMultiThread(TransposeStruct *transpose, int task_id) { + void *in = transpose->base_.in_[FIRST_INPUT]->data_; + void *out = transpose->base_.out_[OUTPUT_INDEX]->data_; + + if (transpose->opt_run_) { + transpose->nhwc2nchw_(in, out, transpose->opt_perm_[FIRST_INPUT], transpose->opt_perm_[SECOND_INPUT], + transpose->opt_perm_[THIRD_INPUT], task_id, transpose->base_.thread_nr_); + } else { + transpose->optimize_(in, out, transpose->out_shape_, transpose->perm_, transpose->strides_, transpose->out_strides_, + transpose->num_axes_, task_id, transpose->base_.thread_nr_); + } + return NNACL_OK; +} + +int TransposeComputeinSingleThread(TransposeStruct *transpose) { + if (transpose->opt_run_ || transpose->num_axes_ > DIMENSION_6D) { + return TransposeComputeinMultiThread(transpose, 0); + } + + void *in = transpose->base_.in_[FIRST_INPUT]->data_; + void *out = transpose->base_.out_[OUTPUT_INDEX]->data_; + return transpose->compute_(in, out, transpose->out_shape_, transpose->perm_, transpose->strides_, + transpose->out_strides_, transpose->data_num_, transpose->num_axes_); +} + +int ResetTransposeStatus(TransposeStruct *transpose) { + transpose->num_axes_ = 0; + if (transpose->base_.in_size_ == C2NUM) { + transpose->num_axes_ = NNACLGetElementNum(transpose->base_.in_[SECOND_INPUT]); + transpose->perm_size_ = transpose->base_.in_[SECOND_INPUT]->shape_[0]; + } + + TensorC *in_tensor = transpose->base_.in_[FIRST_INPUT]; + if (in_tensor->shape_size_ > MAX_TRANSPOSE_DIM_SIZE) { + return NNACL_TRANSPOSE_INSHAPE_OUT_OF_RANGE; + } + + int trans_nd[MAX_TRANSPOSE_DIM_SIZE] = {0, 2, 1}; + int *perm_data; + if ((int)in_tensor->shape_size_ != transpose->num_axes_) { + perm_data = trans_nd; + if (in_tensor->shape_size_ == Num3 && transpose->num_axes_ == Num4) { + transpose->num_axes_ = Num3; + } + if (transpose->num_axes_ == 0) { + for (size_t i = 0; i < in_tensor->shape_size_; ++i) { + trans_nd[i] = (int)in_tensor->shape_size_ - 1 - (int)i; + } + transpose->num_axes_ = (int)in_tensor->shape_size_; + } + } else { + NNACL_CHECK_TRUE_RET(transpose->base_.in_size_ == TWO_TENSOR, NNACL_TRANSPOSE_INPUT_TENSOR_NUM_INVALID); + TensorC *perm_tensor = transpose->base_.in_[SECOND_INPUT]; + if (perm_tensor->data_type_ != kNumberTypeInt32) { + return NNACL_TRANSPOSE_PERM_TENSOR_INVALID; + } + perm_data = (int *)(perm_tensor->data_); + NNACL_CHECK_NULL_RETURN_ERR(perm_data); + int ele_num = NNACLGetElementNum(perm_tensor); + for (int i = 0; i < ele_num; i++) { + for (int j = 0; j < ele_num; j++) { + if (i == perm_data[j]) { + break; + } + if (j == ele_num - 1) { + return NNACL_TRANSPOSE_PERM_TENSOR_VALUE_INVALID; + } + } + } + } + + NNACL_CHECK_TRUE_RET(transpose->num_axes_ <= MAX_TRANSPOSE_DIM_SIZE, NNACL_TRANSPOSE_PERM_DIMS_INVALID); + for (int i = 0; i < transpose->num_axes_; ++i) { + transpose->perm_[i] = perm_data[i]; + } + return NNACL_OK; +} + +void TransposeFreeSegments(int **segments, int segments_size) { + for (int i = 0; i < segments_size; i++) { + if (segments[i] != NULL) { + free(segments[i]); + segments[i] = NULL; + } + } +} + +int TransposeOptimizeShape(TransposeStruct *transpose) { + TensorC *in_tensor = transpose->base_.in_[FIRST_INPUT]; + int *in_shape = in_tensor->shape_; + + // first step, delete dimension where value is 1. + int in_shape_temp[MAX_TRANSPOSE_DIM_SIZE] = {0}; + int in_shape_temp_size = 0; + int perm_diff[MAX_TRANSPOSE_DIM_SIZE] = {0}; + for (size_t i = 0; i < in_tensor->shape_size_; ++i) { + if (in_shape[i] != 1) { + in_shape_temp[in_shape_temp_size++] = in_shape[i]; + continue; + } + for (size_t j = 0; j < in_tensor->shape_size_; ++j) { + if (transpose->perm_[j] < (int)(i)) { + continue; + } + if (transpose->perm_[j] == (int)(i)) { + perm_diff[j] = (int)(i) + 1; + } else { + perm_diff[j] += 1; + } + } + } + + int perm_temp[MAX_TRANSPOSE_DIM_SIZE] = {0}; + int perm_temp_size = 0; + for (size_t i = 0; i < in_tensor->shape_size_; ++i) { + int diff = transpose->perm_[i] - perm_diff[i]; + if (diff < 0) { + continue; + } + perm_temp[perm_temp_size++] = diff; + } + + NNACL_CHECK_TRUE_RET(in_shape_temp_size == perm_temp_size, NNACL_TRANSPOSE_PERM_DELETE_DIMENSION_FAILED); + + // second step, fuse continuous dimension.; + int axis_num = in_shape_temp_size; + int *segments[MAX_TRANSPOSE_DIM_SIZE]; + int segment_sizes[MAX_TRANSPOSE_DIM_SIZE]; + int segments_size = 0; + for (int i = 0; i < axis_num;) { + int segment[MAX_TRANSPOSE_DIM_SIZE]; + int segment_size = 0; + segment[segment_size++] = perm_temp[i]; + ++i; + for (; i < axis_num; ++i) { + if (perm_temp[i] - 1 != perm_temp[i - 1]) { + break; + } + segment[segment_size++] = perm_temp[i]; + } + + segments[segments_size] = malloc(segment_size * sizeof(int)); + if (segments[segments_size] == NULL) { + TransposeFreeSegments(segments, segments_size); + return NNACL_NULL_PTR; + } + memcpy(segments[segments_size], segment, segment_size * sizeof(int)); + segment_sizes[segments_size] = segment_size; + segments_size++; + } + + transpose->in_shape_size_ = segments_size; + transpose->perm_size_ = segments_size; + for (int i = 0; i < segments_size; i++) { + transpose->in_shape_[i] = 1; + transpose->perm_[i] = 0; + } + for (int i = 0; i < segments_size; ++i) { + for (int j = 0; j < segments_size; ++j) { + transpose->perm_[i] += (segments[j][FIRST_INPUT] < segments[i][FIRST_INPUT] ? 1 : 0); + } + for (int k = 0; k < segment_sizes[i]; ++k) { + transpose->in_shape_[transpose->perm_[i]] *= in_shape_temp[segments[i][k]]; + } + } + TransposeFreeSegments(segments, segments_size); + return NNACL_OK; +} + +void SetTransposeOptInfo(TransposeStruct *transpose) { + // now perm is [1, 0] or [0, 2, 1] + if (transpose->perm_size_ == C2NUM) { + transpose->opt_perm_[FIRST_INPUT] = 1; + transpose->opt_perm_[SECOND_INPUT] = transpose->in_shape_[FIRST_INPUT]; + transpose->opt_perm_[THIRD_INPUT] = transpose->in_shape_[transpose->in_shape_size_ - 1]; + } else { + transpose->opt_perm_[FIRST_INPUT] = transpose->in_shape_[FIRST_INPUT]; + transpose->opt_perm_[SECOND_INPUT] = transpose->in_shape_[SECOND_INPUT]; + transpose->opt_perm_[THIRD_INPUT] = transpose->in_shape_[transpose->in_shape_size_ - 1]; + } +} + +bool TransposeOpt(TransposeStruct *transpose) { + if (transpose->perm_size_ == DIMENSION_2D) { + return true; + } + if (transpose->perm_size_ == DIMENSION_3D && transpose->perm_[FIRST_INPUT] == OPT_PERM_0 && + transpose->perm_[SECOND_INPUT] == OPT_PERM_1 && transpose->perm_[THIRD_INPUT] == OPT_PERM_2) { + return true; + } + return false; +} + +int TransposeComputeOfflineInfo(TransposeStruct *transpose) { + transpose->num_axes_ = transpose->in_shape_size_; + NNACL_CHECK_TRUE_RET(transpose->num_axes_ >= DIMENSION_3D, NNACL_TRANSPOSE_INSHAPE_OUT_OF_RANGE); + + for (int i = 0; i < transpose->num_axes_; ++i) { + transpose->out_shape_[i] = transpose->in_shape_[transpose->perm_[i]]; + } + transpose->strides_[transpose->num_axes_ - 1] = 1; + transpose->out_strides_[transpose->num_axes_ - 1] = 1; + transpose->data_num_ = NNACLGetElementNum(transpose->base_.in_[FIRST_INPUT]); + for (int i = transpose->num_axes_ - 2; i >= 0; i--) { + transpose->strides_[i] = transpose->in_shape_[i + 1] * transpose->strides_[i + 1]; + transpose->out_strides_[i] = transpose->out_shape_[i + 1] * transpose->out_strides_[i + 1]; + } + return NNACL_OK; +} + +int TransposeCopyInputToOutput(TransposeStruct *transpose) { + TensorC *in_tensor = transpose->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + NNACL_CHECK_NULL_RETURN_ERR(in_tensor->data_); + TensorC *out_tensor = transpose->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + NNACL_CHECK_NULL_RETURN_ERR(out_tensor->data_); + + NNACL_CHECK_FALSE(NNACLGetSize(in_tensor) == 0, NNACL_TRANSPOSE_INPUT_TENSOR_VALUD_INVALID); + if (in_tensor->data_ != out_tensor->data_) { + (void)memcpy(out_tensor->data_, in_tensor->data_, NNACLGetSize(in_tensor)); + } + return NNACL_OK; +} + +int TransposeImpl(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + TransposeStruct *transpose = (TransposeStruct *)cdata; + return TransposeComputeinMultiThread(transpose, task_id); +} + +int TransposeCompute(struct KernelBase *self) { + TransposeStruct *transpose = (TransposeStruct *)self; + if (!transpose->is_valid_) { + return TransposeCopyInputToOutput(transpose); + } + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]->data_); + if (self->thread_nr_ == 1) { + return TransposeComputeinSingleThread(transpose); + } + return self->env_->ParallelLaunch(self->env_->thread_pool_, TransposeImpl, self, self->thread_nr_); +} + +int TransposeResize(struct KernelBase *self) { + TransposeStruct *transpose = (TransposeStruct *)self; + int ret = ResetTransposeStatus(transpose); + if (ret != NNACL_OK) { + return ret; + } + transpose->is_valid_ = (int)transpose->base_.in_[FIRST_INPUT]->shape_size_ == transpose->num_axes_ && + (int)transpose->base_.in_[FIRST_INPUT]->shape_size_ == transpose->perm_size_; + if (!transpose->is_valid_) { + return NNACL_OK; + } + + ret = TransposeOptimizeShape(transpose); + if (ret != NNACL_OK) { + return ret; + } + + transpose->is_valid_ = transpose->perm_size_ > DIMENSION_1D; + if (!transpose->is_valid_) { + return NNACL_OK; + } + + transpose->opt_run_ = TransposeOpt(transpose); + if (transpose->opt_run_) { + SetTransposeOptInfo(transpose); + return NNACL_OK; + } + + ret = TransposeComputeOfflineInfo(transpose); + if (ret != NNACL_OK) { + return ret; + } + + self->thread_nr_ = (!transpose->opt_run_ && transpose->num_axes_ <= DIMENSION_6D) ? 1 : self->thread_nr_; + return NNACL_OK; +} + +int TransposePrepare(struct KernelBase *self) { + int ret = DefaultPrepare1In1Out(self); + if (ret != NNACL_OK) { + return ret; + } + TransposeStruct *transpose = (TransposeStruct *)self; + TransposeParameter *param = (TransposeParameter *)transpose->base_.param_; + if (param->perm_size_ > INT32_MAX) { + return NNACL_TRANSPOSE_PERM_DIMS_INVALID; + } + transpose->perm_size_ = (int)param->perm_size_; + for (int i = 0; i < transpose->perm_size_; i++) { + transpose->perm_[i] = param->perm_[i]; + } + return NNACL_OK; +} + +KernelBase *CreateTranspose(OpParameter *param, int data_type) { + TransposeStruct *transpose = (TransposeStruct *)malloc(sizeof(TransposeStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(transpose); + transpose->nhwc2nchw_ = PackNHWCToNCHWFp32; + transpose->optimize_ = TransposeDimsFp32; + transpose->compute_ = DoTransposeFp32; + transpose->base_.Release = DefaultRelease; + transpose->base_.Prepare = TransposePrepare; + transpose->base_.Resize = TransposeResize; + transpose->base_.Compute = TransposeCompute; + if (data_type == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + transpose->nhwc2nchw_ = PackNHWCToNCHWFp16; + transpose->optimize_ = TransposeDimsFp16; + transpose->compute_ = DoTransposeFp16; +#else + free(transpose); + return NULL; +#endif + } + return (KernelBase *)transpose; +} + +REG_KERNEL_CREATOR(PrimType_Transpose, kNumberTypeFloat32, CreateTranspose) +REG_KERNEL_CREATOR(PrimType_Transpose, kNumberTypeFloat16, CreateTranspose) +REG_KERNEL_CREATOR(PrimType_Transpose, kNumberTypeInt32, CreateTranspose) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/transpose.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/transpose.h new file mode 100644 index 0000000000000000000000000000000000000000..ca8e7a0dd99eb7533f2792d7558904f5ed88b9de --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/transpose.h @@ -0,0 +1,49 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_TRANSPOSE_H_ +#define NNACL_KERNEL_TRANSPOSE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/transpose_parameter.h" + +typedef struct TransposeStruct { + KernelBase base_; + bool is_valid_; + int num_axes_; + int data_num_; + int perm_[MAX_TRANSPOSE_DIM_SIZE]; + int perm_size_; + int in_shape_[MAX_TRANSPOSE_DIM_SIZE]; /* after shape optimize */ + int in_shape_size_; + int out_shape_[MAX_TRANSPOSE_DIM_SIZE]; + int strides_[MAX_TRANSPOSE_DIM_SIZE]; + int out_strides_[MAX_TRANSPOSE_DIM_SIZE]; + + int opt_perm_[PERM_NUM_THREE]; // only valid when opt_run_ is true + bool opt_run_; // only true when perm is [1, 0] or [0, 2, 1] + + int (*compute_)(const void *src, void *dst, const int *out_shape, int *perm, int *strides, int *out_strides, + int data_size, int num_axes); + void (*nhwc2nchw_)(const void *src, void *dst, int b, int hw, int c, int task_id, int thread); + void (*optimize_)(const void *src, void *dst, const int *out_shape, int *perm, int *strides, int *out_strides, + int num_axes, int task_id, int thread); +} TransposeStruct; + +KernelBase *CreateTranspose(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_TRANSPOSE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tril.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tril.c new file mode 100644 index 0000000000000000000000000000000000000000..4b64aada0eed70939d38570a1747fd006b0223f6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tril.c @@ -0,0 +1,89 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/tril.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/fp32/triu_tril_fp32.h" + +int TrilCompute(KernelBase *self) { + TrilStruct *tril = (TrilStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(tril); + + int ret = TriuTrilGetKValue(self, &tril->k_); + if (ret != NNACL_OK) { + return ret; + } + + int64_t mul, height, width; + ret = TriuTrilGetCalculateNum(self, &mul, &height, &width); + if (ret != NNACL_OK) { + return ret; + } + + void *src_data = self->in_[FIRST_INPUT]->data_; + void *dst_data = self->out_[OUTPUT_INDEX]->data_; + int type_size = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + NNACL_CHECK_ZERO_RETURN_ERR(type_size); + + switch (type_size) { + case sizeof(int64_t): { + TrilByte8(src_data, dst_data, tril->k_, height, width, mul); + break; + } + case sizeof(int32_t): { + TrilByte4(src_data, dst_data, tril->k_, height, width, mul); + break; + } + case sizeof(int16_t): { + TrilByte2(src_data, dst_data, tril->k_, height, width, mul); + break; + } + case sizeof(int8_t): { + TrilByte1(src_data, dst_data, tril->k_, height, width, mul); + break; + } + default: + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +KernelBase *CreateTril(OpParameter *param, int data_type) { + TrilStruct *tril = (TrilStruct *)malloc(sizeof(TrilStruct)); + NNACL_CHECK_NULL_RETURN_NULL(tril); + tril->base_.Release = DefaultRelease; + tril->base_.Prepare = DefaultPrepare1In1Out; + tril->base_.Resize = DefaultResize; + tril->base_.Compute = TrilCompute; + return (KernelBase *)tril; +} + +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeDouble, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat64, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat32, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat16, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt64, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt32, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt16, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt8, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt64, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt32, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt16, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt8, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeBool, CreateTril) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tril.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tril.h new file mode 100644 index 0000000000000000000000000000000000000000..189325b610076221d5ea47638bd26b2fb428d847 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tril.h @@ -0,0 +1,32 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_TRIL_H_ +#define NNACL_KERNEL_TRIL_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct TrilStruct { + KernelBase base_; + int64_t k_; +} TrilStruct; + +KernelBase *CreateTril(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_TRIL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/triu.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/triu.c new file mode 100644 index 0000000000000000000000000000000000000000..a3121f23fb89418571a3d98bdcfa24b95a4921cf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/triu.c @@ -0,0 +1,89 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/triu.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/fp32/triu_tril_fp32.h" + +int TriuCompute(KernelBase *self) { + TriuStruct *triu = (TriuStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(triu); + + void *src_data = self->in_[FIRST_INPUT]->data_; + void *dst_data = self->out_[OUTPUT_INDEX]->data_; + int type_size = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + NNACL_CHECK_ZERO_RETURN_ERR(type_size); + + int ret = TriuTrilGetKValue(self, &triu->k_); + if (ret != NNACL_OK) { + return ret; + } + + int64_t mul, height, width; + ret = TriuTrilGetCalculateNum(self, &mul, &height, &width); + if (ret != NNACL_OK) { + return ret; + } + + switch (type_size) { + case sizeof(int64_t): { + TriuByte8(src_data, dst_data, triu->k_, height, width, mul); + break; + } + case sizeof(int32_t): { + TriuByte4(src_data, dst_data, triu->k_, height, width, mul); + break; + } + case sizeof(int16_t): { + TriuByte2(src_data, dst_data, triu->k_, height, width, mul); + break; + } + case sizeof(int8_t): { + TriuByte1(src_data, dst_data, triu->k_, height, width, mul); + break; + } + default: + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +KernelBase *CreateTriu(OpParameter *param, int data_type) { + TriuStruct *triu = (TriuStruct *)malloc(sizeof(TriuStruct)); + NNACL_CHECK_NULL_RETURN_NULL(triu); + triu->base_.Release = DefaultRelease; + triu->base_.Prepare = DefaultPrepare1In1Out; + triu->base_.Resize = DefaultResize; + triu->base_.Compute = TriuCompute; + return (KernelBase *)triu; +} + +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeDouble, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeFloat, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeFloat64, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeFloat32, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeFloat16, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeInt, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeInt64, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeInt32, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeInt16, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeInt8, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeUInt64, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeUInt32, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeUInt16, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeUInt8, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeBool, CreateTriu) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/triu.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/triu.h new file mode 100644 index 0000000000000000000000000000000000000000..e710cb76afdbc47ea7239bf740e385cb75aac366 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/triu.h @@ -0,0 +1,32 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_TRIU_H_ +#define NNACL_KERNEL_TRIU_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct TriuStruct { + KernelBase base_; + int64_t k_; +} TriuStruct; + +KernelBase *CreateTriu(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_TRIU_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/unique.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/unique.c new file mode 100644 index 0000000000000000000000000000000000000000..ee9f8218a57e1293288d0a63fea8f3ca7dde82f0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/unique.c @@ -0,0 +1,66 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/unique.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/unique_fp32.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/unique_fp16.h" +#endif + +int UniqueCompute(KernelBase *self) { + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output0 = self->out_[Index0]; + NNACL_CHECK_NULL_RETURN_ERR(output0); + TensorC *output1 = self->out_[Index1]; + NNACL_CHECK_NULL_RETURN_ERR(output1); + + int num = NNACLGetElementNum(input); + int output0_len = 0; + +#ifdef ENABLE_FP16 + if (input->data_type_ == kNumberTypeFloat16) { + UniqueFp16((float16_t *)input->data_, num, (float16_t *)output0->data_, &output0_len, (int *)output1->data_); + } +#endif + if (input->data_type_ == kNumberTypeInt32) { + UniqueInt((int *)input->data_, num, (int *)output0->data_, &output0_len, (int *)output1->data_); + } + if (input->data_type_ == kNumberTypeFloat32) { + Unique((float *)input->data_, num, (float *)output0->data_, &output0_len, (int *)output1->data_); + } + + output0->shape_changed_ = (output0->shape_[output0->shape_size_ - 1] != output0_len); + output0->shape_[output0->shape_size_ - 1] = output0_len; + return NNACL_OK; +} + +KernelBase *CreateUnique(OpParameter *param, int data_type) { + UniqueStruct *unique = (UniqueStruct *)malloc(sizeof(UniqueStruct)); + NNACL_CHECK_NULL_RETURN_NULL(unique); + unique->data_type_ = data_type; + unique->base_.Release = DefaultRelease; + unique->base_.Prepare = DefaultPrepare1In2Out; + unique->base_.Resize = DefaultResize; + unique->base_.Compute = UniqueCompute; + return (KernelBase *)unique; +} + +REG_KERNEL_CREATOR(PrimType_Unique, kNumberTypeInt32, CreateUnique) +REG_KERNEL_CREATOR(PrimType_Unique, kNumberTypeFloat32, CreateUnique) +REG_KERNEL_CREATOR(PrimType_Unique, kNumberTypeFloat16, CreateUnique) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/unique.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/unique.h new file mode 100644 index 0000000000000000000000000000000000000000..3083e92566742e32f8d85f23ac82a949dd41766a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/unique.h @@ -0,0 +1,32 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_UNIQUE_H_ +#define NNACL_KERNEL_UNIQUE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct UniqueStruct { + KernelBase base_; + int data_type_; +} UniqueStruct; + +KernelBase *CreateUnique(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_UNIQUE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/where.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/where.c new file mode 100644 index 0000000000000000000000000000000000000000..15229018f0d7d064b11b4f115c482abbfd58d13b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/where.c @@ -0,0 +1,298 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/where.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/fp32/where_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/where_fp16.h" +#endif +#include "nnacl_c/base/broadcast_to.h" + +int WhereExcuteFp16(WhereStruct *where, int task_id) { +#ifdef ENABLE_FP16 + WhereWithTripleInputsFp16((float16_t *)where->x_, (float16_t *)where->y_, (float16_t *)where->output_, &where->args_, + task_id, where->base_.thread_nr_); +#endif + return NNACL_OK; +} + +int WhereExcute(WhereStruct *where, int task_id) { + WhereWithTripleInputs((float *)where->x_, (float *)where->y_, (float *)where->output_, &where->args_, task_id, + where->base_.thread_nr_); + return NNACL_OK; +} + +int WhereRun(void *cdata, int task_id, float l, float r) { + WhereStruct *where = (WhereStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(where); + + NNACL_CHECK_NULL_RETURN_ERR(where->x_); + NNACL_CHECK_NULL_RETURN_ERR(where->y_); + NNACL_CHECK_NULL_RETURN_ERR(where->output_); + NNACL_CHECK_NULL_RETURN_ERR(where->args_.condition_); + + if (where->data_type_ == kNumberTypeFloat16) { + return WhereExcuteFp16(where, task_id); + } + return WhereExcute(where, task_id); +} + +int WhereRunWithSingleInput(WhereStruct *where) { + TensorC *input = where->base_.in_[FIRST_INPUT]; + int32_t *int32_condition = NULL; + float *fp32_condition = NULL; + bool *bool_condition = NULL; + switch (where->data_type_) { + case kNumberTypeInt32: + int32_condition = (int32_t *)input->data_; + NNACL_CHECK_NULL_RETURN_ERR(int32_condition); + break; + case kNumberTypeFloat32: + fp32_condition = (float *)input->data_; + NNACL_CHECK_NULL_RETURN_ERR(fp32_condition); + break; + case kNumberTypeBool: + bool_condition = (bool *)input->data_; + NNACL_CHECK_NULL_RETURN_ERR(bool_condition); + break; + default: + return NNACL_WHERE_CONDITION_DATA_TYPE_ERROR; + } + WhereArgs *where_args = &where->args_; + where_args->condition_num_ = NNACLGetElementNum(input); + where_args->rank_ = input->shape_size_; + int strides[MAX_SHAPE_SIZE]; + ComputeStrides(input->shape_, strides, where_args->rank_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(where_args->condition_num_, where_args->rank_, NNACL_ERR); + int data_num_int = where_args->condition_num_ * where_args->rank_; + NNACL_CHECK_TRUE_RET(data_num_int >= 0, NNACL_ERR); + size_t result_size = (size_t)data_num_int * sizeof(int32_t); + int32_t *result = where->base_.env_->Alloc(where->base_.env_->allocator_, result_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(result); + + int result_index = 0; + int true_num = 0; + for (int index = 0; index < where_args->condition_num_; index++) { + bool condition = false; + switch (where->data_type_) { + case kNumberTypeInt32: + condition = (bool)int32_condition[index]; + break; + case kNumberTypeFloat32: + condition = (bool)fp32_condition[index]; + break; + case kNumberTypeBool: + condition = (bool)bool_condition[index]; + break; + default: + return NNACL_WHERE_CONDITION_DATA_TYPE_ERROR; + } + if (condition) { + true_num++; + int dim = index; + for (int j = 0; j < where_args->rank_; j++) { + NNACL_CHECK_ZERO_RETURN_ERR(strides[j]); + result[result_index++] = dim / strides[j]; + dim %= strides[j]; + } + } + } + + TensorC *output = where->base_.out_[OUTPUT_INDEX]; + if (output->data_ != NULL) { + /* the data should be nullptr */ + where->base_.env_->Free(where->base_.env_->allocator_, output->data_); + } + int output_shape[] = {true_num, where_args->rank_}; + output->shape_changed_ = ShapeEqual(output->shape_, output->shape_size_, output_shape, Num2); + output->shape_size_ = Num2; + memcpy(output->shape_, output_shape, Num2 * sizeof(int)); + + if (true_num > 0) { + output->data_ = result; + } + return NNACL_OK; +} + +int WhereBroadCastForInput(WhereStruct *where, TensorC *condition, TensorC *x, TensorC *y, + void **condition_broadcast_buf, void **x_broadcast_buf, void **y_broadcast_buf, + TensorC *output) { + size_t broad_cast_buf_size = NNACLGetElementNum(output); + if (output->data_type_ == kNumberTypeFloat32) { + broad_cast_buf_size *= sizeof(float); + } else { + return NNACL_WHERE_BROAD_CAST_FAILED; + } + BroadcastShapeInfo condition_info; + condition_info.input_shape_size_ = condition->shape_size_; + condition_info.output_shape_size_ = output->shape_size_; + memcpy(condition_info.input_shape_, condition->shape_, condition->shape_size_ * sizeof(int)); + memcpy(condition_info.output_shape_, output->shape_, output->shape_size_ * sizeof(int)); + + BroadcastShapeInfo x_info; + x_info.input_shape_size_ = x->shape_size_; + x_info.output_shape_size_ = output->shape_size_; + memcpy(x_info.input_shape_, x->shape_, x->shape_size_ * sizeof(int)); + memcpy(x_info.output_shape_, output->shape_, output->shape_size_ * sizeof(int)); + + BroadcastShapeInfo y_info; + y_info.input_shape_size_ = y->shape_size_; + y_info.output_shape_size_ = output->shape_size_; + memcpy(y_info.input_shape_, y->shape_, y->shape_size_ * sizeof(int)); + memcpy(y_info.output_shape_, output->shape_, output->shape_size_ * sizeof(int)); + + *condition_broadcast_buf = where->base_.env_->Alloc(where->base_.env_->allocator_, broad_cast_buf_size); + if (*condition_broadcast_buf == NULL) { + return NNACL_WHERE_BROAD_CAST_FAILED; + } + BroadcastToSize8(condition->data_, &condition_info, *condition_broadcast_buf); + + *x_broadcast_buf = where->base_.env_->Alloc(where->base_.env_->allocator_, broad_cast_buf_size); + if (*x_broadcast_buf == NULL) { + where->base_.env_->Free(where->base_.env_->allocator_, *condition_broadcast_buf); + return NNACL_WHERE_BROAD_CAST_FAILED; + } + BroadcastToSize32(x->data_, &x_info, *x_broadcast_buf); + + *y_broadcast_buf = where->base_.env_->Alloc(where->base_.env_->allocator_, broad_cast_buf_size); + if (*y_broadcast_buf == NULL) { + where->base_.env_->Free(where->base_.env_->allocator_, *condition_broadcast_buf); + where->base_.env_->Free(where->base_.env_->allocator_, *x_broadcast_buf); + return NNACL_WHERE_BROAD_CAST_FAILED; + } + BroadcastToSize32(y->data_, &y_info, *y_broadcast_buf); + return NNACL_OK; +} + +int WhereRunWithTripleInputs(WhereStruct *where) { + TensorC *condition = where->base_.in_[Index0]; + NNACL_CHECK_NULL_RETURN_ERR(condition); + TensorC *x = where->base_.in_[Index1]; + NNACL_CHECK_NULL_RETURN_ERR(x); + TensorC *y = where->base_.in_[Index2]; + NNACL_CHECK_NULL_RETURN_ERR(y); + TensorC *output = where->base_.out_[Index0]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + int condition_nums = NNACLGetElementNum(condition); + int x_num = NNACLGetElementNum(x); + int y_num = NNACLGetElementNum(y); + int out_num = NNACLGetElementNum(output); + int num_max = condition_nums > x_num ? condition_nums : (x_num > y_num ? x_num : y_num); + + where->x_ = x->data_; + where->y_ = y->data_; + where->output_ = output->data_; + + WhereArgs *args = &where->args_; + args->condition_ = (bool *)condition->data_; + args->condition_num_ = condition_nums; + args->x_num_ = x_num; + args->y_num_ = y_num; + args->max_num_ = num_max; + + void *condition_broadcast_buf = NULL; + void *x_broadcast_buf = NULL; + void *y_broadcast_buf = NULL; + + if (out_num < num_max) { + return NNACL_WHERE_INVALID_OUT_NUM; + } + if (((condition_nums != 1) && (condition_nums != num_max)) || ((x_num != 1) && (x_num != num_max)) || + ((y_num != 1) && (y_num != num_max))) { + if (condition_nums != NNACLGetElementNum(y) && condition->shape_size_ != y->shape_size_) { + int ret = WhereBroadCastForInput(where, condition, x, y, &condition_broadcast_buf, &x_broadcast_buf, + &y_broadcast_buf, output); + if (ret != NNACL_OK) { + return NNACL_WHERE_BROAD_CAST_FAILED; + } + int max_num = NNACLGetElementNum(output); + args->condition_ = (bool *)condition_broadcast_buf; + where->x_ = x_broadcast_buf; + where->y_ = y_broadcast_buf; + where->output_ = output->data_; + args->condition_num_ = max_num; + args->x_num_ = max_num; + args->y_num_ = max_num; + args->max_num_ = max_num; + } else { + /* The length of three inputs are not equal to 1 or length of output, which is unacceptable */ + return NNACL_WHERE_CONDITION_NUM_INVALID; + } + } + if (num_max <= 0) { + /* Error, inputs' length are zero */ + return NNACL_WHERE_NUM_MAX_INVALID; + } + int ret = + where->base_.env_->ParallelLaunch(where->base_.env_->thread_pool_, WhereRun, where, where->base_.thread_nr_); + if (condition_broadcast_buf != NULL) { + where->base_.env_->Free(where->base_.env_->allocator_, condition_broadcast_buf); + condition_broadcast_buf = NULL; + } + if (x_broadcast_buf != NULL) { + where->base_.env_->Free(where->base_.env_->allocator_, x_broadcast_buf); + x_broadcast_buf = NULL; + } + if (y_broadcast_buf != NULL) { + where->base_.env_->Free(where->base_.env_->allocator_, y_broadcast_buf); + y_broadcast_buf = NULL; + } + return ret; +} + +int WhereCompute(KernelBase *self) { + WhereStruct *where = (WhereStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(where); + + int ret = NNACL_ERR; + if (self->in_size_ == Num1) { + ret = WhereRunWithSingleInput(where); + } else if (self->in_size_ == Num3) { + ret = WhereRunWithTripleInputs(where); + } else { + ret = NNACL_WHERE_INPUT_NUM_INVALID; + } + return ret; +} + +int WherePrepare(KernelBase *self) { + NNACL_CHECK_TRUE_RET(self->in_size_ == Num1 || self->in_size_ == Num3, NNACL_WHERE_INPUT_NUM_INVALID); + NNACL_CHECK_TRUE_RET(self->out_size_ == Num1, NNACL_OUTPUT_TENSOR_ERROR); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + return NNACL_OK; +} + +KernelBase *CreateWhere(OpParameter *param, int data_type) { + WhereStruct *where = (WhereStruct *)malloc(sizeof(WhereStruct)); + NNACL_CHECK_NULL_RETURN_NULL(where); + memset(where, 0, sizeof(WhereStruct)); + where->data_type_ = data_type; + where->base_.Prepare = WherePrepare; + where->base_.Compute = WhereCompute; + where->base_.Resize = DefaultResize; + where->base_.Release = DefaultRelease; + return (KernelBase *)where; +} + +REG_KERNEL_CREATOR(PrimType_Where, kNumberTypeBool, CreateWhere) +REG_KERNEL_CREATOR(PrimType_Where, kNumberTypeInt32, CreateWhere) +REG_KERNEL_CREATOR(PrimType_Where, kNumberTypeFloat16, CreateWhere) +REG_KERNEL_CREATOR(PrimType_Where, kNumberTypeFloat32, CreateWhere) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/where.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/where.h new file mode 100644 index 0000000000000000000000000000000000000000..f281996950c91cfe7052a95a4fb928b6ebc4259e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/where.h @@ -0,0 +1,44 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_WHERE_H_ +#define NNACL_KERNEL_WHERE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct WhereArgs { + int condition_num_; + int x_num_; + int y_num_; + int max_num_; + int rank_; + bool *condition_; +} WhereArgs; + +typedef struct WhereStruct { + KernelBase base_; + WhereArgs args_; + int data_type_; + void *x_; + void *y_; + void *output_; +} WhereStruct; + +KernelBase *CreateWhere(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_WHERE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/zeros_like.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/zeros_like.c new file mode 100644 index 0000000000000000000000000000000000000000..946b92ea4eae9d4b2539a9968a4b819e2226e160 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/zeros_like.c @@ -0,0 +1,43 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/zeros_like.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/tensor_c_utils.h" + +int ZerosLikeCompute(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + NNACL_CHECK_NULL_RETURN_ERR(output->data_); + (void)memset(output->data_, 0, NNACLGetSize(output)); + return NNACL_OK; +} + +KernelBase *CreateZerosLike(OpParameter *param, int data_type) { + ZerosLikeStruct *zeros_like = (ZerosLikeStruct *)malloc(sizeof(ZerosLikeStruct)); + NNACL_CHECK_NULL_RETURN_NULL(zeros_like); + zeros_like->base_.Release = DefaultRelease; + zeros_like->base_.Prepare = DefaultPrepare1In1Out; + zeros_like->base_.Resize = DefaultResize; + zeros_like->base_.Compute = ZerosLikeCompute; + return (KernelBase *)zeros_like; +} + +REG_KERNEL_CREATOR(PrimType_ZerosLike, kNumberTypeInt32, CreateZerosLike) +REG_KERNEL_CREATOR(PrimType_ZerosLike, kNumberTypeFloat32, CreateZerosLike) +REG_KERNEL_CREATOR(PrimType_ZerosLike, kNumberTypeFloat16, CreateZerosLike) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/zeros_like.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/zeros_like.h new file mode 100644 index 0000000000000000000000000000000000000000..24085b3a1b24df67876504ae0bf1829a4ff65a1b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/zeros_like.h @@ -0,0 +1,31 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ZEROS_LIKE_H_ +#define NNACL_KERNEL_ZEROS_LIKE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ZerosLikeStruct { + KernelBase base_; +} ZerosLikeStruct; + +KernelBase *CreateZerosLike(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ZEROS_LIKE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/l2_norm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/l2_norm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..8476503c03e3834be76256fad76e4d30e711ae81 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/l2_norm_parameter.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_L2NORM_PARAMETER_H_ +#define NNACL_L2NORM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct L2NormParameter { + // Primitive parameter + OpParameter op_parameter_; + float epsilon_; + int axis_[MAX_SHAPE_SIZE]; + // shape correlative + size_t axis_num_; + int data_num_; + int *shape_; + size_t shape_num_; + // other parameter + ActType act_type_; +} L2NormParameter; + +typedef struct { + QuantArg in_; + QuantArg out_; +} L2NormQuantArg; + +#endif // NNACL_L2NORM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/layer_norm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/layer_norm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..e2b7bc45f04cd1e457d4bf55b075f515ccca97b8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/layer_norm_parameter.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_LAYER_NORM_PARAMETER_H_ +#define NNACL_LAYER_NORM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct LayerNormParameter { + OpParameter op_parameter_; + float epsilon_; + bool elementwise_affine_; + int begin_norm_axis_; + int begin_params_axis_; +} LayerNormParameter; + +typedef struct LayerNormQuantArg { + int32_t in_zp_; + int32_t out_zp_; + double in_scale_; + double out_scale_; +} LayerNormQuantArg; + +#endif // NNACL_LAYER_NORM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/local_response_norm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/local_response_norm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..ebb5f6fb1dff5312f53a2ee1262bcd62a4cbdc2b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/local_response_norm_parameter.h @@ -0,0 +1,31 @@ + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_LRM_PARAMETER_H_ +#define NNACL_LRM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct LocalResponseNormParameter { + OpParameter op_parameter_; + int depth_radius_; + float bias_; + float alpha_; + float beta_; +} LocalResponseNormParameter; + +#endif // NNACL_LRM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/lsh_projection_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/lsh_projection_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..b3bb2e5b094857f214a000a5a356669242ac8be7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/lsh_projection_parameter.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_LSH_PROJECTION_PARAMETER_H_ +#define NNACL_LSH_PROJECTION_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct LshProjectionParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int hash_shape_[2]; + // other parameter + int lsh_type_; + int feature_num_; + char **hash_buffs_; + size_t hash_buff_size_; + int64_t thread_stride_; +} LshProjectionParameter; + +#endif // NNACL_LSH_PROJECTION_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/lstm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/lstm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..563b1c29e8e0931f04eb5c33eacaecc523167dba --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/lstm_parameter.h @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_LSTM_PARAMETER_H_ +#define NNACL_LSTM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct LstmParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int input_size_; + int hidden_size_; + int project_size_; + int output_size_; + int seq_len_; + int batch_; + // other parameter + int output_step_; + bool bidirectional_; + float zoneout_cell_; + float zoneout_hidden_; + int input_row_align_; + int input_col_align_; + int state_row_align_; + int state_col_align_; + int proj_col_align_; + bool has_bias_; +} LstmParameter; + +#endif // NNACL_LSTM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/matmul_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/matmul_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..e23edf89cc6590b5f50b6423be0b6ee29055b313 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/matmul_parameter.h @@ -0,0 +1,96 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_MATMUL_H_ +#define NNACL_MATMUL_H_ + +#include "nnacl_c/op_base.h" + +typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int32_t *dst, int row_4, int col_4, int deep_16, + const int32_t *input_sum, const int32_t *bias); + +typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, + const int32_t *left_shift, const int32_t *right_shift, const int32_t *multiplier, + int32_t output_zp, int32_t mini, int32_t maxi, size_t per_channel); + +typedef void (*MATMUL_OPT_DP_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, + const int32_t *left_shift, const int32_t *right_shift, const int32_t *multiplier, + int32_t output_zp, int32_t mini, int32_t maxi, size_t per_channel, + const int32_t *filter_zp); + +typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2, OutType_NC4HW4 = 3 } OutType; + +typedef enum MatmulType { + // reserve 0 for base op + kNotImplemented = 0, + kMatmulInt8Cpu, + kMatmulDynamicInt8Cpu, + kMatmulDynamicSdotInt8Cpu, + kMatmulFp32BaseCpu, + kMatmulFp32Arm64Cpu, +} MatmulType; + +typedef struct MatMulParameter { + // Primitive parameter + OpParameter op_parameter_; + bool has_bias_; + bool use_axis_; + bool a_transpose_; /* false : row-major */ + bool b_transpose_; /* true : col-major */ + ActType act_type_; + + // other parameter + int row_; + int col_; + int row_4_; + int row_16_; + int row_align_; + int col_8_; + int col_align_; + int deep_; + int deep_4_; + int deep_16_; + int deep_align_; + int batch; + bool a_const_; + bool b_const_; + int axis_; + MatmulType matmul_type_; +} MatMulParameter; + +typedef struct MatmulQuantParameter { + QuantArg input_; + QuantArg weight_; + QuantArg output_; + int32_t out_act_min_; + int32_t out_act_max_; + float *filter_scale_; + int32_t *filter_zp_; + int32_t *left_shift_; + int32_t *right_shift_; + int32_t *quant_multiplier_; +} MatmulQuantParameter; + +typedef struct MatmulDynamicQuantParameter { + float *input_scale_; + int32_t *input_zp_; + float *filter_scale_; + int32_t *filter_zp_; +} MatmulDynamicQuantParameter; + +#endif // NNACL_MATMUL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/mul_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/mul_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..43cb7279b3f58a19e34e5840175b108340714896 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/mul_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_MUL_PARAMETER_H_ +#define NNACL_MUL_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct MulQuantArg { + QuantArg in_quant_args_[2]; + QuantArg out_quant_arg_; + int output_multiplier_; + int output_activation_min_; + int output_activation_max_; + int shift_left_; + int shift_right_; +} MulQuantArg; + +#endif // NNACL_MUL_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nllloss_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nllloss_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..b112246f2f76c50cfbb79f2a77cdd9519f4a0b84 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nllloss_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_NLLLOSS_PARAMETER_H_ +#define NNACL_NLLLOSS_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct NLLLossParameter { + OpParameter op_parameter_; + ReductionType reduction_type_; +} NLLLossParameter; + +#endif // NNACL_NLLLOSS_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_common.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_common.c new file mode 100644 index 0000000000000000000000000000000000000000..9cc3a1a0e726a8e6125cdb8dbce7d2458fd83c20 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_common.c @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/nnacl_common.h" + +typedef union float32_bits { + unsigned int u; + float f; +} float32_bits; + +float ShortToFloat32(uint16_t src_value) { + const float32_bits magic = {113 << 23}; + const unsigned int shifted_exp = 0x7c00 << 13; + float32_bits o; + + o.u = (src_value & 0x7fff) << 13; + unsigned int exp = shifted_exp & o.u; + o.u += (127 - 15) << 23; + + if (exp == shifted_exp) { + o.u += (128 - 16) << 23; + } else if (exp == 0) { + o.u += 1 << 23; + o.f -= magic.f; + } + + o.u |= (src_value & 0x8000) << 16; + return o.f; +} + +uint16_t Float32ToShort(float src_value) { + float32_bits src_value_bits; + src_value_bits.f = src_value; + uint16_t res = 0; + // mantissa + res += (src_value_bits.u >> 13); + // exponent + res += (src_value_bits.u >> 13) & 0x3fc00; + res -= (127 - 15) << 13; + + // sign + res |= (src_value_bits.u & 0x80000000) >> 16; + return res; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_common.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_common.h new file mode 100644 index 0000000000000000000000000000000000000000..df2a9a0a8f1bc4a383800ea0b4d5815d002465ed --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_common.h @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_NNACL_COMMON_H_ +#define NNACL_NNACL_COMMON_H_ + +#include +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +static inline size_t DataTypeCSize(TypeIdC type) { + switch (type) { + case kNumberTypeFloat64: + return sizeof(double); + case kNumberTypeFloat: + case kNumberTypeFloat32: + return sizeof(float); + case kNumberTypeInt8: + return sizeof(int8_t); + case kNumberTypeUInt8: + return sizeof(uint8_t); + case kNumberTypeFloat16: + case kNumberTypeInt16: + return sizeof(int16_t); + case kNumberTypeInt32: + return sizeof(int32_t); + case kNumberTypeInt64: + return sizeof(int64_t); + case kNumberTypeUInt16: + return sizeof(uint16_t); + case kNumberTypeUInt32: + return sizeof(uint32_t); + case kNumberTypeUInt64: + return sizeof(uint64_t); + case kNumberTypeComplex64: + return sizeof(float) + sizeof(float); + case kNumberTypeComplex128: + return sizeof(double) + sizeof(double); + case kNumberTypeBool: + return sizeof(bool); + case kObjectTypeString: + return sizeof(char); + case kObjectTypeTensorType: + return 0; + case kMetaTypeTypeType: + return sizeof(int); + default: + return 0; + } +} + +static inline void ComputeStrides(const int *shape, int *strides, const int ndim) { + int stride = 1; + for (int i = ndim - 1; i >= 0; i--) { + strides[i] = stride; + stride *= shape[i]; + } +} + +static inline void ComputeAxisDims(const int *shape, int shape_size, int axis, int *out_count, int *axis_count, + int *in_count) { + *out_count = 1; + *in_count = 1; + for (int i = 0; i < shape_size; i++) { + if (i < axis) { + *out_count = (*out_count) * shape[i]; + } + if (i == axis) { + *axis_count = shape[axis]; + } + if (i > axis) { + *in_count = (*in_count) * shape[i]; + } + } +} + +static const unsigned int FP32_BIT_SIZE = 32; +static const unsigned int FP32_EXPONENT_BIAS = 127; +static const unsigned int FP32_SIGNIFICAND = 23; +static const unsigned int FP32_EXPONENT_MAX = 255; +static const unsigned int FP16_BIT_SIZE = 16; +static const unsigned int FP16_EXPONENT_BIAS = 15; +static const unsigned int FP16_SIGNIFICAND = 10; +static const int FP16_EXPONENT_MAX = 30; +static const int FP16_EXPONENT_MIN = -10; +static const int FP16_SHIFT = 13; +float ShortToFloat32(uint16_t src_value); +uint16_t Float32ToShort(float src_value); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_NNACL_COMMON_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_utils.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_utils.c new file mode 100644 index 0000000000000000000000000000000000000000..20815a161562ba4c90d53fd7b40e024fceb5cb7e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_utils.c @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/nnacl_utils.h" +#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS) +#include +#endif + +#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS) +uint32_t getHwCap(int hwcap_type) { + uint32_t ret = getauxval(hwcap_type); + return ret; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_utils.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..dfc198787c9c62476b87388bf68a70f382f0bd64 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_utils.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_NNACL_UTILS_H_ +#define NNACL_NNACL_UTILS_H_ + +#include +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(__arm__) || defined(__aarch64__) +uint32_t getHwCap(int hwcap_type); +#endif + +#ifdef DEBUG +#include +#define NNACL_ASSERT(f) assert(f) +#else +#define NNACL_ASSERT(f) ((void)0) +#endif + +#ifdef __cplusplus +} +#endif +#endif // NNACL_NNACL_UTILS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/non_max_suppression_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/non_max_suppression_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..06557138f285d73c6a8c09efdaa58e80c9d17c92 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/non_max_suppression_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_NON_MAX_SUPPRESSION_PARAMETER_H_ +#define NNACL_NON_MAX_SUPPRESSION_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct NMSParameter { + // Primitive parameter + OpParameter op_parameter_; + int center_point_box_; +} NMSParameter; + +#endif // NNACL_NON_MAX_SUPPRESSION_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/one_hot_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/one_hot_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..b1650c36b5b74efdf636375b7078d915c73e8c97 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/one_hot_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ONE_HOT_PARAMETER_H_ +#define NNACL_ONE_HOT_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +typedef struct OneHotParameter { + OpParameter op_parameter_; + int axis_; +} OneHotParameter; + +#endif // NNACL_ONE_HOT_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/op_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/op_base.h new file mode 100644 index 0000000000000000000000000000000000000000..e1a9c40c3f23ff26242c85c1ad3f9317080d72f4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/op_base.h @@ -0,0 +1,802 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_OP_BASE_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_OP_BASE_H_ + +#include +#include +#include +#include +#include +#include +#ifdef ENABLE_ARM +#include +#endif + +#define C0NUM 0 +#define C1NUM 1 +#define C2NUM 2 +#define C3NUM 3 +#define C4NUM 4 +#define C5NUM 5 +#define C6NUM 6 +#define C7NUM 7 +#define C8NUM 8 +#define C9NUM 9 +#define C10NUM 10 +#define C11NUM 11 +#define C12NUM 12 +#define C13NUM 13 +#define C14NUM 14 +#define C15NUM 15 +#define C16NUM 16 +#define C17NUM 17 +#define C18NUM 18 +#define C19NUM 19 +#define C20NUM 20 +#define C21NUM 21 +#define C22NUM 22 +#define C23NUM 23 +#define C24NUM 24 +#define C28NUM 28 +#define C32NUM 32 +#define C36NUM 36 +#define C40NUM 40 +#define C44NUM 44 +#define C48NUM 48 +#define C56NUM 56 +#define C64NUM 64 +#define C128NUM 128 +#define C150NUM 150 +#define C256NUM 256 +#define C512NUM 512 +#define C1500NUM 1500 +#define TILE_NUM 8 +#define MAX_SPLIT_NUM 2048 + +#define FP16_DATA_TYPE_LEN 2 + +#ifndef MS_UNLIKELY +#ifdef _MSC_VER +#define MS_UNLIKELY(x) (x) +#else +#define MS_UNLIKELY(x) __builtin_expect(!!(x), 0) +#endif +#endif + +#ifndef MS_LIKELY +#ifdef _MSC_VER +#define MS_LIKELY(x) (x) +#else +#define MS_LIKELY(x) __builtin_expect(!!(x), 1) +#endif +#endif + +#define NNACL_MIN(x, y) ((x) < (y) ? (x) : (y)) +#define NNACL_MAX(x, y) ((x) > (y) ? (x) : (y)) + +#define MSMIN(x, y) ((x) < (y) ? (x) : (y)) +#define MSMAX(x, y) ((x) > (y) ? (x) : (y)) +#define MSCEIL(x) (int)((x) + (((x) - (int)(x)) > 0 ? 1 : 0)) + +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) +#define UP_ROUND(x, y) (((x) + (y) - (1)) / (y) * (y)) +#define DOWN_DIV(x, y) ((x) / (y)) +#define DOWN_ROUND(x, y) ((x) / (y) * (y)) + +#define MSVALID(left, x, right) (MSMIN((MSMAX(left, x)), right)) +#define SIZE_MUL_OVERFLOW(x, y) (((x) == 0) ? false : (SIZE_MAX / (x)) < (y)) +#define INT_MUL_OVERFLOW(x, y) \ + (((x) == 0) ? false \ + : ((x) > 0 ? (((y) >= 0) ? (INT_MAX / (x)) < (y) : (INT_MAX / (x)) < (-1 * (y))) \ + : (((y) >= 0) ? (INT_MAX / (x)) > (-1 * (y)) : (INT_MAX / (x)) > (y)))) + +#define INT_MUL_OVERFLOW_THRESHOLD(x, y, threshold) \ + (((x) == 0) ? false \ + : ((x) > 0 ? (((y) >= 0) ? ((threshold) / (x)) < (y) : ((threshold) / (x)) < (-1 * (y))) \ + : (((y) >= 0) ? ((threshold) / (x)) > (-1 * (y)) : ((threshold) / (x)) > (y)))) + +#define INT_ADD_OVERFLOW(x, y) (INT_MAX - (x)) < (y) + +#define INT_ADD_OVERFLOW_THRESHOLD(x, y, threshold) ((threshold) - (x)) < (y) + +#define MALLOC_MAX_SIZE (2000 * 1024 * 1024) + +#define COMM_SHAPE_SIZE 4 +#define MAX_SHAPE_SIZE 8 + +#define OUTPUT_INDEX 0 +#define FIRST_INPUT 0 +#define SECOND_INPUT 1 +#define THIRD_INPUT 2 +#define FOURTH_INPUT 3 +#define FIFTH_INPUT 4 +#define SIXTH_INPUT 5 +#define SEVENTH_INPUT 6 +#define EIGHTH_INPUT 7 +#define NINTH_INPUT 8 + +#define ONE_TENSOR 1 +#define TWO_TENSOR 2 +#define THREE_TENSOR 3 +#define FOUR_TENSOR 4 +#define FIVE_TENSOR 5 + +#define Index0 0 +#define Index1 1 +#define Index2 2 +#define Index3 3 +#define Index4 4 +#define Index5 5 +#define Index6 6 +#define Index7 7 +#define Index8 8 +#define Index9 9 + +#define Num0 0 +#define Num1 1 +#define Num2 2 +#define Num3 3 +#define Num4 4 +#define Num5 5 +#define Num6 6 +#define Num7 7 +#define Num8 8 +#define Num9 9 + +#define DIMENSION_0D 0 +#define DIMENSION_1D 1 +#define DIMENSION_2D 2 +#define DIMENSION_3D 3 +#define DIMENSION_4D 4 +#define DIMENSION_5D 5 +#define DIMENSION_6D 6 +#define DIMENSION_7D 7 +#define DIMENSION_8D 8 +#define DIMENSION_9D 9 +#define DIMENSION_10D 10 +#define DIMENSION_11D 11 +#define kInputIndex 0 +#define kWeightIndex 1 +#define kBiasIndex 2 +#define kOutputIndex 0 +#define kNHWC_N 0 +#define kNHWC_H 1 +#define kNHWC_W 2 +#define kNHWC_C 3 +#define kNCHW_N 0 +#define kNCHW_C 1 +#define kNCHW_H 2 +#define kNCHW_W 3 +#define kHWCN_C 2 +#define kHWNC_N 2 +#define kHWCN_N 3 +#define kNDHWC_N 0 +#define kNDHWC_D 1 +#define kNDHWC_H 2 +#define kNDHWC_W 3 +#define kNDHWC_C 4 +#define kInputSize1 2 +#define kInputSize2 3 +#define MAX_AXIS_SIZE 6 +#define MAX_LEN 256 +#define MAX_THREAD_NUM 64 +#define FLT16_MAX 65504 +#define kDefaulLiteMaxSpinCount 300000 +#define kDefaulLiteMinSpinCount 1 +#define kDefaulLiteIosSpinCount 1 +#define DEFAULT_GROUP_NAME_LEN 101 +#define kValueThreshold6 6 + +#define INVALID_SHAPE -1 + +#define CLARGSINDEX0 0 +#define CLARGSINDEX1 1 +#define CLARGSINDEX2 2 +#define CLARGSINDEX3 3 +#define CLARGSINDEX4 4 +#define CLARGSINDEX5 5 +#define CLARGSINDEX6 6 +#define CLARGSINDEX7 7 +#define CLARGSINDEX8 8 +#define CLARGSINDEX9 9 + +#define CLIDX_X 0 +#define CLIDX_Y 1 +#define CLIDX_Z 2 +#define CLIDX_W 3 + +#define RELU6_MIN_VAL 0 +#define RELU6_MAX_VAL 6 + +/* index for primitive_type & activation_type */ +#define TC_PTYPE(primitive_type) (primitive_type << 16) +#define TC_ATYPE(activation_type) (activation_type) +#define TC_TYPE(primitive_type, activation_type) (TC_PTYPE(primitive_type) + TC_ATYPE(activation_type)) + +#define NNACL_MALLOC_CHECK_NULL_RETURN_ERR(ptr) \ + do { \ + if ((ptr) == NULL) { \ + return NNACL_NULL_PTR; \ + } \ + } while (0) + +#define NNACL_MALLOC_CHECK_NULL_RETURN_NULL(ptr) \ + do { \ + if ((ptr) == NULL) { \ + return NULL; \ + } \ + } while (0) + +#if ENABLE_HIGH_PERFORMANCE +#define NNACL_CHECK_TRUE_RET(value, errcode) +#define NNACL_CHECK_TRUE_RET_VOID(value) +#define NNACL_CHECK_FALSE(value, errcode) +#define NNACL_CHECK_INT_MUL_NOT_OVERFLOW(value1, value2, errcode) +#define NNACL_CHECK_INT_ADD_NOT_OVERFLOW(value1, value2, errcode) + +#define NNACL_CHECK_ZERO_RETURN_ERR(val) +#define NNACL_CHECK_ZERO_RETURN(val) +#define NNACL_CHECK_NULL_RETURN_ERR(ptr) +#define NNACL_CHECK_NULL_RETURN_VOID(ptr) +#define NNACL_CHECK_NULL_RETURN_NULL(ptr) +#define NNACL_CHECK_MALLOC_SIZE(val) +#else +#define NNACL_CHECK_TRUE_RET(value, errcode) \ + do { \ + if (!(value)) { \ + return errcode; \ + } \ + } while (0) + +#define NNACL_CHECK_TRUE_RET_VOID(value) \ + do { \ + if (!(value)) { \ + return; \ + } \ + } while (0) + +// Check whether value is false, if not return 'errcode' +#define NNACL_CHECK_FALSE(value, errcode) \ + do { \ + if ((value)) { \ + return errcode; \ + } \ + } while (0) + +#define NNACL_CHECK_INT_MUL_NOT_OVERFLOW(value1, value2, errcode) \ + NNACL_CHECK_TRUE_RET(!(INT_MUL_OVERFLOW(value1, value2)), errcode) +#define NNACL_CHECK_INT_ADD_NOT_OVERFLOW(value1, value2, errcode) \ + NNACL_CHECK_TRUE_RET(!(INT_ADD_OVERFLOW(value1, value2)), errcode) +#define NNACL_CHECK_MALLOC_SIZE(malloc_size) \ + NNACL_CHECK_FALSE((malloc_size) > MALLOC_MAX_SIZE, NNACL_MALLOC_SIZE_INVALID) + +#define NNACL_CHECK_ZERO_RETURN_ERR(val) \ + do { \ + if ((val) == 0) { \ + return NNACL_ERR; \ + } \ + } while (0) + +#define NNACL_CHECK_ZERO_RETURN(val) \ + do { \ + if ((val) == 0) { \ + return; \ + } \ + } while (0) + +#define NNACL_CHECK_NULL_RETURN_ERR(ptr) \ + do { \ + if ((ptr) == NULL) { \ + return NNACL_NULL_PTR; \ + } \ + } while (0) + +#define NNACL_CHECK_NULL_RETURN_VOID(ptr) \ + do { \ + if ((ptr) == NULL) { \ + return; \ + } \ + } while (0) + +#define NNACL_CHECK_NULL_RETURN_NULL(ptr) \ + do { \ + if ((ptr) == NULL) { \ + return NULL; \ + } \ + } while (0) +#endif + +enum PrimType { + PrimType_NONE = 0, + PrimType_Abs = 1, + PrimType_Activation = 2, + PrimType_ActivationGrad = 3, + PrimType_Adam = 4, + PrimType_AddFusion = 5, + PrimType_AdderFusion = 6, + PrimType_AddGrad = 7, + PrimType_AddN = 8, + PrimType_All = 9, + PrimType_ApplyMomentum = 10, + PrimType_ArgMaxFusion = 11, + PrimType_ArgMinFusion = 12, + PrimType_Assert = 13, + PrimType_Assign = 14, + PrimType_AssignAdd = 15, + PrimType_AudioSpectrogram = 16, + PrimType_AvgPoolFusion = 17, + PrimType_AvgPoolGrad = 18, + PrimType_BatchNorm = 19, + PrimType_BatchNormGrad = 20, + PrimType_BatchToSpace = 21, + PrimType_BatchToSpaceND = 22, + PrimType_BiasAdd = 23, + PrimType_BinaryCrossEntropy = 24, + PrimType_BinaryCrossEntropyGrad = 25, + PrimType_BiasAddGrad = 26, + PrimType_BroadcastTo = 27, + PrimType_Cast = 28, + PrimType_Ceil = 29, + PrimType_Clip = 30, + PrimType_Concat = 31, + PrimType_Attention = 32, + PrimType_Conv2DBackpropFilterFusion = 33, + PrimType_Conv2DBackpropInputFusion = 34, + PrimType_Conv2DFusion = 35, + PrimType_Conv2dTransposeFusion = 36, + PrimType_Cos = 37, + PrimType_ConstantOfShape = 38, + PrimType_Crop = 39, + PrimType_CustomExtractFeatures = 40, + PrimType_CustomNormalize = 41, + PrimType_CustomPredict = 42, + PrimType_DeConv2DGradFilter = 43, + PrimType_Depend = 44, + PrimType_DepthToSpace = 45, + PrimType_DetectionPostProcess = 46, + PrimType_DivFusion = 47, + PrimType_DivGrad = 48, + PrimType_Dropout = 49, + PrimType_DropoutGrad = 50, + PrimType_Elu = 51, + PrimType_Eltwise = 52, + PrimType_Equal = 53, + PrimType_EmbeddingLookupFusion = 54, + PrimType_ExpFusion = 55, + PrimType_ExpandDims = 56, + PrimType_FakeQuantWithMinMaxVars = 57, + PrimType_FakeQuantWithMinMaxVarsPerChannel = 58, + PrimType_FftReal = 59, + PrimType_FftImag = 60, + PrimType_Flatten = 61, + PrimType_FlattenGrad = 62, + PrimType_Floor = 63, + PrimType_FloorDiv = 64, + PrimType_FloorMod = 65, + PrimType_Fill = 66, + PrimType_FullConnection = 67, + PrimType_FusedBatchNorm = 68, + PrimType_Gather = 69, + PrimType_GatherNd = 70, + PrimType_Greater = 71, + PrimType_GreaterEqual = 72, + PrimType_HashtableLookup = 73, + PrimType_InstanceNorm = 74, + PrimType_LayerNormFusion = 75, + PrimType_LeakyRelu = 76, + PrimType_Less = 77, + PrimType_LessEqual = 78, + PrimType_Log = 79, + PrimType_LogGrad = 80, + PrimType_LogicalAnd = 81, + PrimType_LogicalNot = 82, + PrimType_LogicalOr = 83, + PrimType_LpNormalization = 84, + PrimType_LRN = 85, + PrimType_LshProjection = 86, + PrimType_LSTM = 87, + PrimType_L2NormalizeFusion = 88, + PrimType_MatMulFusion = 89, + PrimType_Maximum = 90, + PrimType_MaximumGrad = 91, + PrimType_MaxPoolFusion = 92, + PrimType_MaxPoolGrad = 93, + PrimType_SwitchLayer = 94, + PrimType_Mfcc = 95, + PrimType_Minimum = 96, + PrimType_MinimumGrad = 97, + PrimType_Mod = 98, + PrimType_MulFusion = 99, + PrimType_MulGrad = 100, + PrimType_Neg = 101, + PrimType_NegGrad = 102, + PrimType_NotEqual = 103, + PrimType_NonMaxSuppression = 104, + PrimType_OneHot = 105, + PrimType_OnesLike = 106, + PrimType_PadFusion = 107, + PrimType_PartialFusion = 108, + PrimType_PowerGrad = 109, + PrimType_PowFusion = 110, + PrimType_PriorBox = 111, + PrimType_PReLUFusion = 112, + PrimType_QuantDTypeCast = 113, + PrimType_Rank = 114, + PrimType_Range = 115, + PrimType_Reciprocal = 116, + PrimType_RealDiv = 117, + PrimType_ReduceFusion = 118, + PrimType_Reshape = 119, + PrimType_Resize = 120, + PrimType_ReverseSequence = 121, + PrimType_ReverseV2 = 122, + PrimType_Rfft = 123, + PrimType_ROIPooling = 124, + PrimType_Round = 125, + PrimType_Rsqrt = 126, + PrimType_ScaleFusion = 127, + PrimType_ScatterNd = 128, + PrimType_SGD = 129, + PrimType_Shape = 130, + PrimType_SigmoidCrossEntropyWithLogits = 131, + PrimType_SigmoidCrossEntropyWithLogitsGrad = 132, + PrimType_Sin = 133, + PrimType_SkipGram = 134, + PrimType_SliceFusion = 135, + PrimType_SmoothL1Loss = 136, + PrimType_SmoothL1LossGrad = 137, + PrimType_Softmax = 138, + PrimType_SoftmaxCrossEntropyWithLogits = 139, + PrimType_SpaceToBatch = 140, + PrimType_SpaceToBatchND = 141, + PrimType_SpaceToDepth = 142, + PrimType_SparseSoftmaxCrossEntropyWithLogits = 143, + PrimType_SparseToDense = 144, + PrimType_Split = 145, + PrimType_Sqrt = 146, + PrimType_Squeeze = 147, + PrimType_Square = 148, + PrimType_SquaredDifference = 149, + PrimType_Stack = 150, + PrimType_StridedSlice = 151, + PrimType_SubFusion = 152, + PrimType_SubGrad = 153, + PrimType_Switch = 154, + PrimType_TensorListFromTensor = 155, + PrimType_TensorListGetItem = 156, + PrimType_TensorListReserve = 157, + PrimType_TensorListSetItem = 158, + PrimType_TensorListStack = 159, + PrimType_TileFusion = 160, + PrimType_TopKFusion = 161, + PrimType_Transpose = 162, + PrimType_Unique = 163, + PrimType_UnsortedSegmentSum = 164, + PrimType_Unsqueeze = 165, + PrimType_Unstack = 166, + PrimType_LSTMGrad = 167, + PrimType_Where = 168, + PrimType_ZerosLike = 169, + PrimType_Select = 170, + PrimType_ScatterNdUpdate = 171, + PrimType_GRU = 172, + PrimType_NonZero = 173, + PrimType_InvertPermutation = 174, + PrimType_Size = 175, + PrimType_RandomStandardNormal = 176, + PrimType_CropAndResize = 177, + PrimType_Erf = 178, + PrimType_StridedSliceGrad = 179, + PrimType_IsFinite = 180, + PrimType_LinSpace = 181, + PrimType_UniformReal = 182, + PrimType_AbsGrad = 183, + PrimType_RsqrtGrad = 184, + PrimType_SqrtGrad = 185, + PrimType_LayerNormGrad = 186, + PrimType_ResizeGrad = 187, + PrimType_Splice = 188, + PrimType_LogSoftmax = 189, + PrimType_Call = 190, + PrimType_Custom = 191, + PrimType_CumSum = 192, + PrimType_SplitWithOverlap = 193, + PrimType_GenOP = 194, + PrimType_RaggedRange = 195, + PrimType_GLU = 196, + PrimType_TensorArray = 197, + PrimType_TensorArrayRead = 198, + PrimType_TensorArrayWrite = 199, + PrimType_Affine = 200, + PrimType_AllGather = 201, + PrimType_ReduceScatter = 202, + PrimType_DynamicQuant = 203, + PrimType_LSTMGradData = 204, + PrimType_LSTMGradWeight = 205, + PrimType_RandomNormal = 206, + PrimType_NLLLoss = 207, + PrimType_NLLLossGrad = 208, + PrimType_FormatTranspose = 209, + PrimType_GatherD = 210, + PrimType_GroupNormFusion = 211, + PrimType_Log1p = 212, + PrimType_TensorScatterAdd = 213, + PrimType_SparseFillEmptyRows = 214, + PrimType_SparseReshape = 215, + PrimType_SparseSegmentSum = 216, + PrimType_ScatterElements = 217, + PrimType_Triu = 218, + PrimType_Tril = 219, + PrimType_AdamWeightDecay = 220, + PrimType_FillV2 = 221, + PrimType_MIN = PrimType_NONE, + PrimType_MAX = PrimType_FillV2 + 1, + + // inner operators. + PrimType_Inner_ToFormat = 10000, + PrimType_Inner_GltextureToOpencl = 10001, + PrimType_Inner_Identity = 10002, + PrimType_Inner_ShapeFusion = 10003, + PrimType_Inner_GraphKernel = 10004, + PrimType_Inner_SplitReduceConcatFusion = 10005, + PrimType_Inner_EncoderLayer = 10006, + PrimType_Inner_FseDecode = 10007, + PrimType_Inner_DecoderLayer = 10008, + PrimType_Inner_UsePastEmbedding = 10009, + PrimType_Inner_CustomGru = 10010, + PrimType_Inner_CastGatherReduceFusion = 10011, + PrimType_Inner_ReduceConcatFusion = 10012, + PrimType_Inner_AclCustomOp = 10013, + PrimType_Inner_CustomMaskedFill = 10014, + PrimType_Inner_CustomTensorScatterMax = 10015, + PrimType_Inner_CustomIsInf = 10016, + PrimType_Inner_Conv3D = 10017, + PrimType_Inner_GridSampler = 10018, + PrimType_InnerOpMax, + PrimType_InnerOpMin = PrimType_Inner_ToFormat +}; + +typedef enum FormatC { + DEFAULT_FORMAT = -1, + Format_NCHW = 0, + Format_NHWC = 1, + Format_NHWC4 = 2, + Format_HWKC = 3, + Format_HWCK = 4, + Format_KCHW = 5, + Format_CKHW = 6, + Format_KHWC = 7, + Format_CHWK = 8, + Format_HW = 9, + Format_HW4 = 10, + Format_NC = 11, + Format_NC4 = 12, + Format_NC4HW4 = 13, + Format_NONE = 14, // The origin Format_NUM_OF_FORMAT can't be used. + Format_NCDHW = 15, + Format_NWC = 16, + Format_NCW = 17, + Format_NDHWC = 18, + Format_NC8HW8 = 19, + Format_NC16HW16 = 20, + Format_MAX, + Format_MIN = Format_NCHW +} FormatC; + +typedef enum TypeIdC { + kTypeUnknown = 0, + kMetaTypeBegin = kTypeUnknown, + kMetaTypeType, // Type + kMetaTypeAny, + kMetaTypeObject, + kMetaTypeTypeType, // TypeType + kMetaTypeProblem, + kMetaTypeExternal, + kMetaTypeNone, + kMetaTypeNull, + kMetaTypeEllipsis, + kMetaTypeEnd, + // + // Object types + // + kObjectTypeBegin = kMetaTypeEnd, + kObjectTypeNumber, + kObjectTypeString, + kObjectTypeList, + kObjectTypeTuple, + kObjectTypeSlice, + kObjectTypeKeyword, + kObjectTypeTensorType, + kObjectTypeRowTensorType, + kObjectTypeCOOTensorType, + kObjectTypeUndeterminedType, + kObjectTypeClass, + kObjectTypeDictionary, + kObjectTypeFunction, + kObjectTypeJTagged, + kObjectTypeSymbolicKeyType, + kObjectTypeEnvType, + kObjectTypeRefKey, + kObjectTypeRef, + kObjectTypeEnd, + // + // Number Types + // + kNumberTypeBegin = kObjectTypeEnd, + kNumberTypeBool, + kNumberTypeInt, + kNumberTypeInt8, + kNumberTypeInt16, + kNumberTypeInt32, + kNumberTypeInt64, + kNumberTypeUInt, + kNumberTypeUInt8, + kNumberTypeUInt16, + kNumberTypeUInt32, + kNumberTypeUInt64, + kNumberTypeFloat, + kNumberTypeFloat16, + kNumberTypeFloat32, + kNumberTypeFloat64, + kNumberTypeDouble, + kNumberTypeComplex, + kNumberTypeComplex64, + kNumberTypeComplex128, + kNumberTypeInt4, + kNumberTypeGLUInt, + kNumberTypeEnd, +} TypeIdC; + +typedef enum DataOrder { + RowMajor, + ColMajor, +} DataOrder; + +typedef struct OpParameter { + char name_[100]; + int type_; + int thread_num_; + int quant_type_; + bool is_train_session_; + bool is_zero_shape_; + void (*destroy_func_)(struct OpParameter *param); +} OpParameter; + +typedef struct QuantArg { + float scale_; + int32_t zp_; +} QuantArg; + +typedef struct QuantMulArg { + int32_t multiplier_; + int left_shift_; + int right_shift_; +} QuantMulArg; + +typedef enum ReductionType { Reduction_Sum, Reduction_Mean, Reduction_None } ReductionType; +typedef enum ActType { + ActType_No = 0, + ActType_Relu = 1, + ActType_Sigmoid = 2, + ActType_Relu6 = 3, + ActType_Elu = 4, + ActType_LeakyRelu = 5, + ActType_Abs = 6, + ActType_Relu1 = 7, + ActType_Softsign = 8, + ActType_Softplus = 9, + ActType_Tanh = 10, + ActType_Selu = 11, + ActType_HSwish = 12, + ActType_HSigmoid = 13, + ActType_ThresholdRelu = 14, + ActType_Linear = 15, + ActType_HardTanh = 16, + ActType_Sign = 17, + ActType_Swish = 18, + ActType_Gelu = 19, + ActType_FastGelu = 20, + ActType_Unknown = 21 +} ActType; +typedef enum PadType { Pad_pad, Pad_same, Pad_valid } PadType; +typedef enum EltwiseType { Eltwise_PROD, Eltwise_SUM, Eltwise_MAXIMUM, Eltwise_UNKNOWN } EltwiseType; +typedef enum RoundingMode { Rounding_No, Rounding_Away_from_zero, Rounding_Up } RoundingMode; + +typedef enum PaddingModeC { + PaddingMode_Constant, + PaddingMode_Reflect, + PaddingMode_Symmetric, + PaddingMode_Mode_Reserved, +} PaddingModeC; + +typedef enum ElementwiseModeC { + Elementwise_Not = 0, + Elementwise_Per_Channel = 1, + Elementwise_Per_Num = 2 +} ElementwiseModeC; + +typedef enum QuantTypeC { + Quant_None = 0, + Quant_AwareTraining = 1, + Quant_WeightQuant = 2, + Quant_PostTraining = 3, + Quant_QuantWeight = 4, + Quant_QuantAll = 5, + Quant_QuantDynamic = 6, + Quant_Min = Quant_None, + Quant_Max = Quant_QuantDynamic +} QuantTypeC; + +typedef enum TensorCategoryC { + VarTensor, // common tensor + ConstTensor, // const tensor + ConstScalar, // const scalar + GraphInput, + GraphOutput +} TensorCategoryC; + +typedef enum ReduceModeC { + Reduce_Mean = 0, + Reduce_Max = 1, + Reduce_Min = 2, + Reduce_Prod = 3, + Reduce_Sum = 4, + Reduce_SumSquare = 5, + Reduce_ASum = 6, + Reduce_All = 7, + Reduce_L2 = 8, + Reduce_MIN = Reduce_Mean, + Reduce_MAX = Reduce_L2 +} ReduceModeC; + +typedef enum CalFixedMultiplierMode { + Method_No, + Method_SinglePrecision, + Method_DoublePrecision +} CalFixedMultiplierMode; + +#define VA_ARG_TUPLE_LEN 2 +static inline void offset_to_index_init(int offset, int cnt, ...) { + va_list valist; + va_start(valist, cnt); + int start = offset; + for (int i = 0; i < cnt; i += VA_ARG_TUPLE_LEN) { + int *x = va_arg(valist, int *); + int X = va_arg(valist, int); + + *x = start % X; + start = start / X; + } + va_end(valist); +} + +static inline void offset_to_index_step(int cnt, ...) { + va_list valist; + int flag = 1; + va_start(valist, cnt); + for (int i = 0; i < cnt; i += VA_ARG_TUPLE_LEN) { + int *x = va_arg(valist, int *); + int X = va_arg(valist, int); + if (flag) { + *x = (++*x != X) ? (flag = 0, *x) : (flag = 1, 0); + } + } + va_end(valist); +} + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_OP_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/op_simd_header_file.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/op_simd_header_file.h.in new file mode 100644 index 0000000000000000000000000000000000000000..316a11ce233be1008b1985bccdd9fad46b27dd7d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/op_simd_header_file.h.in @@ -0,0 +1,36 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_@OP_NAME_UPPER@_SIMD_H_ +#define NNACL_@OP_NAME_UPPER@_SIMD_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#ifdef ENABLE_AVX512 +#include "nnacl_c/avx512/@OP_NAME_LOWER@_avx512.h" +#endif + +#ifdef ENABLE_AVX +#include "nnacl_c/avx/@OP_NAME_LOWER@_avx.h" +#endif + +#ifdef ENABLE_SSE +#include "nnacl_c/sse/@OP_NAME_LOWER@_sse.h" +#endif + +#ifdef ENABLE_ARM +#include "nnacl_c/neon/@OP_NAME_LOWER@_neon.h" +#endif + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/optimize/CMakeLists.txt b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/optimize/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9b190de7416bca9aaa43ce53092d26386122da82 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/optimize/CMakeLists.txt @@ -0,0 +1,62 @@ +project(optimize) + +set(NNACL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..) +include_directories(NNACL_DIR) + +########################### optimized files ########################### +file(GLOB FP16_C_SRC ${NNACL_DIR}/fp16/*.c ${NNACL_DIR}/kernel/f16/*.c) +if(PLATFORM_ARM32) + file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/arm82_aarch32_fp16/*.S) +else() + file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/fp16/*.S) + file(GLOB SDOT_SRC ${NNACL_DIR}/assembly/opt/*.S) + set_property(SOURCE ${SDOT_SRC} PROPERTY LANGUAGE C) +endif() + +set_property(SOURCE ${FP16_C_SRC} PROPERTY LANGUAGE C) +set_property(SOURCE ${FP16_NEON_SRC} PROPERTY LANGUAGE C) + +if(APPLE) + set_source_files_properties(${SDOT_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp") + set_source_files_properties(${FP16_NEON_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp") +endif() +########################### share library build ######################## +list(APPEND FP16_FILES ${FP16_C_SRC}) +list(APPEND FP16_FILES ${FP16_NEON_SRC}) + +if(SUPPORT_TRAIN) + file(GLOB FP16_TRAIN_SRC ${NNACL_DIR}/fp16_grad/*.c) + list(APPEND FP16_FILES ${FP16_TRAIN_SRC}) +endif() +if(NOT MSVC) +string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") +endif() +if(MACHINE_LINUX_ARM64) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+fp16") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+fp16") +elseif(NOT PLATFORM_ARM32 AND NOT TARGET_HIMIX AND (NOT (TARGET_AOS_ARM AND TOOLCHAIN_NAME STREQUAL "gcc"))) + list(APPEND SDOT_FILES ${SDOT_SRC}) + add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES}) + add_dependencies(nnacl_optimize_mid fbs_src) + if(NOT TARGET_MIX210) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") + endif() +endif() + +if(MSLITE_ENABLE_FP16) + add_library(nnacl_fp16_mid OBJECT ${FP16_FILES}) + add_dependencies(nnacl_fp16_mid fbs_src) + if(PLATFORM_ARM32) + target_compile_options(nnacl_fp16_mid PRIVATE -march=armv8.2-a+fp16 -mfpu=neon-fp-armv8 -mfloat-abi=softfp) + endif() + if(TARGET_AOS_ARM) + if(TOOLCHAIN_NAME STREQUAL "gcc") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+simd+fp16 -mtune=cortex-a72") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+simd+fp16 -mtune=cortex-a72") + else() + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+simd+dotprod+fp16 -mtune=cortex-a72") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+simd+dotprod+fp16 -mtune=cortex-a72") + endif() + endif() +endif() \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pack.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pack.h new file mode 100644 index 0000000000000000000000000000000000000000..701c3cdafde6bb3b6dd885843a1b6e72975db8b2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pack.h @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_PACK_H_ +#define NNACL_PACK_H_ + +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/int8/pack_int8.h" + +#endif // NNACL_PACK_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pad_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pad_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..386e5c42a5365dc761e9a44b3d96e5d8097a1d5d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pad_parameter.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_PAD_PARAMETER_H_ +#define NNACL_PAD_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +#define MAX_PAD_SIZE 12 +#define DEFAULT_PAD_NDIMS 6 + +typedef struct PadQuantArg { + QuantArg *in_quant_args_; + QuantArg *out_quanr_args_; + int8_t *constant_value_; +} PadQuantArg; + +typedef struct PadParameter { + OpParameter op_parameter_; + int paddings_[MAX_PAD_SIZE]; + int pad_mode_; + float constant_value_; + int padding_length; +} PadParameter; + +#endif // NNACL_PAD_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/partial_fusion_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/partial_fusion_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..85c2e69b869c0486cb258cd317129f8d91d3b17c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/partial_fusion_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_PARTIAL_FUSION_H_ +#define NNACL_PARTIAL_FUSION_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/nnacl_utils.h" + +typedef struct PartialParameter { + OpParameter op_parameter_; + int sub_graph_index_; +} PartialParameter; + +#endif // NNACL_ARTITHMETIC_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pooling_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pooling_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..b2d962c2213851a38f92bc07015774b04a164e1d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pooling_parameter.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_POOLING_PARAMETER_H_ +#define NNACL_POOLING_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef enum PoolMode { PoolMode_No, PoolMode_MaxPool, PoolMode_AvgPool } PoolMode; + +typedef enum RoundType { RoundType_No, RoundType_Ceil, RoundType_Floor } RoundType; + +typedef struct PoolingParameter { + OpParameter op_parameter_; + PoolMode pool_mode_; + RoundType round_type_; + PadType pad_mode_; + ActType act_type_; + int avg_mode_; + bool global_; + int window_w_; + int window_h_; + int stride_w_; + int stride_h_; + int pad_u_; + int pad_d_; + int pad_l_; + int pad_r_; +} PoolingParameter; + +typedef struct Pooling3DParameter { + PoolingParameter pooling_parameter_; + int window_d_; + int stride_d_; + int input_d_; + int output_d_; + int pad_f_; // front + int pad_b_; // back + bool count_include_pad_; + int divisor_override_; +} Pooling3DParameter; + +#endif // NNACL_POOLING_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pow_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pow_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..d62371a00b10ddd6d13604947724a8ac9ab12a59 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pow_parameter.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_POW_PARAMETER_H_ +#define NNACL_POW_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct PowQuantArg { + QuantArg in_args_; + QuantArg exp_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; +} PowQuantArg; + +typedef struct PowParameter { + OpParameter op_parameter_; + float power_; + float scale_; + float shift_; +} PowParameter; + +#endif // NNACL_POW_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/predict_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/predict_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..cf86950ea29aa88c4ec87f32f0d2a9c2dc18ce8e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/predict_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_PREDICT_PARAMETER_H_ +#define NNACL_PREDICT_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +typedef struct { + // Primitive parameter + OpParameter op_parameter_; + // other parameter + int output_num; + float weight_threshold; +} PredictParameter; + +typedef struct { + int label; + float weight; +} LabelInfo; +#endif // NNACL_PREDICT_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/prelu_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/prelu_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..c9c3b1d27a5a6668a10eb215a747e452126367ea --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/prelu_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_PRELU_PARAMETER_H_ +#define NNACL_PRELU_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +typedef struct PReluParameter { + OpParameter op_parameter_; + bool channel_shared_; +} PReluParameter; + +#endif // NNACL_PRELU_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/prior_box_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/prior_box_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..ad83559915cb342d876870829405b9606f3f34e9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/prior_box_parameter.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_PRIOR_BOX_PARAMETER_H_ +#define NNACL_PRIOR_BOX_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct PriorBoxParameter { + OpParameter op_parameter_; + int32_t min_sizes_size; + int32_t min_sizes[MAX_SHAPE_SIZE]; + int32_t max_sizes_size; + int32_t max_sizes[MAX_SHAPE_SIZE]; + int32_t aspect_ratios_size; + float aspect_ratios[MAX_SHAPE_SIZE]; + float variances[COMM_SHAPE_SIZE]; + int32_t image_size_w; + int32_t image_size_h; + float step_w; + float step_h; + bool clip; + bool flip; + float offset; +} PriorBoxParameter; + +#endif // NNACL_PRIOR_BOX_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/random_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/random_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..3efeaf945026f813a47f9c9d2ad9ad79b94fa1fb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/random_parameter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_RNADOM_PARAMETER_H_ +#define NNACL_RNADOM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct RandomParam { + OpParameter op_parameter_; + int seed_; + int seed2_; +} RandomParam; + +typedef struct RandomNormalParam { + OpParameter op_parameter_; + float seed_; + float mean_; + float scale_; +} RandomNormalParam; + +#endif // NNACL_RNADOM_STANDARD_NORMAL_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/range_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/range_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..64166deca1c5b2ce7e8bed51f8f9656a4499e1ba --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/range_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_RANGE_PARAMETER_H_ +#define NNACL_RANGE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct RangeParameter { + OpParameter op_parameter_; + int dtype_; + int start_; + int limit_; + int delta_; +} RangeParameter; + +#endif // NNACL_RANGE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reduce_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reduce_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..40df590efa5eadc557de5b752e0747607e1b2ef8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reduce_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_REDUCE_PARAMETER_H_ +#define NNACL_REDUCE_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +typedef struct ReduceParameter { + OpParameter op_parameter_; + bool keep_dims_; + int mode_; + bool reduce_to_end_; + float coeff; +} ReduceParameter; + +#endif // NNACL_REDUCE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reduce_scatter_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reduce_scatter_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..4ffe1b551876a40e5e53a1c6874ea94f3331cb83 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reduce_scatter_parameter.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_REDUCE_SCATTER_PARAMETER_H_ +#define NNACL_REDUCE_SCATTER_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ReduceScatterParameter { + // primitive parameter + OpParameter op_parameter_; + char group_[DEFAULT_GROUP_NAME_LEN]; + int mode_; + + // other parameter + int rank_size_; +} ReduceScatterParameter; +#endif // NNACL_REDUCE_SCATTER_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reshape_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reshape_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..fd8bfa4c9a2fb6a9c427247e9a680a63f3d39eb7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reshape_parameter.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_RESHAHPE_PARAMETER_H_ +#define NNACL_RESHAHPE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ReshapeQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; +} ReshapeQuantArg; + +typedef struct ReshapeParameter { + // primitive parameter + OpParameter op_parameter_; + int shape_dim_; + int shape_[MAX_SHAPE_SIZE]; + + // other parameter + ReshapeQuantArg quant_para_; + int thread_count_; +} ReshapeParameter; + +#endif // NNACL_RESHAHPE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/resize_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/resize_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..950821b6cd1c9b015e483a23a2047d6662ffbc18 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/resize_parameter.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_RESIZE_PARAMETER_H_ +#define NNACL_RESIZE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +typedef struct ResizeParameter { + // primitive parameter + OpParameter op_parameter_; + int method_; + int64_t new_height_; + int64_t new_width_; + int coordinate_transform_mode_; + float cubic_coeff_; + bool preserve_aspect_ratio_; +} ResizeParameter; + +typedef struct CropAndResizeParameter { + // primitive parameter + OpParameter op_parameter_; + int method_; + float extrapolation_value_; +} CropAndResizeParameter; +#endif // NNACL_RESIZE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reverse_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reverse_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..dc7c02a0bb06e05d111ac63e69f26d6e676b720c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reverse_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_REVERSE_PARAMETER_H_ +#define NNACL_REVERSE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +#define REVERSE_SHAPE_MAX_SIZE 4 + +typedef struct ReverseParameter { + OpParameter op_parameter_; + int axis_[REVERSE_SHAPE_MAX_SIZE]; + int num_axis_; +} ReverseParameter; + +#endif // NNACL_REVERSE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reverse_sequence_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reverse_sequence_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..11e299063eb53ec316c8dad247cf79ad9623f9a3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reverse_sequence_parameter.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_REVERSE_SEQUENCE_PARAMETER_H_ +#define NNACL_REVERSE_SEQUENCE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ReverseSequenceParameter { + // primitive parameter + OpParameter op_parameter_; + int seq_axis_; + int batch_axis_; + + // shape correlative + int input_shape0_[5]; + int output_shape_[5]; + int input_stride_[5]; + int output_stride_[5]; + + // other parameter + int ndim_; + int outer_count_; + int outer_stride_; + int inner_count_; + int inner_stride_; + int copy_byte_size_; + int total_data_size_; + bool is_seq_length_int32_; +} ReverseSequenceParameter; + +#endif // NNACL_REVERSE_SEQUENCE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scale_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scale_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..65ad061c5bdc7d7a81a1d3e95433e05a74176bc3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scale_parameter.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SCALE_H_ +#define NNACL_SCALE_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ScaleParameter { + OpParameter op_parameter_; + int axis_; + int activation_type_; +} ScaleParameter; + +typedef struct ScaleQuantParameter { + QuantMulArg scale_mul_arg_; + QuantMulArg offset_mul_arg_; + int input_zp_; + int scale_zp_; + int offset_zp_; + int output_zp_; + int output_activation_min_; + int output_activation_max_; +} ScaleQuantParameter; + +#endif // NNACL_SCALE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scatter_elements_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scatter_elements_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..b1547dac1ac5849b7adcc89a5baf007f2f9db6aa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scatter_elements_parameter.h @@ -0,0 +1,25 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_SCATTER_ELEMENTS_PARAMETER_H_ +#define NNACL_SCATTER_ELEMENTS_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +typedef struct ScatterElementsParameter { + OpParameter op_parameter_; + int axis_; +} ScatterElementsParameter; + +#endif // NNACL_SCATTER_ELEMENTS_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scatter_nd_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scatter_nd_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..6a70ad86b6bdb43a2aa25645bb1a4082c3689426 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scatter_nd_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SCATTER_ND_PARAMETER_H_ +#define NNACL_SCATTER_ND_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ScatterNDParameter { + OpParameter op_parameter; + int num_unit; + int unit_size; + int data_type_len; +} ScatterNDParameter; + +#endif // NNACL_SCATTER_ND_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sequence_unstack_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sequence_unstack_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..14ac9d56d02b8659a46b82732d4d2fd50b135d54 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sequence_unstack_parameter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_SEQUENCE_UNSTACK_PARAMETER_H_ +#define MINDSPORE_NNACL_SEQUENCE_UNSTACK_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct SequenceUnstackParameter { + // primitive parameter + OpParameter op_parameter_; + int num_; + int axis_; + + // other parameter + int pre_dims_; + int axis_dim_; + int after_dims_; +} SequenceUnstackParameter; + +#endif // MINDSPORE_NNACL_SEQUENCE_UNSTACK_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sigmoid_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sigmoid_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..d864ef5d1bab4ec6821b48f7b2ae95b6cc33709f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sigmoid_parameter.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SIGMOID_PARAMETER_H_ +#define NNACL_SIGMOID_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct SigmoidParameter { + // primitive parameter + OpParameter op_parameter_; + + // shape correlative + const int *in_shape_; + const int *out_shape_; + + // other parameter + SigmoidQuantArg quant_arg; + double alpha_; + int thread_count_; + int64_t offset_[MAX_SHAPE_SIZE]; + int64_t in_offset_[MAX_SHAPE_SIZE]; + int64_t axis_; + int input_dim_; + int element_num; +} SigmoidParameter; + +#endif // NNACL_SIGMOID_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/skip_gram_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/skip_gram_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..46bb751f15c5144445f5a9df58aa9ae3022c648d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/skip_gram_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SKIP_GRAM_PARAMETER_H_ +#define NNACL_SKIP_GRAM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct SkipGramParameter { + // primitive parameter + OpParameter op_parameter_; + bool include_all_ngrams; + int max_skip_size; + int ngram_size; +} SkipGramParameter; + +#endif // NNACL_SKIP_GRAM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/slice_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/slice_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..ca5af2fa5da8c40e46cf7fa4c6410d6310fc59c4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/slice_parameter.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SLICE_PARAMETER_H_ +#define NNACL_SLICE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct SliceQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + QuantMulArg multiplier_; +} SliceQuantArg; + +typedef struct SliceParameter { + OpParameter op_parameter_; + int32_t axis_[DIMENSION_8D]; +} SliceParameter; + +#endif // NNACL_SLICE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/softmax_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/softmax_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..428e06735875298c290996079d7545f6293c9b6f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/softmax_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SOFTMAX_PARAMETER_H_ +#define NNACL_SOFTMAX_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct SoftmaxParameter { + OpParameter op_parameter_; + int32_t axis_; +} SoftmaxParameter; + +#endif // NNACL_SOFTMAX_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/space_to_depth_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/space_to_depth_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..fa337a8c8153703323b753120dab45ec328c1d6c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/space_to_depth_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LITE_SRC_BACKEND_ARM_NNACL_SPACE_TO_DEPTH_PARAMETER_H_ +#define LITE_SRC_BACKEND_ARM_NNACL_SPACE_TO_DEPTH_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +typedef struct SpaceToDepthParameter { + // primitive parameter + OpParameter op_parameter_; + int32_t block_size_; + int32_t date_type_len; +} SpaceToDepthParameter; + +#endif // LITE_SRC_BACKEND_ARM_NNACL_SPACE_TO_DEPTH_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sparse_to_dense_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sparse_to_dense_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..2aa82ea0826ed25fbc934856818c819ef7a01dfe --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sparse_to_dense_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SPARSE_TO_DENSE_PARAMETER_H_ +#define NNACL_SPARSE_TO_DENSE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct SparseToDenseParameter { + // primitive parameter + OpParameter op_parameter_; + bool validate_indices_; + bool is_scalar; + int index_num; + int output_num; + int output_stride[DIMENSION_4D]; +} SparseToDenseParameter; + +#endif // NNACL_SPARSE_TO_DENSE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/splice_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/splice_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..1bb1b7b7ea53836a3ec7724eb55a5976e2fe47b5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/splice_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SPLICE_PARAMETER_H_ +#define NNACL_SPLICE_PARAMETER_H_ +#include "nnacl_c/op_base.h" +typedef struct SpliceParameter { + OpParameter op_parameter_; + int context_dim_; + int forward_indexes_dim_; + int *context_; + int *forward_indexes_; + int output_dim_; +} SpliceParameter; +#endif // NNACL_SPLICE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/split_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/split_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..5365daac43598bfb1a2577d46c04e863da485798 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/split_parameter.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SPLIT_PARAMETER_H_ +#define NNACL_SPLIT_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +#define SPLIT_STRIDES_SIZE 32 +#define SPLIT_MAX_SLICE_NUM 10 + +typedef struct SplitQuantArg { + QuantArg in_args_; + QuantArg out_args_[20]; + int output_activation_min_; + int output_activation_max_; +} SplitQuantArg; + +typedef struct SplitParameter { + // primitive parameter + OpParameter op_parameter_; + int num_split_; + int *split_sizes_; + int split_dim_; + + // shape correlative + int strides_[SPLIT_STRIDES_SIZE]; + + // other parameter + SplitQuantArg quant_arg_; + int n_dims_; + int split_count_; +} SplitParameter; + +typedef struct SplitWithOverlapParameter { + OpParameter op_parameter_; + int num_split_; + int split_dim_; + int ratio_[SPLIT_MAX_SLICE_NUM]; + int extend_top_[SPLIT_MAX_SLICE_NUM]; + int extend_bottom_[SPLIT_MAX_SLICE_NUM]; + + // other parameter + int element_bytes_; + int split_dim_size_; + int outer_total_dim_; + int inner_stride_; +} SplitWithOverlapParameter; + +#endif // NNACL_SPLIT_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/squeeze_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/squeeze_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..cfbaad63f2075647fec5fa757c1e675508867015 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/squeeze_parameter.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SQUEEZE_PARAMETER_H_ +#define NNACL_SQUEEZE_PARAMETER_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +#define SQUEEZE_OFFSET_MAX_SIZE 4 + +typedef struct SqueezeQuantArg { + QuantArg *in_quant_args_; + QuantArg *out_quant_args_; +} SqueezeQuantArg; + +typedef struct SqueezeParameter { + // primitive parameter + OpParameter op_parameter_; + int axis_[8]; + size_t axis_size_; + + // shape correlative + const int *in_shape_; + const int *out_shape_; + int offset_size_; + int64_t offset_[SQUEEZE_OFFSET_MAX_SIZE]; + int64_t in_offset_[SQUEEZE_OFFSET_MAX_SIZE]; + int input_dim_; + // other parameter + SqueezeQuantArg quant_arg; +} SqueezeParameter; + +#endif // NNACL_SQUEEZE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/stack_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/stack_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..55d66a514fc6cec213e9c2de4af99f5153395be8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/stack_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_STACK_PARAMETER_H_ +#define NNACL_STACK_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +typedef struct StackParameter { + // primitive parameter + OpParameter op_parameter_; + int32_t axis_; +} StackParameter; + +#endif // NNACL_STACK_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/strided_slice_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/strided_slice_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..3ff8618b338f7ad0896ddf2d0b3bd2f76121db5f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/strided_slice_parameter.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_STRIDED_SLICE_PARAMETER_H_ +#define NNACL_STRIDED_SLICE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct StridedSliceParameter { + // primitive parameter + OpParameter op_parameter_; + int begins_[MAX_SHAPE_SIZE]; + int ends_[MAX_SHAPE_SIZE]; + int strides_[MAX_SHAPE_SIZE]; + int isScale; + + // shape correlative + int in_shape_length_; + int in_shape_[MAX_SHAPE_SIZE]; + + // other parameter + int num_axes_; + TypeIdC data_type; + int begins_mask_; + int ends_mask_; + int ellipsisMask_; + int newAxisMask_; + int shrinkAxisMask_; +} StridedSliceParameter; + +#endif // NNACL_STRIDED_SLICE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_array_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_array_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..aca72845c6f510f79610fedab45bd3ac01d5e340 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_array_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_TENSOR_ARRAY_PARAMETER_H_ +#define NNACL_TENSOR_ARRAY_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +typedef struct TensorArrayParameter { + OpParameter op_parameter_; + bool dynamic_size_; + bool identical_element_shapes_; + int element_shape_[MAX_SHAPE_SIZE]; + int element_shape_size_; + int data_type_; +} TensorArrayParameter; + +#endif // NNACL_TENSOR_ARRAY_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c.h new file mode 100644 index 0000000000000000000000000000000000000000..6d515a2839268d4aca5cac5772fb5763671e48a6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_TENSOR_C_H_ +#define NNACL_TENSOR_C_H_ +#include "nnacl_c/op_base.h" + +typedef struct TensorC { + bool shape_changed_; + int data_type_; + int format_; + int category_; + void *data_; + size_t shape_size_; + int shape_[MAX_SHAPE_SIZE]; + char *name_; // only used in micro now. +} TensorC; + +#endif // NNACL_TENSOR_C_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.c new file mode 100644 index 0000000000000000000000000000000000000000..2a2564e8559190cf1c7bb8c1157861a1893143c0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.c @@ -0,0 +1,439 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use tensor file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/nnacl_common.h" + +int CheckAugmentNull(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter) { + NNACL_CHECK_NULL_RETURN_ERR(inputs); + NNACL_CHECK_NULL_RETURN_ERR(outputs); + for (size_t i = 0; i < inputs_size; i++) { + if (inputs[i] == NULL) { + return NNACL_NULL_PTR; + } + } + for (size_t i = 0; i < outputs_size; i++) { + if (outputs[i] == NULL) { + return NNACL_NULL_PTR; + } + } + if (parameter == NULL) { + return NNACL_NULL_PTR; + } + return NNACL_OK; +} + +int CheckAugmentNullSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret == NNACL_NULL_PTR) { + return NNACL_NULL_PTR; + } + if (inputs_size != inputs_size_obj || outputs_size != outputs_size_obj) { + return NNACL_INPUT_TENSOR_ERROR; + } + return NNACL_OK; +} + +int CheckAugmentNullSizeInputTwo(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, const OpParameter *parameter, size_t inputs_size_obj_0, + size_t inputs_size_obj_1, size_t outputs_size_obj) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret == NNACL_NULL_PTR) { + return NNACL_NULL_PTR; + } + if ((inputs_size != inputs_size_obj_0 && inputs_size != inputs_size_obj_1) || outputs_size != outputs_size_obj) { + return NNACL_INPUT_TENSOR_ERROR; + } + return NNACL_OK; +} + +int CheckAugmentNullInputSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret == NNACL_NULL_PTR) { + return NNACL_NULL_PTR; + } + if (inputs_size != inputs_size_obj) { + return NNACL_INPUT_TENSOR_ERROR; + } + return NNACL_OK; +} + +int CheckAugmentNullOutputSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t outputs_size_obj) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret == NNACL_NULL_PTR) { + return NNACL_NULL_PTR; + } + if (outputs_size != outputs_size_obj) { + return NNACL_INPUT_TENSOR_ERROR; + } + return NNACL_OK; +} + +int CheckAugmentWithMinSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret == NNACL_NULL_PTR) { + return NNACL_NULL_PTR; + } + if (inputs_size < inputs_size_obj || outputs_size < outputs_size_obj) { + return NNACL_INPUT_TENSOR_ERROR; + } + return NNACL_OK; +} + +void SetShapeTensor(TensorC *dst, const TensorC *src) { + for (size_t i = 0; i < src->shape_size_; i++) { + dst->shape_[i] = src->shape_[i]; + } + dst->shape_size_ = src->shape_size_; +} + +void SetShapeArray(TensorC *dst, const int *src, size_t src_size) { + for (size_t i = 0; i < src_size && i < MAX_SHAPE_SIZE; i++) { + dst->shape_[i] = src[i]; + } + dst->shape_size_ = src_size; +} + +void SetDataTypeFormat(TensorC *dst, const TensorC *src) { + dst->format_ = src->format_; + dst->data_type_ = src->data_type_; +} + +int NNACLGetBatch(const TensorC *tensor) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return -1; + } + switch (tensor->format_) { + case Format_NHWC: + case Format_NHWC4: + case Format_NCHW: + case Format_NC4HW4: + case Format_NC8HW8: + case Format_KCHW: + case Format_KHWC: + case Format_NC: + case Format_NC4: + return tensor->shape_[kNHWC_N]; + case Format_HWCK: + case Format_CHWK: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kHWCN_N]; + case Format_HWKC: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kHWNC_N]; + case Format_CKHW: + return tensor->shape_[1]; + default: + return -1; + } +} +int NNACLGetHeight(const TensorC *tensor) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return -1; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_CKHW: + case Format_NC4HW4: + case Format_NC8HW8: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kNCHW_H]; + case Format_NHWC: + case Format_NHWC4: + case Format_KHWC: + case Format_CHWK: + return tensor->shape_[kNHWC_H]; + case Format_HWCK: + case Format_HWKC: + case Format_HW: + case Format_HW4: + return tensor->shape_[0]; + default: + return -1; + } +} +int NNACLGetWidth(const TensorC *tensor) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return -1; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_CKHW: + case Format_NC4HW4: + case Format_NC8HW8: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kNCHW_W]; + case Format_KHWC: + case Format_NHWC: + case Format_NHWC4: + case Format_CHWK: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kNHWC_W]; + case Format_HWCK: + case Format_HWKC: + case Format_HW: + case Format_HW4: + return tensor->shape_[1]; + default: + return -1; + } +} +int NNACLGetChannel(const TensorC *tensor) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return -1; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_NC: + case Format_NC4: + case Format_NC4HW4: + case Format_NC8HW8: + return tensor->shape_[kNCHW_C]; + case Format_HWCK: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kHWCN_C]; + case Format_HWKC: + case Format_NHWC: + case Format_NHWC4: + case Format_KHWC: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kNHWC_C]; + case Format_CKHW: + case Format_CHWK: + return tensor->shape_[0]; + default: + return -1; + } +} + +void NNACLSetBatch(TensorC *tensor, int batch) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return; + } + switch (tensor->format_) { + case Format_NHWC: + case Format_NHWC4: + case Format_NCHW: + case Format_NC4HW4: + case Format_NC8HW8: + case Format_KCHW: + case Format_KHWC: + case Format_NC: + case Format_NC4: + tensor->shape_[kNHWC_N] = batch; + return; + case Format_HWCK: + case Format_CHWK: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kHWCN_N] = batch; + return; + case Format_HWKC: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kHWNC_N] = batch; + return; + case Format_CKHW: + tensor->shape_[1] = batch; + return; + default: + return; + } +} + +void NNACLSetHeight(TensorC *tensor, int height) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_CKHW: + case Format_NC4HW4: + case Format_NC8HW8: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kNCHW_H] = height; + return; + case Format_NHWC: + case Format_NHWC4: + case Format_KHWC: + case Format_CHWK: + tensor->shape_[kNHWC_H] = height; + return; + case Format_HWCK: + case Format_HWKC: + case Format_HW: + case Format_HW4: + tensor->shape_[0] = height; + return; + default: + return; + } +} + +void NNACLSetWidth(TensorC *tensor, int width) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_CKHW: + case Format_NC4HW4: + case Format_NC8HW8: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kNCHW_W] = width; + return; + case Format_KHWC: + case Format_NHWC: + case Format_NHWC4: + case Format_CHWK: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kNHWC_W] = width; + return; + case Format_HWCK: + case Format_HWKC: + case Format_HW: + case Format_HW4: + tensor->shape_[1] = width; + return; + default: + return; + } +} + +void NNACLSetChannel(TensorC *tensor, int channel) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_NC: + case Format_NC4: + case Format_NC4HW4: + case Format_NC8HW8: + tensor->shape_[kNCHW_C] = channel; + return; + case Format_HWCK: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kHWCN_C] = channel; + return; + case Format_HWKC: + case Format_NHWC: + case Format_NHWC4: + case Format_KHWC: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kNHWC_C] = channel; + return; + case Format_CKHW: + case Format_CHWK: + tensor->shape_[0] = channel; + return; + default: + return; + } +} + +int NNACLGetSize(const TensorC *tensor) { + int element_num = NNACLGetElementNum(tensor); + int data_type_size = (int)DataTypeCSize(tensor->data_type_); + return element_num * data_type_size; +} + +int NNACLGetElementNum(const TensorC *tensor) { + if (tensor == NULL) { + return -1; + } + if (tensor->shape_size_ == 0) { + return 1; // scalar mode + } + int res = 1; + for (size_t i = 0; i < tensor->shape_size_; i++) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(res, tensor->shape_[i], NNACL_ERRCODE_MUL_OVERFLOW); + res = res * tensor->shape_[i]; + } + + int c = NNACLGetChannel(tensor); + if (c == 0) { + return res; + } + if (tensor->format_ == Format_NC4HW4) { + res = res / c * UP_ROUND(c, C4NUM); + } + if (tensor->format_ == Format_NC8HW8) { + res = res / c * UP_ROUND(c, C8NUM); + } + return res; +} + +int NNACLGetDimensionSize(const TensorC *tensor, const size_t index) { + int dim_size = -1; + if (index < tensor->shape_size_) { + dim_size = tensor->shape_[index]; + } + return dim_size; +} + +bool NNACLIsShapeSame(const TensorC *tensor1, const TensorC *tensor2) { + if (tensor1->shape_size_ != tensor2->shape_size_) { + return false; + } + for (size_t i = 0; i < tensor1->shape_size_; i++) { + if (tensor1->shape_[i] != tensor2->shape_[i]) { + return false; + } + } + return true; +} + +bool NNACLIsConst(const TensorC *tensor) { + return (tensor->category_ == ConstTensor || tensor->category_ == ConstScalar) && tensor->data_ != NULL; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..29a8112ebd6f3131bc149a566b38eab97f0c8052 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.h @@ -0,0 +1,47 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_TENSORC_UTILS_H_ +#define NNACL_TENSORC_UTILS_H_ + +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int NNACLGetBatch(const TensorC *tensor); +int NNACLGetHeight(const TensorC *tensor); +int NNACLGetWidth(const TensorC *tensor); +int NNACLGetChannel(const TensorC *tensor); +void NNACLSetBatch(TensorC *tensor, int batch); +void NNACLSetHeight(TensorC *tensor, int height); +void NNACLSetWidth(TensorC *tensor, int width); +void NNACLSetChannel(TensorC *tensor, int channel); +int NNACLGetElementNum(const TensorC *tensor); +int NNACLGetSize(const TensorC *tensor); +int NNACLGetDimensionSize(const TensorC *tensor, const size_t index); +bool NNACLIsShapeSame(const TensorC *tensor1, const TensorC *tensor2); +bool NNACLIsConst(const TensorC *tensor); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_TENSORC_UTILS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c.h new file mode 100644 index 0000000000000000000000000000000000000000..d0839275be1b86c218dbdc863333162757006adc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c.h @@ -0,0 +1,41 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_TENSORLIST_C_H_ +#define NNACL_TENSORLIST_C_H_ + +#include "nnacl_c/tensor_c.h" + +typedef struct vvector { + int **shape_; // value of shapes + int *shape_size_; // size of shape + size_t size_; // number of shapes +} vvector; + +typedef struct TensorListC { + bool shape_changed_; + int data_type_; + int format_; + int shape_value_; + int tensors_data_type_; // element_data_type_, keep same as c++ + int max_elements_num_; + TensorC **tensors_; + size_t element_num_; + size_t element_shape_size_; + int element_shape_[MAX_SHAPE_SIZE]; +} TensorListC; + +#endif // NNACL_TENSORLIST_C_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.c new file mode 100644 index 0000000000000000000000000000000000000000..4a6122c3e9f02c25988d8ee398ad822a87764aae --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.c @@ -0,0 +1,82 @@ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use tensor file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/tensorlist_c_utils.h" + +int MallocTensorListData(TensorListC *tensor_list, TypeIdC dtype, const vvector *tensor_shape) { + // This function will create a new tensors_ + // Your must to set shape(param2: tensor_shape) and data_type_(tensors_data_type_ = param1: dtype) of each tensor in + // tensors_. After that, you need to call function:MallocData to malloc data buf of each tensor in tensors_. + + if (tensor_list->element_num_ == 0) { + return NNACL_OK; + } + if (((size_t)(tensor_list->element_num_)) != tensor_shape->size_) { + return NNACL_ERR; + } + tensor_list->tensors_data_type_ = dtype; + void *addr = malloc(tensor_list->element_num_ * sizeof(void *) + + tensor_list->element_num_ * sizeof(TensorC)); // free in infer_manager + if (addr == NULL) { + free(tensor_list->tensors_); + return NNACL_NULL_PTR; + } + memset(addr, 0, tensor_list->element_num_ * sizeof(void *) + tensor_list->element_num_ * sizeof(TensorC)); + tensor_list->tensors_ = (TensorC **)addr; + TensorC *tensors = (TensorC *)(tensor_list->tensors_ + tensor_list->element_num_); + for (size_t i = 0; i < tensor_list->element_num_; ++i) { + TensorC *tensor = tensors + i; + tensor_list->tensors_[i] = tensor; + tensor->format_ = Format_NHWC; + tensor->data_type_ = dtype; + ShapeSet(tensor->shape_, &(tensor->shape_size_), tensor_shape->shape_[i], (size_t)tensor_shape->shape_size_[i]); + } + return NNACL_OK; +} + +int TensorListMergeShape(int *element_shape, size_t *element_shape_size, const int *tmp, size_t tmp_size) { + if (*element_shape_size >= 255 || element_shape[0] == -1) { + ShapeSet(element_shape, element_shape_size, tmp, tmp_size); + return NNACL_OK; + } + if (*element_shape_size != tmp_size) { + return NNACL_ERR; + } + for (size_t j = 0; j < tmp_size; ++j) { + if (element_shape[j] >= 0 && tmp[j] >= 0 && element_shape[j] != tmp[j]) { + return NNACL_ERR; + } + element_shape[j] = element_shape[j] >= 0 ? element_shape[j] : tmp[j]; + } + return NNACL_OK; +} + +bool TensorListIsFullyDefined(const int *shape, size_t shape_size) { + for (size_t i = 0; i < shape_size; ++i) { + if (shape[i] < 0) { + return false; + } + } + return true; +} + +bool InferFlagTensorList(TensorC *tensorc) { + TensorListC *input_tensor_list = (TensorListC *)tensorc; + if (input_tensor_list->shape_value_ == -1) { + return false; + } + return true; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..a69ecc126f90a76f6946766f50eaefeee95ed62a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.h @@ -0,0 +1,38 @@ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_TENSORLIST_C_UTILS_H_ +#define NNACL_TENSORLIST_C_UTILS_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensorlist_c.h" +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int MallocTensorListData(TensorListC *tensor_list, TypeIdC dtype, const vvector *tensor_shape); +int TensorListMergeShape(int *element_shape, size_t *element_shape_size, const int *tmp, size_t tmp_size); +bool TensorListIsFullyDefined(const int *shape, size_t shape_size); +bool InferFlagTensorList(TensorC *tensor_list); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_TENSORLIST_C_UTILS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..beb70c1be7aa7d26653e441bbee7fe6bbcc85898 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_TENSORLIST_PARAMETER_H_ +#define NNACL_TENSORLIST_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct TensorListParameter { + // primitive parameter + OpParameter op_parameter_; + int shape_type_; + int element_dtype_; + + // other parameter + int num_element_; +} TensorListParameter; + +#endif // NNACL_ARG_TENSORLIST_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tile_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tile_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..d7ad9ef93b8039ecfe7238bc60f81636733fd59d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tile_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_TILE_PARAMETER_H_ +#define NNACL_TILE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct TileParameter { + OpParameter op_parameter_; + size_t dims_size_; + int dims_[MAX_SHAPE_SIZE]; + int multiples_[MAX_SHAPE_SIZE]; +} TileParameter; + +#endif // NNACL_TILE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/transpose_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/transpose_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..34a7845345d8f41be5f4904fff1224e246f030ba --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/transpose_parameter.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_TRANSPOSE_PARAMETER_H_ +#define NNACL_TRANSPOSE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +// MAX_TRANSPOSE_SERIAL_SIZE = 64 * 3 * 512 * 512 +#define MAX_TRANSPOSE_SERIAL_SIZE 50331648 +#define MAX_TRANSPOSE_DIM_SIZE 20 +#define PERM_NUM_THREE 3 +#define PERM_NUM_FOUR 4 + +typedef struct TransposeParameter { + // primitive parameter + OpParameter op_parameter_; + int perm_[MAX_TRANSPOSE_DIM_SIZE]; + size_t perm_size_; + bool conjugate_; + + // shape correlative + int strides_[MAX_TRANSPOSE_DIM_SIZE]; + int out_strides_[MAX_TRANSPOSE_DIM_SIZE]; + + // other parameter + int num_axes_; + int data_num_; +} TransposeParameter; + +#endif // NNACL_TRANSPOSE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/triu_tril_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/triu_tril_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..b9b5a95e942c07d580da58ea07b03bc29352ed2c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/triu_tril_parameter.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_TRIU_TRIL_PARAMETER_H_ +#define NNACL_TRIU_TRIL_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct TriuParameter { + // Primitive parameter + OpParameter op_parameter_; +} TriuParameter; + +typedef struct TrilParameter { + // Primitive parameter + OpParameter op_parameter_; +} TrilParameter; + +#endif // NNACL_TRIU_TRIL_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/unsqueeze_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/unsqueeze_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..e7f6d643f4691f163560b6a1fd5bb471ad840290 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/unsqueeze_parameter.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_UNSQUEEZE_PARAMETER_H_ +#define NNACL_UNSQUEEZE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct UnSqueezeQuantArg { + int *output_shape_; + float alpha; + int axis_; + size_t input_num_; + QuantArg in_quant_args_; + QuantArg out_quant_args_; +} UnSqueezeQuantArg; + +typedef struct UnSqueezeParameter { + // primitive parameter + OpParameter op_parameter_; + int dims_[COMM_SHAPE_SIZE]; + int num_dim_; + + // shape correlative + const int *in_shape_; + const int *out_shape_; + int64_t offset_[COMM_SHAPE_SIZE]; + int64_t axis_; + + // other parameter + UnSqueezeQuantArg quant_arg; + int thread_count_; +} UnSqueezeParameter; + +#endif // NNACL_UNSQUEEZE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/unstack_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/unstack_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..0e328190616f94a6ec2b9c2c3a7542e3d18ad162 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/unstack_parameter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_UNSTACK_PARAMETER_H_ +#define NNACL_UNSTACK_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct UnstackParameter { + // primitive parameter + OpParameter op_parameter_; + int num_; + int axis_; + + // other parameter + int pre_dims_; + int axis_dim_; + int after_dims_; +} UnstackParameter; + +#endif // NNACL_UNSTACK_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/upsample_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/upsample_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..79b931b611d6a5a6cb24b77dba5e27db16ddb010 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/upsample_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_UPSAMPLE_PARAMETER_H_ +#define NNACL_UPSAMPLE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +typedef struct { + // primitive parameter + OpParameter op_parameter_; + + // other parameter + int method_; // 0 for bilinear; 1 for nearest +} UpsampleParameter; + +#endif // NNACL_UPSAMPLE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/where_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/where_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..973569facf296e69ddfe6889910b7de6476af04a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/where_parameter.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_WHERE_PARAMETER_H_ +#define NNACL_WHERE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct WhereParameter { + OpParameter op_parameter_; +} WhereParameter; + +#endif // NNACL_WHERE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/string/lsh_projection.h b/mindspore-lite/src/litert/kernel/cpu/string/lsh_projection.h index 30d190d57866ac64f9375108182602a0bc052070..7f55b28191ee09cfbe6aa9452e811c897f8f441d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/string/lsh_projection.h +++ b/mindspore-lite/src/litert/kernel/cpu/string/lsh_projection.h @@ -19,7 +19,7 @@ #include -#include "nnacl/lsh_projection_parameter.h" +#include "nnacl_c/lsh_projection_parameter.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/string/predict.h b/mindspore-lite/src/litert/kernel/cpu/string/predict.h index 768d1d4c013fc67088bfab4d04e05d28b103a330..9f5f6fa4e94ef5543827fcf9b879854e5c040faa 100644 --- a/mindspore-lite/src/litert/kernel/cpu/string/predict.h +++ b/mindspore-lite/src/litert/kernel/cpu/string/predict.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/predict_parameter.h" +#include "nnacl_c/predict_parameter.h" namespace mindspore::kernel { class PredictCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/string/skip_gram.h b/mindspore-lite/src/litert/kernel/cpu/string/skip_gram.h index f5e7a29b251c1c58126d84f75e736b7f56cddb54..9b32619bc1dfb1f4beead8007c7948e65e39c8b1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/string/skip_gram.h +++ b/mindspore-lite/src/litert/kernel/cpu/string/skip_gram.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/skip_gram_parameter.h" +#include "nnacl_c/skip_gram_parameter.h" #include "src/common/string_utils.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/gpu/opencl/opencl_executor.cc b/mindspore-lite/src/litert/kernel/gpu/opencl/opencl_executor.cc index 6a5780812fa92dc9fee95c48278e09c4eacd8a43..fdc3a87cb9b2cc2616e64bf6195f5d65754ecea5 100644 --- a/mindspore-lite/src/litert/kernel/gpu/opencl/opencl_executor.cc +++ b/mindspore-lite/src/litert/kernel/gpu/opencl/opencl_executor.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/gpu/opencl/opencl_executor.h" #include "src/litert/kernel/opencl/utils.h" -#include "nnacl/pack.h" +#include "nnacl_c/pack.h" #include "include/errorcode.h" namespace mindspore::lite::opencl { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/activation.h b/mindspore-lite/src/litert/kernel/opencl/kernel/activation.h index 11bcbb7894c62157684d87b00194259a7230f9dc..dcec9c9b3b182e17748d4f2c617933f33d385dbd 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/activation.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/activation.h @@ -21,7 +21,7 @@ #include #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" namespace mindspore::kernel { class ActivationOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/argminmax.h b/mindspore-lite/src/litert/kernel/opencl/kernel/argminmax.h index 6277b480f5204a35050e738aa39545d06331c4c0..0fdfe87ca678b4d8b3a40f29bf7548b1ac9f8f5f 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/argminmax.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/argminmax.h @@ -19,8 +19,8 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/arg_min_max_parameter.h" -#include "nnacl/kernel/arg_min_max.h" +#include "nnacl_c/arg_min_max_parameter.h" +#include "nnacl_c/kernel/arg_min_max.h" namespace mindspore::kernel { class ArgMinMaxOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic.cc index 603ae5595c59046eb7a4a6ac5331fdd66859d744..5ffc86b040e1a1f60567040c6a54e358a9bc5c73 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic.cc @@ -21,7 +21,7 @@ #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/utils.h" #include "src/litert/kernel/opencl/cl/arithmetic.cl.inc" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" using mindspore::kernel::KERNEL_ARCH::kGPU; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic_self.h b/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic_self.h index d819fb29e324a8b7883cb2e33eccbb28c954edec..68c284682c944c551e8d5b7dd3978e75fc7c0701 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic_self.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic_self.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/arithmetic_self_parameter.h" using mindspore::schema::PrimitiveType_Abs; using mindspore::schema::PrimitiveType_Ceil; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/batch_to_space_nd.h b/mindspore-lite/src/litert/kernel/opencl/kernel/batch_to_space_nd.h index 1d5fa37090873fdf3372c99e3591ec99f68b15b1..a8e6049fefdfa5038d917058e61d2d0bcee1bef9 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/batch_to_space_nd.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/batch_to_space_nd.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/batch_to_space_parameter.h" +#include "nnacl_c/batch_to_space_parameter.h" namespace mindspore::kernel { class BatchToSpaceNDOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.cc index c6019a4328eb2984c7083fa653ef4e809444a925..6f240aebfa692116ac1d4d3d1ed69dc36879a1e9 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.cc @@ -21,7 +21,7 @@ #include "src/litert/kernel/opencl/kernel/batchnorm.h" #include "src/litert/kernel/opencl/utils.h" #include "src/litert/kernel/opencl/cl/batchnorm.cl.inc" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" using mindspore::kernel::KERNEL_ARCH::kGPU; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.h b/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.h index 44563b8d8d54bba6da6b1e7b6089132b6ea22104..94f384046b0a95fc57969b663ea8910cd22c9fc6 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/fp32/batchnorm_fp32.h" +#include "nnacl_c/fp32/batchnorm_fp32.h" namespace mindspore::kernel { class BatchNormOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/concat.h b/mindspore-lite/src/litert/kernel/opencl/kernel/concat.h index 3876a745928dd4d94a45fc197e62085e39ed2561..15094387326ba3919f9095c46d1444e33e267a7c 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/concat.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/concat.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d.h b/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d.h index f10a002661862c5631b45805968dae16be278584..4ad08678e501dde335f715cb8ad08d4a4e48eddf 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d.h @@ -22,7 +22,7 @@ #include "src/tensor.h" #include "src/litert/kernel/opencl/opencl_kernel.h" #include "schema/model_generated.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "schema/ops_generated.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.cc index d38d1eff40622e819f888206e2ef579bbb95fb8d..e8efae25a3fa348c7ce3642b68e3c0109e80b70a 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.cc @@ -17,7 +17,7 @@ #include "src/litert/kernel/opencl/kernel/conv2d_transpose.h" #include #include -#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl_c/fp32/common_func_fp32.h" #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/cl/conv2d_transpose.cl.inc" #include "src/litert/kernel/opencl/utils.h" diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.h b/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.h index 3a0b6491677bd9b79a1a7c95875053c290cb7794..bb756aef2ee91cba3c5920c5445e0a6b10151980 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "src/litert/kernel/opencl/opencl_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/crop.h b/mindspore-lite/src/litert/kernel/opencl/kernel/crop.h index 0abaa292849fbff009adb99dd1e226fe8b0aab44..117ee024406eec7df1b049f10d1f6e86e99f32df 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/crop.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/crop.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/crop_parameter.h" +#include "nnacl_c/crop_parameter.h" namespace mindspore::kernel { class CropOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.cc index 6e3a1888b19a23c4cbf04519aee94029c62507d1..6b923fb79708a9de96e1354f9d8e02c34adb621d 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.cc @@ -22,8 +22,8 @@ #include #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/utils.h" -#include "nnacl/fp32/common_func_fp32.h" -#include "nnacl/op_base.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/op_base.h" #include "include/errorcode.h" #include "src/litert/kernel/opencl/cl/depthwise_conv2d.cl.inc" diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.h b/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.h index 26172d7aadc998f2010150e2f3d74461600c2b0d..909f61d81a5c1d8ac898f213d256462735f23127 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" using mindspore::lite::opencl::MemType; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/fill.h b/mindspore-lite/src/litert/kernel/opencl/kernel/fill.h index 7b78841315417a5260f834129ff0df95c3424d86..f81873a64dc0d3b21bae042706799199a7555dd6 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/fill.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/fill.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_FILL_H_ #include -#include "nnacl/base/fill_base.h" +#include "nnacl_c/base/fill_base.h" #include "src/litert/kernel/opencl/opencl_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.cc index 99efc6f8d416fad1550cb366c497373fb10626cb..9f61a96633f82402f26c6ec5c2b75344bc0b27d4 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.cc @@ -17,7 +17,7 @@ #include #include #include -#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl_c/fp32/common_func_fp32.h" #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/kernel/fullconnection.h" #include "src/litert/kernel/opencl/utils.h" diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.h b/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.h index c895a9db7e71f245ab5fada9b798a450bdd8ae7f..da851b055dc275f6ce6b6737acb0b037d4cbae66 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::kernel { class FullConnectionOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.cc index 16e067e1084b6eceb3bb5f3f1be86698300866cc..386796c74fd6d587053a19454d7c8ce3b4839132 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.cc @@ -17,9 +17,9 @@ #include #include "src/litert/kernel/opencl/utils.h" #include "include/errorcode.h" -#include "nnacl/arithmetic_parameter.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/scale_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/scale_parameter.h" #include "src/litert/infer_manager.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.h b/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.h index 30e042af3ea7d4c7af53245fc8ce55b676574679..549004016640d98c81e122b030bd5d43f2620bb6 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.h @@ -28,7 +28,7 @@ #include "src/litert/kernel/opencl/kernel/arithmetic_self.h" #include "src/litert/kernel/opencl/kernel/to_format.h" #include "schema/ops_generated.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" using mindspore::schema::ActivationType; using mindspore::schema::PrimitiveType; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/gather.h b/mindspore-lite/src/litert/kernel/opencl/kernel/gather.h index b42ccf030115d27b04b468e1b059c9ba80842dc1..98da3725d17b2aeab63db2c08f7966785b1d467a 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/gather.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/gather.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/gather_parameter.h" +#include "nnacl_c/gather_parameter.h" namespace mindspore::kernel { class GatherOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/int8/arithmetic_int8.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/int8/arithmetic_int8.cc index cca73b1660d808d000fd460dbfd9e41db906c95d..48cc00eff76abf6c4f17ae197dccc24de7628118 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/int8/arithmetic_int8.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/int8/arithmetic_int8.cc @@ -16,12 +16,12 @@ #include "src/litert/kernel/opencl/kernel/int8/arithmetic_int8.h" #include -#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl_c/fp32/common_func_fp32.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/utils.h" #include "src/litert/kernel/opencl/cl/int8/arithmetic.cl.inc" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" using mindspore::kernel::KERNEL_ARCH::kGPU; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/layer_norm.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/layer_norm.cc index 44982a344ecc2dcafbcfbac773dc5cefcda89068..c7f4cfd491d97bcfe19d1e064388f757920a5c49 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/layer_norm.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/layer_norm.cc @@ -19,7 +19,7 @@ #include #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/kernel/layer_norm.h" -#include "nnacl/layer_norm_parameter.h" +#include "nnacl_c/layer_norm_parameter.h" #include "src/litert/kernel/opencl/utils.h" #include "src/litert/kernel/opencl/cl/layer_norm.cl.inc" diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/matmul.h b/mindspore-lite/src/litert/kernel/opencl/kernel/matmul.h index 07f971aa45eb9e2f8035806472fe99682c260bd5..3981a8c1cce2d6a08032c93f31f033cc965624d7 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/matmul.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/matmul.h @@ -20,7 +20,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" #include "src/common/utils.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::kernel { class MatMulOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/one_hot.h b/mindspore-lite/src/litert/kernel/opencl/kernel/one_hot.h index d1bec9fbeb7d5a4eaf0acf00dbdc5132c2420b76..dcec9932ffaca4be4d42159aa5d86ded1e53f831 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/one_hot.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/one_hot.h @@ -21,7 +21,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/fp32/one_hot_fp32.h" +#include "nnacl_c/fp32/one_hot_fp32.h" namespace mindspore::kernel { class OneHotOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/pad.h b/mindspore-lite/src/litert/kernel/opencl/kernel/pad.h index 195ec7c11360dce65c9b04e00001be17dff7a705..253a9adf818573bdcf9c75c88f30ed9016f03427 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/pad.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/pad.h @@ -22,7 +22,7 @@ #include "src/tensor.h" #include "src/litert/kernel/opencl/opencl_kernel.h" #include "schema/model_generated.h" -#include "nnacl/pad_parameter.h" +#include "nnacl_c/pad_parameter.h" namespace mindspore::kernel { class PadOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/pooling2d.h b/mindspore-lite/src/litert/kernel/opencl/kernel/pooling2d.h index f65e93faaa384e58f81d3b3870a0fe673a7e45d8..5e6ef777b096bf4a7856c06c306ed6b98f6c9923 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/pooling2d.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/pooling2d.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/fp32/pooling_fp32.h" +#include "nnacl_c/fp32/pooling_fp32.h" namespace mindspore::kernel { class PoolingOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/power.h b/mindspore-lite/src/litert/kernel/opencl/kernel/power.h index d2c3eae2e794d97f02ad0f13b2993f99499457cd..33f02aac26b6ea23d6ac2c6352545aeb44fbb154 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/power.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/power.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_POWER_H_ #include -#include "nnacl/fp32/power_fp32.h" +#include "nnacl_c/fp32/power_fp32.h" #include "src/litert/kernel/opencl/opencl_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/prelu.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/prelu.cc index 501bcf56d00ab73429d08d623a559c101d004bf2..8c31914032ac8598c148ad46bfc5982b0bd42d05 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/prelu.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/prelu.cc @@ -17,13 +17,13 @@ */ #include "src/litert/kernel/opencl/kernel/prelu.h" -#include +#include "nnacl_c/prelu_parameter.h" #include #include #include "src/litert/kernel/opencl/cl/prelu.cl.inc" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl_c/fp32/common_func_fp32.h" using mindspore::kernel::KERNEL_ARCH::kGPU; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/reduce.h b/mindspore-lite/src/litert/kernel/opencl/kernel/reduce.h index 1b093caa8308fdfb361c102bb1b7a2202184e2be..f66d28b0abf627f8eb8f8667e72ce930d54d9821 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/reduce.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/reduce.h @@ -21,7 +21,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/reduce_parameter.h" namespace mindspore::kernel { class ReduceOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/resize.h b/mindspore-lite/src/litert/kernel/opencl/kernel/resize.h index 74d1098a5c7b213eeb02e54d63295d2e0cd83cbe..8f7d9143c49552c5fa77975b4d57acad4073bf22 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/resize.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/resize.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" namespace mindspore::kernel { class ResizeOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/scale.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/scale.cc index c63ef367d037ba145afe5053aeb71bdd2f63aa6d..9fdfa3f782d536407c06b38197b25017cb15923d 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/scale.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/scale.cc @@ -20,7 +20,7 @@ #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl_c/fp32/common_func_fp32.h" #include "src/litert/kernel/opencl/utils.h" #include "src/litert/kernel/opencl/cl/scale.cl.inc" diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/scale.h b/mindspore-lite/src/litert/kernel/opencl/kernel/scale.h index bac659e01620abfdd911acdd9d019b5295d803c7..09a70acfae1807612a73594cf9384eef84a48bc7 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/scale.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/scale.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SCALE_H_ #include -#include "nnacl/scale_parameter.h" +#include "nnacl_c/scale_parameter.h" #include "src/litert/kernel/opencl/opencl_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.cc index bafaa3368774940475b4dac97a433d30bbd54fd8..e15effb79da1191c5170d15748f05b6cf5e1700e 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.cc @@ -19,7 +19,7 @@ #include "include/errorcode.h" #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/utils.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "src/litert/kernel/opencl/cl/softmax.cl.inc" using mindspore::kernel::KERNEL_ARCH::kGPU; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.h b/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.h index 84f223969d1ce81694edfc626c872357bfa3bb4d..22f598da0733887c531171928f6a5e8ef3cf9a72 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/fp32/softmax_fp32.h" +#include "nnacl_c/fp32/softmax_fp32.h" namespace mindspore::kernel { class SoftmaxOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_batch_nd.h b/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_batch_nd.h index 6dcd0ae6112dee224eea8e8c2a8e9fe4bc8be805..bde3a88bb883fdd2cf16af9d8c489aabb253a3b2 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_batch_nd.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_batch_nd.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" namespace mindspore::kernel { class SpaceToBatchNDOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_depth.h b/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_depth.h index 75ecf703ace53d871619631a9dba60c8c2765cd7..e00c5177f146f353afa74a1637df5c80460ebb1f 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_depth.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_depth.h @@ -21,7 +21,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/space_to_depth_parameter.h" +#include "nnacl_c/space_to_depth_parameter.h" namespace mindspore::kernel { class SpaceToDepthOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/sparse_to_dense.h b/mindspore-lite/src/litert/kernel/opencl/kernel/sparse_to_dense.h index 63181ed7e528142c1a1c5eef07669b3058517ef3..306411b7334050b2c06660a35a8284c4b3cb8f83 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/sparse_to_dense.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/sparse_to_dense.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/fp32/sparse_to_dense_fp32.h" +#include "nnacl_c/fp32/sparse_to_dense_fp32.h" namespace mindspore::kernel { class SparseToDenseOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/split.h b/mindspore-lite/src/litert/kernel/opencl/kernel/split.h index 9f094f45d0e8317e886719c1a60af9431acc5370..80ebb4e1b4eb97b5c1b086c5e3a19289c80382cd 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/split.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/split.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::kernel { class SplitOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/stack.h b/mindspore-lite/src/litert/kernel/opencl/kernel/stack.h index 4e7bef91f45bf574bc4f77d07a2b99a059c6fa8c..9e7f1be4d2c22866b1511e36786ecd832a538aa5 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/stack.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/stack.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/stack_parameter.h" +#include "nnacl_c/stack_parameter.h" namespace mindspore::kernel { class StackOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/strided_slice.h b/mindspore-lite/src/litert/kernel/opencl/kernel/strided_slice.h index dbc78f752ed645a09d89e3fa31c3d06b4d9167cb..129de4fa2a29edfad391257fc819321943a28c9b 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/strided_slice.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/strided_slice.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/base/slice_base.h" +#include "nnacl_c/base/slice_base.h" namespace mindspore::kernel { class StridedSliceOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/transpose.h b/mindspore-lite/src/litert/kernel/opencl/kernel/transpose.h index 646657354f485311a62bb7327d19e0c75d6a2adb..f7cc7c06bcbbf73eb5ee24dbb80e1c52705e60c5 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/transpose.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/transpose.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" #include "src/litert/kernel/opencl/opencl_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/winograd.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/winograd.cc index d1d886dddd6a80f015ea94e11d201b6d2a848e61..b947c2c61e972741e5abdb0990fbb7625c3de100 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/winograd.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/winograd.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/opencl/kernel/winograd.h" #include #include "src/litert/kernel/opencl/cl/winograd.cl.inc" -#include "nnacl/base/minimal_filtering_generator.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/base/minimal_filtering_generator.h" +#include "nnacl_c/errorcode.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/opencl/opencl_fusion.cc b/mindspore-lite/src/litert/kernel/opencl/opencl_fusion.cc index 72186219a8fb8495d19cad4f49c3f69bd03c0294..b4e3ae209ea7a9010fa4708d1b3082bdb6b6a9af 100644 --- a/mindspore-lite/src/litert/kernel/opencl/opencl_fusion.cc +++ b/mindspore-lite/src/litert/kernel/opencl/opencl_fusion.cc @@ -27,13 +27,13 @@ #include "include/errorcode.h" #include "schema/ops_generated.h" #include "src/common/utils.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/pad_parameter.h" -#include "nnacl/pooling_parameter.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/scale_parameter.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/pad_parameter.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" using mindspore::schema::ActivationType; using mindspore::schema::ActivationType_LEAKY_RELU; diff --git a/mindspore-lite/src/litert/kernel/opencl/opencl_kernel.h b/mindspore-lite/src/litert/kernel/opencl/opencl_kernel.h index c8c016bda859fbf3a7b728f9fb6646f782f44eb2..86d07453d4a8ecc522558641b6d41d9e11093bc5 100644 --- a/mindspore-lite/src/litert/kernel/opencl/opencl_kernel.h +++ b/mindspore-lite/src/litert/kernel/opencl/opencl_kernel.h @@ -29,7 +29,7 @@ #include "src/litert/kernel/gpu/opencl/opencl_runtime.h" #include "src/litert/tensor_category.h" #include "src/litert/kernel/opencl/utils.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/opencl/utils.h b/mindspore-lite/src/litert/kernel/opencl/utils.h index 1b70960909c86b452d432d50d0afbadc48de1399..d392c7a1aa29f300852ca2a68fce0d99cac2774f 100644 --- a/mindspore-lite/src/litert/kernel/opencl/utils.h +++ b/mindspore-lite/src/litert/kernel/opencl/utils.h @@ -22,7 +22,7 @@ #include #include "CL/cl2.hpp" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/lite_kernel.h" #include "src/common/utils.h" #include "src/litert/kernel/opencl/opencl_kernel.h" diff --git a/mindspore-lite/src/litert/kernel_exec_util.cc b/mindspore-lite/src/litert/kernel_exec_util.cc index 82ffb7e2b75dd4b2310b90e89e640bcbe7cfcada..b4a885ee44081760103e7ea5a5ad09ec6d25fda9 100644 --- a/mindspore-lite/src/litert/kernel_exec_util.cc +++ b/mindspore-lite/src/litert/kernel_exec_util.cc @@ -20,7 +20,7 @@ #include #include #include "src/executor/sub_graph_kernel.h" -#include "nnacl/call_parameter.h" +#include "nnacl_c/call_parameter.h" #if GPU_OPENCL #include "src/litert/kernel/opencl/opencl_subgraph.h" #include "src/litert/kernel/gpu/opencl/opencl_runtime.h" diff --git a/mindspore-lite/src/litert/kernel_registry.cc b/mindspore-lite/src/litert/kernel_registry.cc index f69086e6fb8b7fb180d6e23ce0b8405d3318d799..d28eaf5b54d155eb48e53c54298ea61d55c9c514 100644 --- a/mindspore-lite/src/litert/kernel_registry.cc +++ b/mindspore-lite/src/litert/kernel_registry.cc @@ -22,7 +22,7 @@ #endif #include "src/common/ops/populate/populate_register.h" #include "src/common/version_manager.h" -#include "nnacl/pooling_parameter.h" +#include "nnacl_c/pooling_parameter.h" #if defined(ENABLE_FP16) && defined(ENABLE_ARM) #if defined(__ANDROID__) #include diff --git a/mindspore-lite/src/litert/lite_kernel.h b/mindspore-lite/src/litert/lite_kernel.h index 6b58b59dba540c7f7dc172f2adcbaec92db8164c..501ab5d1f47b4eb12b1995669bdf0a9eed0e8759 100644 --- a/mindspore-lite/src/litert/lite_kernel.h +++ b/mindspore-lite/src/litert/lite_kernel.h @@ -23,7 +23,7 @@ #include #include "src/common/utils.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/inner_context.h" #include "src/tensor.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/lite_model.h b/mindspore-lite/src/litert/lite_model.h index 2e62655ca6c2a5809ff2d3b080e91ec5680b2058..24a4da0ef096447d51466d8095b9491870d35188 100644 --- a/mindspore-lite/src/litert/lite_model.h +++ b/mindspore-lite/src/litert/lite_model.h @@ -27,7 +27,7 @@ #include "src/common/log_adapter.h" #include "src/common/version_manager.h" #include "src/litert/schema_tensor_wrapper.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/prim_util.h" #ifdef ENABLE_MODEL_OBF #include "tools/obfuscator/include/deobfuscator.h" diff --git a/mindspore-lite/src/litert/mindrt_executor.cc b/mindspore-lite/src/litert/mindrt_executor.cc index 7162ea696ef2ab4d2a5e22572799b5514ef15fc6..16c454767050e05595fa0b4ab6ba5d9cebaaf9b6 100644 --- a/mindspore-lite/src/litert/mindrt_executor.cc +++ b/mindspore-lite/src/litert/mindrt_executor.cc @@ -23,9 +23,9 @@ #include "src/common/common.h" #include "src/common/tensor_util.h" #ifdef ENABLE_FP16 -#include "nnacl/base/cast_base.h" +#include "nnacl_c/base/cast_base.h" #endif -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" #include "src/litert/kernel_exec_util.h" namespace mindspore::lite { diff --git a/mindspore-lite/src/litert/pass/format_pass/format_pass.cc b/mindspore-lite/src/litert/pass/format_pass/format_pass.cc index c18ae7f2ebfc1b9f8643cf4b6013a4a50a3f1246..0be0ad790edb39f42477aa50e11bd2d3378dc605 100644 --- a/mindspore-lite/src/litert/pass/format_pass/format_pass.cc +++ b/mindspore-lite/src/litert/pass/format_pass/format_pass.cc @@ -19,7 +19,7 @@ #include "src/litert/pass/format_pass/eliminate_transpose.h" #ifdef ENABLE_MULTI_LAYOUT #include "src/litert/kernel_registry.h" -#include "nnacl/format_transpose_parameter.h" +#include "nnacl_c/format_transpose_parameter.h" #endif #include "src/common/draw/drawer.h" diff --git a/mindspore-lite/src/litert/pass/format_pass/insert_transpose.cc b/mindspore-lite/src/litert/pass/format_pass/insert_transpose.cc index b5543f767ffc9e1d2895255cd85a1659d0c7a8a2..9b71774bc3ee6ce3349a6883745c2900b77b21be 100644 --- a/mindspore-lite/src/litert/pass/format_pass/insert_transpose.cc +++ b/mindspore-lite/src/litert/pass/format_pass/insert_transpose.cc @@ -17,7 +17,7 @@ #include "src/litert/pass/format_pass/insert_transpose.h" #include "src/litert/pass/format_pass/format_utils.h" #include "src/litert/kernel_exec_util.h" -#include "nnacl/base/format_transpose.h" +#include "nnacl_c/base/format_transpose.h" namespace mindspore::lite::pass { int InsertTranspose::TransposeConstData(kernel::KernelExec *kernel, size_t index) { diff --git a/mindspore-lite/src/litert/pass/format_pass/pass_utils.cc b/mindspore-lite/src/litert/pass/format_pass/pass_utils.cc index 7ba83911de3016cb6debc8ef41d008f4603bae2d..c90674330ce55274eb3b6b1dc02c06e3c8e40c56 100644 --- a/mindspore-lite/src/litert/pass/format_pass/pass_utils.cc +++ b/mindspore-lite/src/litert/pass/format_pass/pass_utils.cc @@ -17,8 +17,8 @@ #include "src/litert/pass/format_pass/pass_utils.h" #include #include -#include "nnacl/format_transpose_parameter.h" -#include "nnacl/arg_min_max_parameter.h" +#include "nnacl_c/format_transpose_parameter.h" +#include "nnacl_c/arg_min_max_parameter.h" namespace mindspore::lite::pass { bool IsNoneTranspose(const TransInfoPair &trans) { diff --git a/mindspore-lite/src/litert/pass/format_pass/transpose_strategy.cc b/mindspore-lite/src/litert/pass/format_pass/transpose_strategy.cc index c68cc9d16e86ccb593618b678ed04b371f636145..15ed7a3e5a5113dabacb7a697cf6d9f2c9415e3d 100644 --- a/mindspore-lite/src/litert/pass/format_pass/transpose_strategy.cc +++ b/mindspore-lite/src/litert/pass/format_pass/transpose_strategy.cc @@ -15,18 +15,18 @@ */ #include "src/litert/pass/format_pass/transpose_strategy.h" -#include "nnacl/op_base.h" -#include "nnacl/arg_min_max_parameter.h" -#include "nnacl/concat_parameter.h" -#include "nnacl/crop_parameter.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/split_parameter.h" -#include "nnacl/squeeze_parameter.h" -#include "nnacl/stack_parameter.h" -#include "nnacl/unsqueeze_parameter.h" -#include "nnacl/unstack_parameter.h" -#include "nnacl/slice_parameter.h" -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/arg_min_max_parameter.h" +#include "nnacl_c/concat_parameter.h" +#include "nnacl_c/crop_parameter.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/squeeze_parameter.h" +#include "nnacl_c/stack_parameter.h" +#include "nnacl_c/unsqueeze_parameter.h" +#include "nnacl_c/unstack_parameter.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" namespace mindspore::lite::pass { static const std::set arithmetic_kernel_lists = { diff --git a/mindspore-lite/src/litert/pass/online_fusion/cast_gather_reduce_fusion_pass.cc b/mindspore-lite/src/litert/pass/online_fusion/cast_gather_reduce_fusion_pass.cc index 2c29023f85192312ad26fefe5d8c1b61fe577284..773d1c23fcc028817db0eec671d1dd3d4af92f21 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/cast_gather_reduce_fusion_pass.cc +++ b/mindspore-lite/src/litert/pass/online_fusion/cast_gather_reduce_fusion_pass.cc @@ -18,7 +18,7 @@ #include #include "src/litert/pass/online_fusion/online_fusion_utils.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/reduce_parameter.h" #include "include/model.h" namespace { diff --git a/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.cc b/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.cc index c5f802b2afd28ea44b4a9bbdb52fcfeb8151ebe5..7c458f399906ac4cc217f51584486a41f04d2f06 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.cc +++ b/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.cc @@ -17,9 +17,9 @@ #include "src/litert/pass/online_fusion/online_fusion_pass.h" #include #include "src/common/ops/populate/populate_register.h" -#include "nnacl/split_parameter.h" -#include "nnacl/reduce_parameter.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/reduce_parameter.h" +#include "nnacl_c/concat_parameter.h" #include "include/model.h" namespace mindspore::lite { diff --git a/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.h b/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.h index 7ea6fe7f94932fc31b12201d1a281a0759945090..d9574ee51f6db81a0d971e60a277f33800d3d37b 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.h +++ b/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.h @@ -29,8 +29,8 @@ #include "src/litert/sub_graph_split.h" #include "src/litert/pass/online_fusion/online_fusion_pass_registry.h" #include "src/common/prim_util.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::lite { class OnlineFusionPass { diff --git a/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.cc b/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.cc index 431534328c835632c5c4ca2fb0c9cd3e13ec13eb..2ff41458a1362016d595c4305c6d8aeae082603a 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.cc +++ b/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.cc @@ -18,8 +18,8 @@ #include #include "src/litert/pass/online_fusion/online_fusion_utils.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/reduce_parameter.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/reduce_parameter.h" +#include "nnacl_c/concat_parameter.h" #include "include/model.h" namespace { diff --git a/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.h b/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.h index 52d8311e2e3f245f1189fc614e711b7805c0670c..be31320505aeefbc53481183eccaf71e1722a33e 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.h +++ b/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.h @@ -29,7 +29,7 @@ #include "src/litert/sub_graph_split.h" #include "src/litert/pass/online_fusion/online_fusion_pass.h" #include "src/common/prim_util.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::lite { class ReduceConcatOnlineFusionPass : public OnlineFusionPass { diff --git a/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.cc b/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.cc index 18a835e6c72cd9fda4b9fcfbf17ce6b6ba21cd7e..76c44638d3b6614559c1ec8375397bfd9e1aede6 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.cc +++ b/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.cc @@ -17,9 +17,9 @@ #include "src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.h" #include #include "src/common/ops/populate/populate_register.h" -#include "nnacl/split_parameter.h" -#include "nnacl/concat_parameter.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/concat_parameter.h" +#include "nnacl_c/reduce_parameter.h" #include "include/model.h" namespace { diff --git a/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.h b/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.h index aca4c8305040c6fb204f1309b23193a49927af12..4b65e5ba1e1b96448e97e820de5b5d733085719e 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.h +++ b/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.h @@ -29,7 +29,7 @@ #include "src/litert/sub_graph_split.h" #include "src/litert/pass/online_fusion/online_fusion_pass.h" #include "src/common/prim_util.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::lite { class SplitReduceConcatOnlineFusionPass : public OnlineFusionPass { diff --git a/mindspore-lite/src/litert/runtime_packed_node_pass.cc b/mindspore-lite/src/litert/runtime_packed_node_pass.cc index ed7f54b9e1ecb8427ae341be18c2f9dd0c8af87b..85fdb39574170982ba73f5bfdf185a5686bb88f4 100644 --- a/mindspore-lite/src/litert/runtime_packed_node_pass.cc +++ b/mindspore-lite/src/litert/runtime_packed_node_pass.cc @@ -14,10 +14,10 @@ * limitations under the License. */ #include "src/litert/runtime_packed_node_pass.h" -#include "nnacl/op_base.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/matmul_parameter.h" #include "nnacl/nnacl_kernel.h" -#include "nnacl/kernel/matmul_struct.h" +#include "nnacl_c/kernel/matmul_struct.h" #include "common/string_utils.h" using RecoveryWeightFunc = void (*)(void *, void *, int, int, bool); diff --git a/mindspore-lite/src/litert/runtime_pass.cc b/mindspore-lite/src/litert/runtime_pass.cc index da92cca59e94f675f44b40406ccc3f5a4be61b7e..a98cbbaf489775e3291ad94b8c09603f061db452 100644 --- a/mindspore-lite/src/litert/runtime_pass.cc +++ b/mindspore-lite/src/litert/runtime_pass.cc @@ -15,7 +15,7 @@ */ #include "src/litert/runtime_pass.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite { #ifndef RUNTIME_PASS_CLIP diff --git a/mindspore-lite/src/litert/runtime_shape_fusion_pass.cc b/mindspore-lite/src/litert/runtime_shape_fusion_pass.cc index 79b101271422a8cc7c84c74b6035ddb915545ec3..726138b44befc41ed6c0be9a5d84c54f47aa7374 100644 --- a/mindspore-lite/src/litert/runtime_shape_fusion_pass.cc +++ b/mindspore-lite/src/litert/runtime_shape_fusion_pass.cc @@ -21,7 +21,7 @@ #include #include "include/errorcode.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { namespace { diff --git a/mindspore-lite/src/litert/scheduler.cc b/mindspore-lite/src/litert/scheduler.cc index 440ec0d24eaa17f58291d5ef70bc866481db7c57..2e9a684470d3c45a55ebddeea22e2caae9ddaed6 100644 --- a/mindspore-lite/src/litert/scheduler.cc +++ b/mindspore-lite/src/litert/scheduler.cc @@ -22,7 +22,7 @@ #include #include #include "src/tensorlist.h" -#include "nnacl/partial_fusion_parameter.h" +#include "nnacl_c/partial_fusion_parameter.h" #include "include/errorcode.h" #include "src/common/graph_util.h" #include "src/common/utils.h" @@ -47,7 +47,7 @@ #endif #include "src/litert/weight_decoder.h" #include "src/litert/kernel/cpu/fp16/fp16_op_handler.h" -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" #if GPU_OPENCL #include "src/litert/kernel/opencl/opencl_subgraph.h" #include "src/litert/kernel/gpu/opencl/opencl_runtime.h" diff --git a/mindspore-lite/src/litert/schema_tensor_wrapper.cc b/mindspore-lite/src/litert/schema_tensor_wrapper.cc index 5ba0a55a7d5ac71fcb3f1bdf10d7ff049e8658a7..bc608f181dbe98ec6315bda892ab09863e3fc0b9 100644 --- a/mindspore-lite/src/litert/schema_tensor_wrapper.cc +++ b/mindspore-lite/src/litert/schema_tensor_wrapper.cc @@ -17,7 +17,7 @@ #include "src/litert/schema_tensor_wrapper.h" #include "src/common/log_adapter.h" #include "src/common/file_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/litert/sub_graph_split.cc b/mindspore-lite/src/litert/sub_graph_split.cc index b1655c8cda2fa5f509fa6de58e559fd92e7aef2c..ef5f0798088c864eafddda850cbc853fcf520e5e 100644 --- a/mindspore-lite/src/litert/sub_graph_split.cc +++ b/mindspore-lite/src/litert/sub_graph_split.cc @@ -27,9 +27,9 @@ #include "src/common/ops/populate/populate_register.h" #include "src/litert/scheduler.h" #include "src/litert/tensor_category.h" -#include "nnacl/pooling_parameter.h" +#include "nnacl_c/pooling_parameter.h" #include "include/model.h" -#include "nnacl/base/conv_common_base.h" +#include "nnacl_c/base/conv_common_base.h" namespace { constexpr const int kMaxDepth = 2048; diff --git a/mindspore-lite/src/litert/sub_graph_split.h b/mindspore-lite/src/litert/sub_graph_split.h index 2b16aedb1e3e64af3d9cd15646a6b15ad10266a9..0bb7f08e1276297d2ce4681192ae0d1f7abda4d8 100644 --- a/mindspore-lite/src/litert/sub_graph_split.h +++ b/mindspore-lite/src/litert/sub_graph_split.h @@ -27,7 +27,7 @@ #include "src/litert/lite_model.h" #include "src/litert/inner_context.h" #include "src/common/prim_util.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite { constexpr int kDefaultSubGraphSize = 2; diff --git a/mindspore-lite/src/litert/thread_cost_model.cc b/mindspore-lite/src/litert/thread_cost_model.cc index 5bbd3d36cdabd1c1ded4604915fef8e9a00fef70..d3e6f5c031411d09d0578cb14c4e1ec89dfec85d 100644 --- a/mindspore-lite/src/litert/thread_cost_model.cc +++ b/mindspore-lite/src/litert/thread_cost_model.cc @@ -19,7 +19,7 @@ #include "src/common/log_util.h" #include "src/litert/inner_context.h" #include "thread/threadpool.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { const std::map kernel_compute_cost_map_ = { diff --git a/mindspore-lite/src/litert/thread_cost_model.h b/mindspore-lite/src/litert/thread_cost_model.h index 70c9ca9db13bbeadf2e65eedb891433c5c5e9e5e..0254b1efd68f6bdc083c62e6b16388891e9db227 100644 --- a/mindspore-lite/src/litert/thread_cost_model.h +++ b/mindspore-lite/src/litert/thread_cost_model.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_THREAD_COST_MODEL_H_ #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/api/context.h" #include "schema/ops_generated.h" diff --git a/mindspore-lite/src/litert/weight_decoder.cc b/mindspore-lite/src/litert/weight_decoder.cc index aa6ca73bc6e8b7127e94a6a235eaed7f5229f5e6..5e21960fdf59c15eb24443c35fc46d1f03edfad1 100644 --- a/mindspore-lite/src/litert/weight_decoder.cc +++ b/mindspore-lite/src/litert/weight_decoder.cc @@ -18,7 +18,7 @@ #include "src/litert/weight_decoder.h" #include "src/litert/huffman_decode.h" #include "tools/converter/quantizer/fse_decoder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite { #ifndef WEIGHT_DECODE_CLIP diff --git a/mindspore-lite/src/litert/weight_decoder.h b/mindspore-lite/src/litert/weight_decoder.h index 4df0eb82f9ed525370e23d9198e6c4ca414b9f00..371b3eea0bc7564183f15beeb415da1ed68ad37e 100644 --- a/mindspore-lite/src/litert/weight_decoder.h +++ b/mindspore-lite/src/litert/weight_decoder.h @@ -24,8 +24,8 @@ #include #include #include -#include "nnacl/matmul_parameter.h" -#include "nnacl/gather_parameter.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/gather_parameter.h" #include "src/executor/kernel_exec.h" #include "src/common/utils.h" #include "src/tensor.h" diff --git a/mindspore-lite/src/tensor.h b/mindspore-lite/src/tensor.h index a02a51b9dda170e7c289d1743bb4235a8e46d209..724674db8b6eeabd98f7970cee792d2421f51733 100644 --- a/mindspore-lite/src/tensor.h +++ b/mindspore-lite/src/tensor.h @@ -26,8 +26,8 @@ #include #include "include/api/format.h" #include "include/lite_types.h" -#include "nnacl/tensor_c.h" -#include "nnacl/tensor_c_utils.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/tensor_c_utils.h" #include "src/litert/inner_allocator.h" #include "src/common/log_adapter.h" #include "src/common/utils.h" diff --git a/mindspore-lite/src/tensorlist.cc b/mindspore-lite/src/tensorlist.cc index 13fe91e3ec1fd7b31b1cc744807e50c533f48c92..10d87438653dcec5fcf5dfc774a976b6364c879b 100644 --- a/mindspore-lite/src/tensorlist.cc +++ b/mindspore-lite/src/tensorlist.cc @@ -19,7 +19,7 @@ #include #include "src/common/log_adapter.h" #include "src/tensor.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { #ifndef CONTROLFLOW_TENSORLIST_CLIP diff --git a/mindspore-lite/src/tensorlist.h b/mindspore-lite/src/tensorlist.h index 5c85e29d40105e99eee58ba7a76a9471a8d0e04b..406dd5fa9afc723246bf44c2281c28363be7fe9a 100644 --- a/mindspore-lite/src/tensorlist.h +++ b/mindspore-lite/src/tensorlist.h @@ -20,7 +20,7 @@ #include #include #include "include/errorcode.h" -#include "nnacl/tensorlist_c.h" +#include "nnacl_c/tensorlist_c.h" #include "src/common/log_adapter.h" #include "schema/model_generated.h" #include "src/tensor.h" diff --git a/mindspore-lite/src/train/opt_allocator.cc b/mindspore-lite/src/train/opt_allocator.cc index d9931641c296ab981e260293818e08f9f2ee412e..485e4b9742411209a7171a3aa2e8bdd1da2f6e4a 100644 --- a/mindspore-lite/src/train/opt_allocator.cc +++ b/mindspore-lite/src/train/opt_allocator.cc @@ -15,7 +15,7 @@ */ #include "src/train/opt_allocator.h" #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { size_t OptAllocator::FindFree(size_t size) { diff --git a/mindspore-lite/src/train/optimizer/fusion/gru_fusion_pass.cc b/mindspore-lite/src/train/optimizer/fusion/gru_fusion_pass.cc index 435686e50a8370e189a04f1f397870ef21347e94..91655293ed7fcf2f341bbb401de7a360156ac93e 100644 --- a/mindspore-lite/src/train/optimizer/fusion/gru_fusion_pass.cc +++ b/mindspore-lite/src/train/optimizer/fusion/gru_fusion_pass.cc @@ -24,7 +24,7 @@ #include #include "src/common/log_adapter.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/train/train_loop.cc b/mindspore-lite/src/train/train_loop.cc index b565443b1cb870463386fb4eecdf27b63484e0bb..125c4edea337d5622189485ae0ce40dcb8eab872 100644 --- a/mindspore-lite/src/train/train_loop.cc +++ b/mindspore-lite/src/train/train_loop.cc @@ -22,7 +22,7 @@ #include "include/errorcode.h" #include "include/dataset/iterator.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/train/train_populate_parameter.cc b/mindspore-lite/src/train/train_populate_parameter.cc index c1a70b921d585adb03dc4f7fbc0c4ebade2fdf61..3cfdd24d62c4c35cb5537e58de328df4f06b2a9e 100644 --- a/mindspore-lite/src/train/train_populate_parameter.cc +++ b/mindspore-lite/src/train/train_populate_parameter.cc @@ -17,21 +17,21 @@ #include "include/securec.h" #include "src/common/ops/populate/populate_register.h" #include "src/common/ops/populate/default_populate.h" -#include "nnacl/strided_slice_parameter.h" -#include "nnacl/arithmetic_parameter.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/pooling_parameter.h" -#include "nnacl/pow_parameter.h" -#include "nnacl/activation_parameter.h" -#include "nnacl/fp32_grad/softmax_crossentropy_parameter.h" -#include "nnacl/fp32_grad/optimizer.h" -#include "nnacl/fp32_grad/batch_norm_parameter.h" -#include "nnacl/fp32_grad/dropout_parameter.h" -#include "nnacl/fp32_grad/smooth_l1_loss.h" -#include "nnacl/fp32_grad/resize_grad_parameter.h" -#include "nnacl/fp32_grad/lstm_grad_fp32.h" -#include "nnacl/fp32_grad/binary_cross_entropy.h" -#include "nnacl/fp32_grad/binary_cross_entropy_grad.h" +#include "nnacl_c/strided_slice_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/pow_parameter.h" +#include "nnacl_c/activation_parameter.h" +#include "nnacl_c/fp32_grad/softmax_crossentropy_parameter.h" +#include "nnacl_c/fp32_grad/optimizer.h" +#include "nnacl_c/fp32_grad/batch_norm_parameter.h" +#include "nnacl_c/fp32_grad/dropout_parameter.h" +#include "nnacl_c/fp32_grad/smooth_l1_loss.h" +#include "nnacl_c/fp32_grad/resize_grad_parameter.h" +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" +#include "nnacl_c/fp32_grad/binary_cross_entropy.h" +#include "nnacl_c/fp32_grad/binary_cross_entropy_grad.h" using mindspore::lite::Registry; diff --git a/mindspore-lite/test/CMakeLists.txt b/mindspore-lite/test/CMakeLists.txt index 0d0760860e7346dda629739f5f7826e28f41eca5..f44d75b92e379371f45fe4fef2ed9c441187a35a 100644 --- a/mindspore-lite/test/CMakeLists.txt +++ b/mindspore-lite/test/CMakeLists.txt @@ -4,7 +4,7 @@ set(LITE_DIR ${TOP_DIR}/mindspore-lite) include_directories(${TOP_DIR}) include_directories(${TEST_DIR}) -include_directories(${TOP_DIR}/mindspore/mindspore/ops/kernel/cpu) +include_directories(${NNACL_DIR}/../) include(${TOP_DIR}/cmake/external_libs/gtest.cmake) include(${TOP_DIR}/cmake/external_libs/mockcpp.cmake) diff --git a/mindspore-lite/test/common/common_test.h b/mindspore-lite/test/common/common_test.h index aef9e7d9d0dbefee65bc3492f1334a9d06136351..e8f32d3b7b42701aa88e04895210ee5fc02923dc 100644 --- a/mindspore-lite/test/common/common_test.h +++ b/mindspore-lite/test/common/common_test.h @@ -25,8 +25,8 @@ #include "gtest/gtest.h" #include "include/api/format.h" #include "src/litert/tensor_category.h" -#include "nnacl/tensorlist_c_utils.h" -#include "nnacl/tensor_c_utils.h" +#include "nnacl_c/tensorlist_c_utils.h" +#include "nnacl_c/tensor_c_utils.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/test/ut/nnacl/infer/adam_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/adam_infer_test.cc index 31834b01b44f4ee975494b4b1ae5dc6bbbc55562..355bd16e298773956929d77beba7280c3d97a4b0 100644 --- a/mindspore-lite/test/ut/nnacl/infer/adam_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/adam_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/adam_infer.h" +#include "nnacl_c/infer/adam_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/adam_weight_decay_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/adam_weight_decay_infer_test.cc index 52dae7fbe34a2b7d364969b8cd826a513d834733..62f8654c4c9b31fece0998003b179f3df9e4c60d 100644 --- a/mindspore-lite/test/ut/nnacl/infer/adam_weight_decay_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/adam_weight_decay_infer_test.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" -#include "nnacl/infer/adam_weight_decay_infer.h" +#include "nnacl_c/infer/adam_weight_decay_infer.h" namespace mindspore { class AdamWeightDecayInfer : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/nnacl/infer/addn_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/addn_infer_test.cc index 0604cfa8cf0ce02491dec72a8880a4893c394ca8..6d5cd2c8188ad61391f7874fc22ee975c49cdaac 100644 --- a/mindspore-lite/test/ut/nnacl/infer/addn_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/addn_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/addn_infer.h" +#include "nnacl_c/infer/addn_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/apply_momentum_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/apply_momentum_infer_test.cc index 5c21e8233cc38994d07b065d42739ac0eb14fcf0..ee3b919c7085b9bce20d30da21127763873776d2 100644 --- a/mindspore-lite/test/ut/nnacl/infer/apply_momentum_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/apply_momentum_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/apply_momentum_infer.h" +#include "nnacl_c/infer/apply_momentum_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/argmax_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/argmax_infer_test.cc index d003c9d27e1958c8992caec2c9daeb791d186866..10ecad2c5e791973c5f5fb0e29330e6f481e7c76 100644 --- a/mindspore-lite/test/ut/nnacl/infer/argmax_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/argmax_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/argmin_max_infer.h" +#include "nnacl_c/infer/argmin_max_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/argmin_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/argmin_infer_test.cc index 7fe54cfd97150ae549c4b67002f97ed7d8c2dd66..1c6ed96af96dd2a7750fcbaad4a4e73a4975183e 100644 --- a/mindspore-lite/test/ut/nnacl/infer/argmin_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/argmin_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/argmin_max_infer.h" +#include "nnacl_c/infer/argmin_max_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/arithmetic_compare_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/arithmetic_compare_infer_test.cc index 869edf1d4d61d647c211070c310310dcd3cc3688..c1096310bf9719aa216d393f738e1013e3c7b4cf 100644 --- a/mindspore-lite/test/ut/nnacl/infer/arithmetic_compare_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/arithmetic_compare_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/arithmetic_compare_infer.h" +#include "nnacl_c/infer/arithmetic_compare_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/arithmetic_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/arithmetic_infer_test.cc index 146c3f0f80d1b7cd083a814b347b8acabee6a067..fa759509dfa77bf63df58ac6e463804717ab1730 100644 --- a/mindspore-lite/test/ut/nnacl/infer/arithmetic_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/arithmetic_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/arithmetic_infer.h" +#include "nnacl_c/infer/arithmetic_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/assign_add_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/assign_add_infer_test.cc index 6f6e823ff00a5b08d158f6c99718c7caed034791..b2055260ab328830c56b132c04ae94b1c22ded51 100644 --- a/mindspore-lite/test/ut/nnacl/infer/assign_add_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/assign_add_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/assign_add_infer.h" +#include "nnacl_c/infer/assign_add_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/assign_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/assign_infer_test.cc index 8263392c02fb49ca1e495b2c0240dce658076842..739c25be950fb3eea8a39b9fd1d4ff5843122b36 100644 --- a/mindspore-lite/test/ut/nnacl/infer/assign_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/assign_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/assign_infer.h" +#include "nnacl_c/infer/assign_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/audio_spectrogram_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/audio_spectrogram_infer_test.cc index 5c724ae1abef9cb58da2b1c6d4e2ed556b7f0b22..14497e7314f54a7b394f447a92fb3f933db21aa1 100644 --- a/mindspore-lite/test/ut/nnacl/infer/audio_spectrogram_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/audio_spectrogram_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/audio_spectrogram_infer.h" +#include "nnacl_c/infer/audio_spectrogram_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/batch_to_space_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/batch_to_space_infer_test.cc index 0541102b75ac6c3af77d102ed1a8dd9dc86208e3..2985f2b706305db53fdd4fc0a63bbaebdc4f650e 100644 --- a/mindspore-lite/test/ut/nnacl/infer/batch_to_space_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/batch_to_space_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/batch_to_space_infer.h" +#include "nnacl_c/infer/batch_to_space_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/bias_grad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/bias_grad_infer_test.cc index 55c5d2de7e8e4b04b127b4e589971e6dec01216e..855e0a4e71d15858067d6387f00d6e04ef7a8ecd 100644 --- a/mindspore-lite/test/ut/nnacl/infer/bias_grad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/bias_grad_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/bias_grad_infer.h" +#include "nnacl_c/infer/bias_grad_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/binary_cross_entropy_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/binary_cross_entropy_infer_test.cc index b6a7e3f83e957a51c05a33fc5d2560e75ef15897..25c51b064fbf57ecdaa87ab1268018a4f2741abd 100644 --- a/mindspore-lite/test/ut/nnacl/infer/binary_cross_entropy_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/binary_cross_entropy_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/binary_cross_entropy_infer.h" +#include "nnacl_c/infer/binary_cross_entropy_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/bn_grad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/bn_grad_infer_test.cc index 8fbabbca84cb52d09c9275761e06fbe9641c2856..9ea08ec6a3f22d6c1be0293863f9f5b80cff6ae3 100644 --- a/mindspore-lite/test/ut/nnacl/infer/bn_grad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/bn_grad_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/bn_grad_infer.h" +#include "nnacl_c/infer/bn_grad_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/broadcast_to_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/broadcast_to_infer_test.cc index 31e2b55c8cc3ff8a3a1aac110859b6ac5364e180..34f44a0434834d3764ec803e6d8dfd0f2efa249d 100644 --- a/mindspore-lite/test/ut/nnacl/infer/broadcast_to_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/broadcast_to_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/broadcast_to_infer.h" +#include "nnacl_c/infer/broadcast_to_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/cast_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/cast_infer_test.cc index 3c4906773e05e258907e77f2e7819f169d670280..18b54fce7a0f0787ddd0fba3c295e7249d7cba0c 100644 --- a/mindspore-lite/test/ut/nnacl/infer/cast_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/cast_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/cast_infer.h" +#include "nnacl_c/infer/cast_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/concat_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/concat_infer_test.cc index 03624cfd13eca15e23bbdc3225b00d7a7f933461..a3b15dd2533ca5346d99cee5e5afc3601ce29117 100644 --- a/mindspore-lite/test/ut/nnacl/infer/concat_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/concat_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/concat_infer.h" +#include "nnacl_c/infer/concat_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/constant_of_shape_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/constant_of_shape_infer_test.cc index 792ea55bf71aafe1ce9ee573e7119b311c68c592..7f4cbd53f8ecff7bbfc8abb13e0149a3cc4b5574 100644 --- a/mindspore-lite/test/ut/nnacl/infer/constant_of_shape_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/constant_of_shape_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/constant_of_shape_infer.h" +#include "nnacl_c/infer/constant_of_shape_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_filter_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_filter_infer_test.cc index 868768e82769945b55388e0d80792434a8aa29ba..1d22281e4ccf2bcc2a82fbd8b542a90148d7829c 100644 --- a/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_filter_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_filter_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/conv2d_grad_filter_infer.h" +#include "nnacl_c/infer/conv2d_grad_filter_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_input_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_input_infer_test.cc index 12e0fbb17461baa32d2d54ef98ceac8d435b2b5b..cce2634dc8bd242335ce846975be257ec15bde30 100644 --- a/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_input_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_input_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/conv2d_grad_input_infer.h" +#include "nnacl_c/infer/conv2d_grad_input_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/conv2d_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/conv2d_infer_test.cc index 749cddb28e407990d6d8dd25f4d547e854e81336..3f88a643e8111c69adb0089182a5e2020707358d 100644 --- a/mindspore-lite/test/ut/nnacl/infer/conv2d_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/conv2d_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/conv2d_infer.h" +#include "nnacl_c/infer/conv2d_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/crop_and_resize_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/crop_and_resize_infer_test.cc index 01cfd8a1661af7af72cec1afa74dff3bc8dd3118..5bee340299dc1c43df41420cdf3ad72a8805e8b9 100644 --- a/mindspore-lite/test/ut/nnacl/infer/crop_and_resize_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/crop_and_resize_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/crop_and_resize_infer.h" +#include "nnacl_c/infer/crop_and_resize_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/crop_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/crop_infer_test.cc index ca1798d10d50a73c3fa483968463d185941dfae8..434d97a1a83bb2e36a9fa45d6d27226c73e197f5 100644 --- a/mindspore-lite/test/ut/nnacl/infer/crop_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/crop_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/crop_infer.h" +#include "nnacl_c/infer/crop_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/cumsum_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/cumsum_infer_test.cc index e66522f0c8b9e29d58979b05387226eefc6d26c0..5a09a79614825c396f18dd1837cd21925faa589c 100644 --- a/mindspore-lite/test/ut/nnacl/infer/cumsum_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/cumsum_infer_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/cumsum_infer.h" -#include "nnacl/cumsum_parameter.h" +#include "nnacl_c/infer/cumsum_infer.h" +#include "nnacl_c/cumsum_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/custom_extract_features_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/custom_extract_features_infer_test.cc index cd1ded3f5fd29144e0c3ecc4306ffe732517f071..c38148aa825dece11450e08cd6f852737f8f0600 100644 --- a/mindspore-lite/test/ut/nnacl/infer/custom_extract_features_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/custom_extract_features_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/string/custom_extract_features_infer.h" +#include "nnacl_c/infer/string/custom_extract_features_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/custom_normalize_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/custom_normalize_infer_test.cc index 1c84fdd7215bed6225957d39131431c9087b3f9e..081aa543df68b3edd7dd9e66dcf488c66aa32379 100644 --- a/mindspore-lite/test/ut/nnacl/infer/custom_normalize_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/custom_normalize_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/string/custom_normalize_infer.h" +#include "nnacl_c/infer/string/custom_normalize_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/custom_predict_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/custom_predict_infer_test.cc index b908aa7a344c629ee892106609b9b4b941821176..415c5c88644f3723ee9ffe71371115c348331b63 100644 --- a/mindspore-lite/test/ut/nnacl/infer/custom_predict_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/custom_predict_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/string/custom_predict_infer.h" +#include "nnacl_c/infer/string/custom_predict_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/deconv2d_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/deconv2d_infer_test.cc index aa2360f0f4a2f3012c99aa4e5d91af02964a2e41..3e1f9579601568b5dc5b9cbe8d3d766d14bc8077 100644 --- a/mindspore-lite/test/ut/nnacl/infer/deconv2d_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/deconv2d_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/deconv2d_infer.h" +#include "nnacl_c/infer/deconv2d_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/depth_to_space_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/depth_to_space_infer_test.cc index 7b947e4b437c3a619dc5a1f245ceeed2305351ad..2508e21248278b73cacabf2d411814fc047629b6 100644 --- a/mindspore-lite/test/ut/nnacl/infer/depth_to_space_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/depth_to_space_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/depth_to_space_infer.h" +#include "nnacl_c/infer/depth_to_space_infer.h" #include "src/tensor.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/depthwise_conv2d_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/depthwise_conv2d_infer_test.cc index 74e9d27e9cb3a52d08c6e2955ff459e4146c6065..cf96e455d0f6e28d413596c35d5f9cbb8e525f7e 100644 --- a/mindspore-lite/test/ut/nnacl/infer/depthwise_conv2d_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/depthwise_conv2d_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/depthwise_conv2d_infer.h" +#include "nnacl_c/infer/depthwise_conv2d_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/detection_post_process_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/detection_post_process_infer_test.cc index 9dcca21df6d371df01f3816bb43efcdd77065374..0321ff0c489f60e6368de526d2f0abfd0649ed44 100644 --- a/mindspore-lite/test/ut/nnacl/infer/detection_post_process_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/detection_post_process_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/detection_post_process_infer.h" +#include "nnacl_c/infer/detection_post_process_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/dropout_grad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/dropout_grad_infer_test.cc index 758205e4384461e8ac90a71bed245c579baec3dc..9731d71e18602520090ee89f90b9fc8e65521fca 100644 --- a/mindspore-lite/test/ut/nnacl/infer/dropout_grad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/dropout_grad_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/dropout_grad_infer.h" +#include "nnacl_c/infer/dropout_grad_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/embedding_lookup_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/embedding_lookup_infer_test.cc index 98b8cf24749ba4275d76872fdd92b5c923d3ab7d..995b129fc6af7ea4d7e7cbeae4a22a3c8b9b458c 100644 --- a/mindspore-lite/test/ut/nnacl/infer/embedding_lookup_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/embedding_lookup_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/embedding_lookup_infer.h" +#include "nnacl_c/infer/embedding_lookup_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/expand_dims_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/expand_dims_infer_test.cc index 92c099f863f01c6d937f697cf2b839d70299ff9e..039ae90ae98162dcd4d4dd56f15466fac5f12bac 100644 --- a/mindspore-lite/test/ut/nnacl/infer/expand_dims_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/expand_dims_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/expand_dims_infer.h" +#include "nnacl_c/infer/expand_dims_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/fft_imag_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/fft_imag_infer_test.cc index 3073083181243f1c94b49b1c6dbea30c468542d0..36390924a425368025504807ce34e87acf986a6b 100644 --- a/mindspore-lite/test/ut/nnacl/infer/fft_imag_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/fft_imag_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/fft_imag_infer.h" +#include "nnacl_c/infer/fft_imag_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/fill_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/fill_infer_test.cc index 2ad4b009f5251deebd5351b72979bc68b9a9afaf..451462af3a1591d7ddd63ad6c0c04471caa6cfa9 100644 --- a/mindspore-lite/test/ut/nnacl/infer/fill_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/fill_infer_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/fill_infer.h" -#include "nnacl/fill_parameter.h" +#include "nnacl_c/infer/fill_infer.h" +#include "nnacl_c/fill_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/flatten_grad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/flatten_grad_infer_test.cc index 686fdf22a5e039620d94ff7658dfec1a9698da32..acad0586afc073fd2fb7b2d54af89c1b8c303a00 100644 --- a/mindspore-lite/test/ut/nnacl/infer/flatten_grad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/flatten_grad_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/flatten_grad_infer.h" +#include "nnacl_c/infer/flatten_grad_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/flatten_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/flatten_infer_test.cc index f815279349cc78cd774afc1b1af641dca2f49011..24029984237a58685a1481b74e9f7b07bc119ae0 100644 --- a/mindspore-lite/test/ut/nnacl/infer/flatten_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/flatten_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/flatten_infer.h" +#include "nnacl_c/infer/flatten_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/full_connection_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/full_connection_infer_test.cc index 0dd33a6a9c82c4f55deeb583e1c3c6b68367ddb6..33f147462b9224544e2d7dfd33faf16765353f5a 100644 --- a/mindspore-lite/test/ut/nnacl/infer/full_connection_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/full_connection_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/full_connection_infer.h" +#include "nnacl_c/infer/full_connection_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/fused_batchnorm_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/fused_batchnorm_infer_test.cc index 5df7e3e20a49ac3ffbe4527e6091bc310d2f2df3..b5bdd02790332d26a6c39f8b5080039384680792 100644 --- a/mindspore-lite/test/ut/nnacl/infer/fused_batchnorm_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/fused_batchnorm_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/fused_batchnorm_infer.h" +#include "nnacl_c/infer/fused_batchnorm_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/gather_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/gather_infer_test.cc index 079c9c3c7d9a5d9cf8b49750dc87a35a94b92489..5d35f7a86fa0b332a26d5d0d615e96e3c8493d1a 100644 --- a/mindspore-lite/test/ut/nnacl/infer/gather_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/gather_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/gather_infer.h" +#include "nnacl_c/infer/gather_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/gather_nd_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/gather_nd_infer_test.cc index 26fd731c704f1e94c113519951cc051567be6727..5c203271acff3e4b2a8dcc35e7c2c150ec57dc7d 100644 --- a/mindspore-lite/test/ut/nnacl/infer/gather_nd_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/gather_nd_infer_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/gather_nd_infer.h" -#include "nnacl/gather_nd_parameter.h" +#include "nnacl_c/infer/gather_nd_infer.h" +#include "nnacl_c/gather_nd_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/group_conv2d_grad_input_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/group_conv2d_grad_input_infer_test.cc index f687aa7dd23999774e6528b5205a5f1aeb13de4a..d2b9b3a76b087dba0a785c76931020413a989a06 100644 --- a/mindspore-lite/test/ut/nnacl/infer/group_conv2d_grad_input_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/group_conv2d_grad_input_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/group_conv2d_grad_input_infer.h" +#include "nnacl_c/infer/group_conv2d_grad_input_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/gru_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/gru_infer_test.cc index 6f917cf144bc966b1a5a4f1c3e3522ca89f0e32a..dde9f8a7b51156ffeeb57d87c60c70cae9193daa 100644 --- a/mindspore-lite/test/ut/nnacl/infer/gru_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/gru_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/gru_infer.h" +#include "nnacl_c/infer/gru_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/hashtable_lookup_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/hashtable_lookup_infer_test.cc index 4768bedf7e4b83a8efc850cd3f00944c1e11ee36..37eb81144a2c90aa839636b1f63968cbc7a9707b 100644 --- a/mindspore-lite/test/ut/nnacl/infer/hashtable_lookup_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/hashtable_lookup_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/string/hashtable_lookup_infer.h" +#include "nnacl_c/infer/string/hashtable_lookup_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/invert_permutation_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/invert_permutation_infer_test.cc index 78b5666f6edee7db58b575719d1305b994838e98..07467648b33d69ff3c9dcb651eb6c3c804c562f7 100644 --- a/mindspore-lite/test/ut/nnacl/infer/invert_permutation_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/invert_permutation_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/invert_permutation_infer.h" +#include "nnacl_c/infer/invert_permutation_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/layer_norm_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/layer_norm_infer_test.cc index 167f8a039fd23ea8d00450dafb099d658d7675a2..6849818793fb91640e95f78cc4df5a3e2d2f3921 100644 --- a/mindspore-lite/test/ut/nnacl/infer/layer_norm_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/layer_norm_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/layer_norm_infer.h" +#include "nnacl_c/infer/layer_norm_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/lsh_projection_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/lsh_projection_infer_test.cc index 33717760b18fad0b22fb24692b8dda9e0485a373..72a19e046e2f17177d046aac9e573d33730ea88f 100644 --- a/mindspore-lite/test/ut/nnacl/infer/lsh_projection_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/lsh_projection_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/string/lsh_projection_infer.h" +#include "nnacl_c/infer/string/lsh_projection_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/lstm_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/lstm_infer_test.cc index ef24be40e6c7c7274c59c6a2e13ef4da323947c4..5c3434ebf3f9d3df270533e7e06e1ac5a2cb8046 100644 --- a/mindspore-lite/test/ut/nnacl/infer/lstm_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/lstm_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/lstm_infer.h" +#include "nnacl_c/infer/lstm_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/matmul_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/matmul_infer_test.cc index 1c2c807dfdd0ebbb2d566be0736056e21794ddf9..e5e966dc10a98b3f0c69d41e5b40340ae4f17ed9 100644 --- a/mindspore-lite/test/ut/nnacl/infer/matmul_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/matmul_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/matmul_infer.h" +#include "nnacl_c/infer/matmul_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/max_min_grad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/max_min_grad_infer_test.cc index 107bcbf3661e7660efb7f6b1a1b9713fd200308a..740a727878a7665d72de59d082699f99c93e62f8 100644 --- a/mindspore-lite/test/ut/nnacl/infer/max_min_grad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/max_min_grad_infer_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/max_min_grad_infer.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/infer/max_min_grad_infer.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/mfcc_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/mfcc_infer_test.cc index f6c2a615c82e1c38b91e3e1aa09cd0eb70a9fe8c..7efce63be785ec1729f6a0c82f64d07131560b01 100644 --- a/mindspore-lite/test/ut/nnacl/infer/mfcc_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/mfcc_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/mfcc_infer.h" +#include "nnacl_c/infer/mfcc_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/nllloss_grad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/nllloss_grad_infer_test.cc index 260dcfa7f7fbace8c5403fbe776c016ecae6f12c..0c0480b8ef581cb89854dded1b336317719445fa 100644 --- a/mindspore-lite/test/ut/nnacl/infer/nllloss_grad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/nllloss_grad_infer_test.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" -#include "nnacl/infer/nllloss_grad_infer.h" +#include "nnacl_c/infer/nllloss_grad_infer.h" namespace mindspore { class TestNLLLossGradInfer : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/nnacl/infer/nllloss_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/nllloss_infer_test.cc index 7de3cc4f0deb00f5bd6d1b59bda5f953659b22ce..416d2299dede04be6a33026601bfa8ae6ee1f84e 100644 --- a/mindspore-lite/test/ut/nnacl/infer/nllloss_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/nllloss_infer_test.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" -#include "nnacl/infer/nllloss_infer.h" +#include "nnacl_c/infer/nllloss_infer.h" namespace mindspore { class TestNLLLossInfer : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/nnacl/infer/one_hot_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/one_hot_infer_test.cc index 0324351882b84f7d360c9415cea914ba0b86c8b7..0752eafd6e64743338a62f2ad2c0da0cf417cc84 100644 --- a/mindspore-lite/test/ut/nnacl/infer/one_hot_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/one_hot_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/one_hot_infer.h" +#include "nnacl_c/infer/one_hot_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/pad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/pad_infer_test.cc index 2c6473e1037b24874c43477ece40fb68a016d1a9..b66b40a55e1e736acaeef00fdfeaaf5f71e6fb19 100644 --- a/mindspore-lite/test/ut/nnacl/infer/pad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/pad_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/pad_infer.h" +#include "nnacl_c/infer/pad_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/pooling_grad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/pooling_grad_infer_test.cc index 74c32a2fd9e9935a40ec36111f1cebdcea156906..bb776ea3046b4b28754b0bf8b71041d5b0c6578b 100644 --- a/mindspore-lite/test/ut/nnacl/infer/pooling_grad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/pooling_grad_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/pooling_grad_infer.h" +#include "nnacl_c/infer/pooling_grad_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/pooling_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/pooling_infer_test.cc index 0fb2789282dd660be51a51de13ad816e561451d1..9f4d72d58d732f04d08611c00a3a07006b75f01b 100644 --- a/mindspore-lite/test/ut/nnacl/infer/pooling_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/pooling_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/pooling_infer.h" +#include "nnacl_c/infer/pooling_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/power_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/power_infer_test.cc index dc7b10f7ce95b672259b26d23ff90783e3cf1d94..0047b71d0f0d0af0ececa3c0a80296a9d56b756e 100644 --- a/mindspore-lite/test/ut/nnacl/infer/power_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/power_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/power_infer.h" +#include "nnacl_c/infer/power_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/quant_dtype_cast_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/quant_dtype_cast_infer_test.cc index 9d60d2f495c156785a9e205584905dfe07c2c3ae..7c99fb158538c84d58d2e0e0e3422f1163a196bc 100644 --- a/mindspore-lite/test/ut/nnacl/infer/quant_dtype_cast_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/quant_dtype_cast_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/quant_dtype_cast_infer.h" +#include "nnacl_c/infer/quant_dtype_cast_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/random_standard_normal_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/random_standard_normal_infer_test.cc index ab81dcdecd5a7b144ef274e6e37c7fdf8a1e26e7..c1c6e56ec247f7fd63cb1c2e7f740d4e38973b2f 100644 --- a/mindspore-lite/test/ut/nnacl/infer/random_standard_normal_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/random_standard_normal_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/random_standard_normal_infer.h" +#include "nnacl_c/infer/random_standard_normal_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/range_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/range_infer_test.cc index 067490b741df200398f54cfb063facf5369ed313..8f40763d160de830782800989f36f8c6325f5432 100644 --- a/mindspore-lite/test/ut/nnacl/infer/range_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/range_infer_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/range_infer.h" -#include "nnacl/range_parameter.h" +#include "nnacl_c/infer/range_infer.h" +#include "nnacl_c/range_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/rank_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/rank_infer_test.cc index 8b6000c6c005641a63d042b06ee0d6c0d24fabad..c2c99e47d9bf4203081c41bc1488079422a8cf38 100644 --- a/mindspore-lite/test/ut/nnacl/infer/rank_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/rank_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/rank_infer.h" +#include "nnacl_c/infer/rank_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/reduce_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/reduce_infer_test.cc index a1823f74b0431d5912619de05e49a6fad5bd0f8f..a5bf6ed76d3a96fd29345ae7e04117fc957a1607 100644 --- a/mindspore-lite/test/ut/nnacl/infer/reduce_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/reduce_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/reduce_infer.h" +#include "nnacl_c/infer/reduce_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/reshape_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/reshape_infer_test.cc index ade7305d3fdefc8897b5fc962073a516025fee0f..47b4be80992d6fb067cce548c28946898630c192 100644 --- a/mindspore-lite/test/ut/nnacl/infer/reshape_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/reshape_infer_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/reshape_infer.h" -#include "nnacl/reshape_parameter.h" +#include "nnacl_c/infer/reshape_infer.h" +#include "nnacl_c/reshape_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/resize_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/resize_infer_test.cc index 5ac041cd6ba0ca6297f8b24bc47a26f967ce8d18..4aaf0a75c7a779db8fe86cbe6d17e4cac861c55b 100644 --- a/mindspore-lite/test/ut/nnacl/infer/resize_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/resize_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/resize_infer.h" +#include "nnacl_c/infer/resize_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/rfft_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/rfft_infer_test.cc index 20a33e82c240f34b455450384d69dfd6a44d6cae..ce630250a77e9772c66fbd133206db02c758e3d6 100644 --- a/mindspore-lite/test/ut/nnacl/infer/rfft_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/rfft_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/rfft_infer.h" +#include "nnacl_c/infer/rfft_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/roi_pooling_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/roi_pooling_infer_test.cc index 25293239b70537327b640dba7762af790f2b8e2f..ee8cc2a7caf42f75928210cb08e3a68141c9f118 100644 --- a/mindspore-lite/test/ut/nnacl/infer/roi_pooling_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/roi_pooling_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/roi_pooling_infer.h" +#include "nnacl_c/infer/roi_pooling_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/scatter_nd_add_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/scatter_nd_add_infer_test.cc index 25ca4491f7c9401bf0e75438434cf8789bebdf76..5d3813c8f5db33e809a88feb7876feed7bdb6d40 100644 --- a/mindspore-lite/test/ut/nnacl/infer/scatter_nd_add_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/scatter_nd_add_infer_test.cc @@ -15,8 +15,8 @@ */ #include "common/common_test.h" -#include "nnacl/scatter_nd_parameter.h" -#include "nnacl/infer/scatter_nd_update_infer.h" +#include "nnacl_c/scatter_nd_parameter.h" +#include "nnacl_c/infer/scatter_nd_update_infer.h" namespace mindspore { class TestScatterNdAddInfer : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/nnacl/infer/scatter_nd_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/scatter_nd_infer_test.cc index 70efa6d31882070333a0f31fa0401e0606ce6053..3600d37ddb2746d74226bc6ead5943aa9686a4a0 100644 --- a/mindspore-lite/test/ut/nnacl/infer/scatter_nd_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/scatter_nd_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/scatter_nd_infer.h" +#include "nnacl_c/infer/scatter_nd_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/select_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/select_infer_test.cc index f6d308f97726262819120cf1ecf1f7c05fc39265..67f74a02d096e5eee3c8a6036f7d54e236282d87 100644 --- a/mindspore-lite/test/ut/nnacl/infer/select_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/select_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/select_infer.h" +#include "nnacl_c/infer/select_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/sgd_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/sgd_infer_test.cc index 72c60a45cb316bd2bb0102710845664c1a8a888e..340bc8bfa7c1493c3030dbba895b73dbf0fd2f71 100644 --- a/mindspore-lite/test/ut/nnacl/infer/sgd_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/sgd_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/sgd_infer.h" +#include "nnacl_c/infer/sgd_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/shape_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/shape_infer_test.cc index bbc10b3a5958747d578c96338b0119205aa2c898..920f589daa967c307d5360cf4e6bdb9d61e8f85e 100644 --- a/mindspore-lite/test/ut/nnacl/infer/shape_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/shape_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/shape_infer.h" +#include "nnacl_c/infer/shape_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/size_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/size_infer_test.cc index f6e9b7d44eecd303e31c5318a39bbbc1c42a28c7..2325f9f8c7b0f5c2287fa17dfd71f4860881e8ec 100644 --- a/mindspore-lite/test/ut/nnacl/infer/size_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/size_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/size_infer.h" +#include "nnacl_c/infer/size_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/skip_gram_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/skip_gram_infer_test.cc index ef7adebb89820cab398f04a2e14d937acf62ab09..9669e06043e0e54e3d2550d22f6a5e7a0b781832 100644 --- a/mindspore-lite/test/ut/nnacl/infer/skip_gram_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/skip_gram_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/string/skip_gram_infer.h" +#include "nnacl_c/infer/string/skip_gram_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/slice_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/slice_infer_test.cc index 33fdcd74809f239ccb6f60664db03627dac75fce..087afc9ccf45cb889b21d1beb843bebd9ec58e45 100644 --- a/mindspore-lite/test/ut/nnacl/infer/slice_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/slice_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/slice_infer.h" +#include "nnacl_c/infer/slice_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/softmax_cross_entropy_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/softmax_cross_entropy_infer_test.cc index 27a7c96ce88711d6fe4af41e74b3d86bed0694a8..b1cb21708446ce643812906e27ca749f59bb44dd 100644 --- a/mindspore-lite/test/ut/nnacl/infer/softmax_cross_entropy_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/softmax_cross_entropy_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/softmax_cross_entropy_infer.h" +#include "nnacl_c/infer/softmax_cross_entropy_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/softmax_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/softmax_infer_test.cc index 835845876a8ea6cc95c6aa8f0abbe24cdf61b011..0b440079e804510b18f0d268dbb27f34e5ad1dce 100644 --- a/mindspore-lite/test/ut/nnacl/infer/softmax_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/softmax_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/softmax_infer.h" +#include "nnacl_c/infer/softmax_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/space_to_batch_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/space_to_batch_infer_test.cc index c780c35453776bcafa35ba1a4b32a857df3598be..4786e258b551d835f93ae6208032e3a238ebfd10 100644 --- a/mindspore-lite/test/ut/nnacl/infer/space_to_batch_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/space_to_batch_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/space_to_batch_infer.h" +#include "nnacl_c/infer/space_to_batch_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/space_to_batch_nd_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/space_to_batch_nd_infer_test.cc index 224d3edfab5803c4f876d2e33c95ed977065b071..0a160e43e8ab165376a3ab6364d47c0605187327 100644 --- a/mindspore-lite/test/ut/nnacl/infer/space_to_batch_nd_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/space_to_batch_nd_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/space_to_batch_nd_infer.h" +#include "nnacl_c/infer/space_to_batch_nd_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/space_to_depth_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/space_to_depth_infer_test.cc index 4293d921734319578fb40da80b6c712e26f97fda..ea69414ff0952155e952cbb8e27f5b065fbe864c 100644 --- a/mindspore-lite/test/ut/nnacl/infer/space_to_depth_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/space_to_depth_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/space_to_depth_infer.h" +#include "nnacl_c/infer/space_to_depth_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/sparse_to_dense_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/sparse_to_dense_infer_test.cc index a04d0d03c85236b40603f1eadbd412dea955458a..76bf4f9012e0eafd12e75756ee1831dcbd02928f 100644 --- a/mindspore-lite/test/ut/nnacl/infer/sparse_to_dense_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/sparse_to_dense_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/sparse_to_dense_infer.h" +#include "nnacl_c/infer/sparse_to_dense_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/split_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/split_infer_test.cc index 64242b0af7e83843aa5dbadbbebea7ef1ce384e9..f755188d85b3574a4a212e200857cc575c699a22 100644 --- a/mindspore-lite/test/ut/nnacl/infer/split_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/split_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/split_infer.h" +#include "nnacl_c/infer/split_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/squeeze_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/squeeze_infer_test.cc index 3873efc66726f673e938debe3d450e9e6fd38a8d..5b18156324266697ec2f9775e1fe0ed7a80b1626 100644 --- a/mindspore-lite/test/ut/nnacl/infer/squeeze_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/squeeze_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/squeeze_infer.h" +#include "nnacl_c/infer/squeeze_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/stack_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/stack_infer_test.cc index 7cddaab674212556d40ce0620be4555de05fb03f..654355384a86aa11f7c0660ef11655b77e7b2e13 100644 --- a/mindspore-lite/test/ut/nnacl/infer/stack_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/stack_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/stack_infer.h" +#include "nnacl_c/infer/stack_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/strided_slice_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/strided_slice_infer_test.cc index 37fbd4e52aec6d7dd206b76952445ea6e7a1f25a..6a4ab89c74237b0ba7712e5576e4a5177c975bbe 100644 --- a/mindspore-lite/test/ut/nnacl/infer/strided_slice_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/strided_slice_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/strided_slice_infer.h" +#include "nnacl_c/infer/strided_slice_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/tensorlist_fromtensor_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/tensorlist_fromtensor_infer_test.cc index 481690f8d2097443e003e013a8f50b099ddb8d8a..eafde94158f7d4828b368234b66215e90959b676 100644 --- a/mindspore-lite/test/ut/nnacl/infer/tensorlist_fromtensor_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/tensorlist_fromtensor_infer_test.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" #include "src/common/tensor_util.h" -#include "nnacl/infer/control/tensorlist_fromtensor_infer.h" +#include "nnacl_c/infer/control/tensorlist_fromtensor_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/tensorlist_getitem_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/tensorlist_getitem_infer_test.cc index 4f39c0936700e833adc936707946f6df404b627d..2a235823dc47418dd911b6dbcdb645e25e1a7df5 100644 --- a/mindspore-lite/test/ut/nnacl/infer/tensorlist_getitem_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/tensorlist_getitem_infer_test.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" #include "src/common/tensor_util.h" -#include "nnacl/infer/control/tensorlist_getitem_infer.h" +#include "nnacl_c/infer/control/tensorlist_getitem_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/tensorlist_reserve_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/tensorlist_reserve_infer_test.cc index abefbeea3c7841598fb2224afd424044ebae0dbd..fe8262382c94d067295ef7d63554d2e1aa604cab 100644 --- a/mindspore-lite/test/ut/nnacl/infer/tensorlist_reserve_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/tensorlist_reserve_infer_test.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" #include "src/common/tensor_util.h" -#include "nnacl/infer/control/tensorlist_reserve_infer.h" +#include "nnacl_c/infer/control/tensorlist_reserve_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/tensorlist_setitem_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/tensorlist_setitem_infer_test.cc index f726e159bc84b6d7622356577101c84302a70110..8102692683c12430da207c951809cb74c1206b3b 100644 --- a/mindspore-lite/test/ut/nnacl/infer/tensorlist_setitem_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/tensorlist_setitem_infer_test.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" #include "src/common/tensor_util.h" -#include "nnacl/infer/control/tensorlist_setitem_infer.h" +#include "nnacl_c/infer/control/tensorlist_setitem_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/tensorlist_stack_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/tensorlist_stack_infer_test.cc index 42a7adbf68902ccb51af3bea364de0c073b40bad..8095670e70f7be58bd336c3ce343d236d847d244 100644 --- a/mindspore-lite/test/ut/nnacl/infer/tensorlist_stack_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/tensorlist_stack_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/control/tensorlist_stack_infer.h" +#include "nnacl_c/infer/control/tensorlist_stack_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/tile_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/tile_infer_test.cc index 705c127e3d793e68694203d035a057c56f725962..316c9f3fffc6a3ff49c14b7345dc7a0e9cbaeb26 100644 --- a/mindspore-lite/test/ut/nnacl/infer/tile_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/tile_infer_test.cc @@ -14,9 +14,9 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/tile_infer.h" -#include "nnacl/base/tile_base.h" -#include "nnacl/tile_parameter.h" +#include "nnacl_c/infer/tile_infer.h" +#include "nnacl_c/base/tile_base.h" +#include "nnacl_c/tile_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/topk_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/topk_infer_test.cc index d467ae27b6897de047b81c1682fa01b5c5c493f3..55e98524a89e12707a92955b64c6dcf253fdcd0e 100644 --- a/mindspore-lite/test/ut/nnacl/infer/topk_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/topk_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/topk_infer.h" +#include "nnacl_c/infer/topk_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/transpose_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/transpose_infer_test.cc index 5d300eb88e92492f38b6bb8f485fb9b054efb474..4096d1cd50e7c093483da576468ab8a02eefd2b1 100644 --- a/mindspore-lite/test/ut/nnacl/infer/transpose_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/transpose_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/transpose_infer.h" +#include "nnacl_c/infer/transpose_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/unique_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/unique_infer_test.cc index 0facfd6c480dd8ee4e50d34ab3e2328bc71f808e..534921c91eff093c341d1e6f826b4d77ad1c39ec 100644 --- a/mindspore-lite/test/ut/nnacl/infer/unique_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/unique_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/unique_infer.h" +#include "nnacl_c/infer/unique_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/unsorted_segment_sum_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/unsorted_segment_sum_infer_test.cc index 9094cc25bd1bb7a2d05b4cc99289fa35458959f6..4f14be0f6e7476a983ab30df97db7ad1c26f76f2 100644 --- a/mindspore-lite/test/ut/nnacl/infer/unsorted_segment_sum_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/unsorted_segment_sum_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/unsorted_segment_sum_infer.h" +#include "nnacl_c/infer/unsorted_segment_sum_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/unsqueeze_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/unsqueeze_infer_test.cc index 98ad4b3fb638be0a664495b4033bdd3ea70a1774..bcf3bd196dbd46663fefd9e038db07e8318ebe88 100644 --- a/mindspore-lite/test/ut/nnacl/infer/unsqueeze_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/unsqueeze_infer_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/unsqueeze_infer.h" -#include "nnacl/unsqueeze_parameter.h" +#include "nnacl_c/infer/unsqueeze_infer.h" +#include "nnacl_c/unsqueeze_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/unstack_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/unstack_infer_test.cc index ee31cae61623fbb91112d18feaaddb55b0023a09..c5bcb70c193af53c90b52a787b7eb32ae7c75c2d 100644 --- a/mindspore-lite/test/ut/nnacl/infer/unstack_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/unstack_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/unstack_infer.h" +#include "nnacl_c/infer/unstack_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/where_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/where_infer_test.cc index 39bfe729c70b23f858bfb468506b09ed3dc9e9ad..f24572b8f7dff89e4371d02da73fbf628a3177da 100644 --- a/mindspore-lite/test/ut/nnacl/infer/where_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/where_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/where_infer.h" +#include "nnacl_c/infer/where_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/int8/quant_dtype_cast_int8_test.cc b/mindspore-lite/test/ut/nnacl/int8/quant_dtype_cast_int8_test.cc index d8df2ff1949934ccbd587cead7072866e94773ee..18c2074e20e31bc2ed7a6b3aa82f43b511eafc00 100644 --- a/mindspore-lite/test/ut/nnacl/int8/quant_dtype_cast_int8_test.cc +++ b/mindspore-lite/test/ut/nnacl/int8/quant_dtype_cast_int8_test.cc @@ -16,9 +16,9 @@ #include #include #include "common/common_test.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" #ifdef ENABLE_ARM64 -#include "nnacl/fp16/quant_dtype_cast_fp16.h" +#include "nnacl_c/fp16/quant_dtype_cast_fp16.h" #endif namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/kernel/cast_test.cc b/mindspore-lite/test/ut/nnacl/kernel/cast_test.cc index 63f95d22584b4f206db4de846fa7ca33185f1bc8..ef8db4b3627561d77161f40527aebd419e65dfa0 100644 --- a/mindspore-lite/test/ut/nnacl/kernel/cast_test.cc +++ b/mindspore-lite/test/ut/nnacl/kernel/cast_test.cc @@ -16,9 +16,9 @@ #include #include #include "common/common_test.h" -#include "nnacl/op_base.h" -#include "nnacl/base/cast_base.h" -#include "nnacl/kernel/cast.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/cast_base.h" +#include "nnacl_c/kernel/cast.h" namespace mindspore { class CastTest : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc index 5642f8351936bf7f5e636c5680deb6d699aed0bf..9c2fd933f47aed464d9eeb328d19201e18dff56a 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc @@ -19,11 +19,11 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/common/file_utils.h" -#include "nnacl/pack.h" -#include "nnacl/fp32/conv_common_fp32.h" +#include "nnacl_c/pack.h" +#include "nnacl_c/fp32/conv_common_fp32.h" #ifdef ENABLE_FP16 -#include "nnacl/fp16/pack_fp16.h" -#include "nnacl/fp16/conv_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/fp16/conv_fp16.h" #endif namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/common/strided_slice_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/common/strided_slice_tests.cc index 971e0c4ddf3f9e96e1eca5aad0dd7d54390fd208..c51aa18c81441e0476e441e0e82f57a3f5496e57 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/common/strided_slice_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/common/strided_slice_tests.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" #include "nnacl/nnacl_manager.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/activation_grad_fp16_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/activation_grad_fp16_test.cc index 2b671b7cc1b49d7cbd0593cc37f44136aff39ade..75f20b1694f6b6eea05bd45a3e359c360a14a31c 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/activation_grad_fp16_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/activation_grad_fp16_test.cc @@ -21,7 +21,7 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/common/file_utils.h" -#include "nnacl/fp16_grad/activation_grad_fp16.h" +#include "nnacl_c/fp16_grad/activation_grad_fp16.h" namespace mindspore { class TestActGradFp16 : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/arithmetic_fp16_self_grad_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/arithmetic_fp16_self_grad_tests.cc index add5c8ec6c10b1fbe0d4807379ea9deb7d966de1..3d5a767d90b03a0c432f4b7934a294a7ba7fe899 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/arithmetic_fp16_self_grad_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/arithmetic_fp16_self_grad_tests.cc @@ -21,7 +21,7 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/common/file_utils.h" -#include "nnacl/fp16_grad/arithmetic_self_grad.h" +#include "nnacl_c/fp16_grad/arithmetic_self_grad.h" namespace mindspore { class TestArithmeticSelfGradFp16 : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32-sparsity/matmul_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32-sparsity/matmul_fp32_tests.cc index 6344e560b4e7df671ea03ad404da6add0351022a..a19d152cb72c8e174f767129d0dfbe08c1719424 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32-sparsity/matmul_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32-sparsity/matmul_fp32_tests.cc @@ -19,7 +19,7 @@ #include #include "common/common_test.h" #include "src/common/file_utils.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "src/tensor.h" #include "include/securec.h" #include "src/litert/infer_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc index 9f7dc50fc11006eb3ce5a4012f7e68ce48703ef2..2cab22921e6fb05388fa8c81c9b356be490bf797 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc @@ -16,7 +16,7 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" #include "src/executor/kernel_exec.h" #include "nnacl/nnacl_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batch_to_space_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batch_to_space_fp32_test.cc index edb2dde96fd62b3434aa8b0f3d92bb6466941337..a89a7338e5b319e3c91cb3044058d6a0a8fc6e09 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batch_to_space_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batch_to_space_fp32_test.cc @@ -15,9 +15,9 @@ */ #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/base/batch_to_space_base.h" -#include "nnacl/batch_to_space_parameter.h" -#include "nnacl/common_func.h" +#include "nnacl_c/base/batch_to_space_base.h" +#include "nnacl_c/batch_to_space_parameter.h" +#include "nnacl_c/common_func.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc index 03995b9c407914244945d47352d43bc1ad70c43a..2943341be97ec5de54530cd0a7e641fc343d3345 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc @@ -16,7 +16,7 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" #include "nnacl/nnacl_manager.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc index 019d5ffa52a739569d07435465988dfb3b273a68..6c6420dfbcc040feda35e4f1268c1a999635ad72 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc @@ -15,7 +15,7 @@ */ #include #include "common/common_test.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "src/litert/kernel/cpu/fp32/convolution_1x1_fp32.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/crop_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/crop_fp32_test.cc index 7f28e655a3aa5d66346753dee317b13d9a63e1ef..83d510c4d30d944f45af893e68c4bfb80aeea17a 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/crop_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/crop_fp32_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/fp32/crop_fp32.h" +#include "nnacl_c/fp32/crop_fp32.h" #include "src/litert/tensor_category.h" #include "src/litert/lite_kernel.h" #include "nnacl/nnacl_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/cumsum_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/cumsum_tests.cc index 2e483c0a3347e2ae75a1310b91383b4d254010bf..0fa934716a14184c2fe2e54680345e72e4c5badb 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/cumsum_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/cumsum_tests.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" -#include "nnacl/cumsum_parameter.h" +#include "nnacl_c/cumsum_parameter.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc index 39dcc6ef170e3c0149196188e8f34973100b04f0..16a9c5983dfc6e4b145f96419b1461ec2277704b 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc @@ -18,8 +18,8 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/common/file_utils.h" -#include "nnacl/fp32/deconv_fp32.h" -#include "nnacl/op_base.h" +#include "nnacl_c/fp32/deconv_fp32.h" +#include "nnacl_c/op_base.h" #include "src/litert/tensor_category.h" #include "nnacl/nnacl_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/depth_to_space_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/depth_to_space_fp32_test.cc index 08497018f904f203f536a34fd85b73099c776ea3..901dcd0626b143223920ab1b14f225509336c670 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/depth_to_space_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/depth_to_space_fp32_test.cc @@ -15,10 +15,10 @@ */ #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/base/depth_to_space_base.h" -#include "nnacl/common_func.h" -#include "nnacl/depth_to_space_parameter.h" -#include "nnacl/kernel/depth_to_space.h" +#include "nnacl_c/base/depth_to_space_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/depth_to_space_parameter.h" +#include "nnacl_c/kernel/depth_to_space.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc index f665feb94fd6363d8fcb71d917fe286a26898236..92fc83cadc0b1758eafc3742f82f3e12d9756d50 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc @@ -16,7 +16,7 @@ #include #include "src/litert/kernel/cpu/fp32/embedding_lookup_fp32.h" -#include "nnacl/fp32/embedding_lookup_fp32.h" +#include "nnacl_c/fp32/embedding_lookup_fp32.h" #include "src/common/file_utils.h" #include "common/common_test.h" #include "src/common/log_adapter.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc index 29d0ed36c889548049367551d31be605000de324..1c9308f566f8dee8858c6d5f5e05bd164851b439 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc @@ -17,7 +17,7 @@ #include #include #include "common/common_test.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" #include "src/common/file_utils.h" #include "src/litert/tensor_category.h" #include "src/common/log_adapter.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/logicalor_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/logicalor_fp32_test.cc index 68cfdbc3dd5e0b2d504265e6df1bb68a49031019..58cbcc248a729170f046386a417747641c759fac 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/logicalor_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/logicalor_fp32_test.cc @@ -16,7 +16,7 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lsh_projection_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lsh_projection_fp32_tests.cc index 687562f0f4a78a14c6c8d9da681d773dbec709d5..deca358511ad4826d7dc14dbbf3ea0e56ec8fca4 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lsh_projection_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lsh_projection_fp32_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/lsh_projection_parameter.h" +#include "nnacl_c/lsh_projection_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc index 9972cf5eb1282ef952a961d9a116c94030ca83c3..9441e03ba976122d9f4b2349bc4ce6e661c4f6a5 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc @@ -17,7 +17,7 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc index 5f20562575e4d416a92d2fd9fbc9c3bf4337c6e6..b7dcc7121fd743be5ffcf85e6527e98fa2704a73 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc @@ -16,8 +16,8 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/pack_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/litert/tensor_category.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/nllloss_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/nllloss_fp32_test.cc index 3c3b47c8c3e16ed8941235bcbbebd348e893b2af..6fdbba8fd41fa5cebbb17166b7a2777e7535fc7d 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/nllloss_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/nllloss_fp32_test.cc @@ -18,7 +18,7 @@ #include "common/common_test.h" #include "src/executor/kernel_exec.h" #include "src/litert/tensor_category.h" -#include "nnacl/nllloss_parameter.h" +#include "nnacl_c/nllloss_parameter.h" #include "nnacl/nnacl_manager.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/one_hot_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/one_hot_fp32_test.cc index c23773feac1c864c1c5190ea722860cf83c28232..cb9ec5c8f93ec3e18568018457ef5ac7094d53ff 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/one_hot_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/one_hot_fp32_test.cc @@ -17,7 +17,7 @@ #include "src/executor/kernel_exec.h" #include "src/tensor.h" #include "common/common_test.h" -#include "nnacl/fp32/one_hot_fp32.h" +#include "nnacl_c/fp32/one_hot_fp32.h" #include "src/litert/kernel_registry.h" #include "schema/ops_generated.h" #include "nnacl/nnacl_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc index 8d936215781cea184b62bbf75d6e6ef44be40c0c..7a942a1424ced78041a78cfbc9ec2c2f36d95ca2 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc @@ -19,7 +19,7 @@ #include "src/executor/kernel_exec.h" #include "src/litert/tensor_category.h" #include "nnacl/nnacl_manager.h" -#include "nnacl/pow_parameter.h" +#include "nnacl_c/pow_parameter.h" namespace mindspore { class TestPowerFp32 : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/ragged_range_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/ragged_range_fp32_tests.cc index aac8dd24f3b8d018be2fe23eb1df31e71814dd73..2d5360759c79273672a7825b91c0d5de7ac63963 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/ragged_range_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/ragged_range_fp32_tests.cc @@ -16,7 +16,7 @@ #include #include #include "common/common_test.h" -#include "nnacl/fp32/ragged_range_fp32.h" +#include "nnacl_c/fp32/ragged_range_fp32.h" #include "src/tensor.h" #include "src/executor/kernel_exec.h" #include "nnacl/nnacl_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reduce_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reduce_fp32_tests.cc index 7db8176cf68fb7dd614b3438cf2c47e5794deb35..ca6857e45ae5ffeab9589cd4a9fc1bf0383ba229 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reduce_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reduce_fp32_tests.cc @@ -16,7 +16,7 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/reduce_fp32.h" +#include "nnacl_c/fp32/reduce_fp32.h" #include "schema/inner/model_generated.h" #include "src/tensor.h" #include "nnacl/nnacl_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_bilinear_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_bilinear_fp32_tests.cc index 7edc4e2d2e337ad2a8aa0edc73529eea367e7d8d..0250d9fb3cab053468812911407742f515816bd2 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_bilinear_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_bilinear_fp32_tests.cc @@ -18,7 +18,7 @@ #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" #include "schema/ops_generated.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_nearest_neighbor_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_nearest_neighbor_fp32_tests.cc index 01acf479c24abe2077db2905d35be33c1a951e5d..800a586e278f76304e5593e61e468de5995d4dbf 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_nearest_neighbor_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_nearest_neighbor_fp32_tests.cc @@ -15,7 +15,7 @@ */ #include #include "common/common_test.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reverse_sequence_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reverse_sequence_fp32_tests.cc index 4f738dc9d0e59759439e1b31e78718047e808dd9..0a1f041a1e56fcbf934d7baa47e2f8211a039900 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reverse_sequence_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reverse_sequence_fp32_tests.cc @@ -17,7 +17,7 @@ #include #include #include "common/common_test.h" -#include "nnacl/fp32/reverse_sequence_fp32.h" +#include "nnacl_c/fp32/reverse_sequence_fp32.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scale_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scale_fp32_tests.cc index 785ca827fbc1cf4de1d3a6eafa2fc9ee187b0716..5b16bb68fbcb5b5585b64e896e5b23c35c06de7c 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scale_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scale_fp32_tests.cc @@ -17,10 +17,10 @@ #include "src/executor/kernel_exec.h" #include "src/tensor.h" #include "common/common_test.h" -#include "nnacl/pad_parameter.h" +#include "nnacl_c/pad_parameter.h" #include "schema/ops_generated.h" -#include "nnacl/fp32/scale_fp32.h" -#include "nnacl/scale_parameter.h" +#include "nnacl_c/fp32/scale_fp32.h" +#include "nnacl_c/scale_parameter.h" #include "nnacl/nnacl_manager.h" using mindspore::schema::ActivationType; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_add_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_add_fp32_test.cc index ef8a7b118809592c977941a9a5d67fe0cf0609de..47cf6f3e9f2b6d5b83b9a847ff6ae6c6f95c31c4 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_add_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_add_fp32_test.cc @@ -16,7 +16,7 @@ #include "common/common_test.h" #include "src/litert/kernel_registry.h" -#include "nnacl/scatter_nd_parameter.h" +#include "nnacl_c/scatter_nd_parameter.h" namespace mindspore { using mindspore::lite::Tensor; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_fp32_tests.cc index 117950c1d8a56ee5641cb3ccce04d902be61e9c2..b171b5b59286e1a07cbafc7b077eeb0550ed85b2 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_fp32_tests.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" #include "src/litert/kernel_registry.h" -#include "nnacl/scatter_nd_parameter.h" +#include "nnacl_c/scatter_nd_parameter.h" namespace mindspore { using mindspore::lite::Tensor; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/skip_gram_fp32.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/skip_gram_fp32.cc index 0f079ef5f8f23143076840a9a0e86f26da327788..aaa12c6e1077438e2e067351bfad825325556be8 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/skip_gram_fp32.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/skip_gram_fp32.cc @@ -16,7 +16,7 @@ #include #include "src/litert/kernel/cpu/string/skip_gram.h" -#include "nnacl/skip_gram_parameter.h" +#include "nnacl_c/skip_gram_parameter.h" #include "src/common/file_utils.h" #include "src/litert/tensor_category.h" #include "common/common_test.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/softmax_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/softmax_tests.cc index 84aee370ea26da0dabd716a2d184b3fa088f01a2..57a287d4612d90dec9f604ca44d1a336890e3c73 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/softmax_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/softmax_tests.cc @@ -16,7 +16,7 @@ #include #include #include "common/common_test.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "src/litert/kernel/cpu/nnacl/nnacl_manager.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc index 20f49eab04b20ad4d0c96697c9592e0cdff641c7..88e177b03cd3d4767ca934c9d30e491920a8b1fd 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc @@ -17,7 +17,7 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_depth_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_depth_fp32_tests.cc index 4ee7d2558c228e7bbaae9f13e0526332b01da200..1ebbe4f872c264079e9858708fc9a17ad0f4742f 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_depth_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_depth_fp32_tests.cc @@ -18,8 +18,8 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/space_to_depth_parameter.h" -#include "nnacl/base/space_to_depth_base.h" +#include "nnacl_c/space_to_depth_parameter.h" +#include "nnacl_c/base/space_to_depth_base.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/sparse_to_dense_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/sparse_to_dense_fp32_tests.cc index d390a003de8bb1ad88ea631d6891491bfcdc9d84..08704f9963a5ca0f677dde129ebdd299e955da45 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/sparse_to_dense_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/sparse_to_dense_fp32_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/sparse_to_dense_fp32.h" +#include "nnacl_c/fp32/sparse_to_dense_fp32.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/stack_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/stack_fp32_test.cc index c454945b58c1978d442d5625622aecd2965cf831..758e7f3f5ed0bdedb3e78cadaace409ef584a4f5 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/stack_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/stack_fp32_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/base/stack_base.h" +#include "nnacl_c/base/stack_base.h" namespace mindspore { class StackTestFp32 : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/tile_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/tile_fp32_tests.cc index 44e538cff61f0e20e1a6ea91d874575a32a113ee..fde4a2bf0158e61071fcd736b220bb48162a429d 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/tile_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/tile_fp32_tests.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" -#include "nnacl/tile_parameter.h" +#include "nnacl_c/tile_parameter.h" #include "src/litert/kernel_registry.h" #include "src/litert/kernel/cpu/nnacl/nnacl_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc index be71e7c888d37ef80e091ba96115c90db58253b6..f6852855a123391c5cbcf07d89fabf6d9100e6b2 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc @@ -17,7 +17,7 @@ #include #include #include "common/common_test.h" -#include "nnacl/fp32/topk_fp32.h" +#include "nnacl_c/fp32/topk_fp32.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc index 1978d45c7130219fdc2dd430ec031ec67c124a41..87210921803ff0c5c67a3f89367ca0e8d8c76c4a 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc @@ -18,8 +18,8 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/transpose_fp32.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/fp32/transpose_fp32.h" +#include "nnacl_c/transpose_parameter.h" #include "nnacl/nnacl_manager.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/uniform_real_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/uniform_real_fp32_test.cc index 4a2989bf854dcf439d345661fa8bc667261ab832..be300dea62eb2c95af34d8aa1a00c8b7f883b2f3 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/uniform_real_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/uniform_real_fp32_test.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unique_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unique_fp32_tests.cc index c5c1849274841a0f082a78a0698f862d90583516..047d16acca669b56fe947479c4d636a8482641f9 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unique_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unique_fp32_tests.cc @@ -17,7 +17,7 @@ #include #include #include "common/common_test.h" -#include "nnacl/fp32/unique_fp32.h" +#include "nnacl_c/fp32/unique_fp32.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unstack_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unstack_fp32_tests.cc index 09d61569f26811aec3a586d32112824f907b993e..4b57c33c19abececa87468928f66d0b74589bca5 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unstack_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unstack_fp32_tests.cc @@ -17,7 +17,7 @@ #include #include #include "common/common_test.h" -#include "nnacl/base/unstack_base.h" +#include "nnacl_c/base/unstack_base.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc index fb7c90c1d227a8ffea0f95aab62a69c4c0361c04..1d0ef614605a42d9638011e5dbbe79d64958b495 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc @@ -25,7 +25,7 @@ #include "src/tensor.h" #include "src/executor/kernel_exec.h" #include "src/litert/kernel/cpu/fp32_grad/activation_grad.h" -#include "nnacl/fp32_grad/activation_grad_fp32.h" +#include "nnacl_c/fp32_grad/activation_grad_fp32.h" namespace mindspore { class TestActGradFp32 : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/arithmetic_grad_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/arithmetic_grad_fp32_tests.cc index 144dd5c1d9bf6974bdc11dadacba458f27ff1844..72265769f25a52f645f92acaae0454f9f7f109b7 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/arithmetic_grad_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/arithmetic_grad_fp32_tests.cc @@ -21,7 +21,7 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/common/file_utils.h" -#include "nnacl/fp32/reduce_fp32.h" +#include "nnacl_c/fp32/reduce_fp32.h" #include "src/litert/kernel/cpu/fp32_grad/arithmetic_grad.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc index 16d44cfa10c135cfffb61b843b59c8e18b4a65a2..642f5e204be9e29a227f86441fdfac2f39181435 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc @@ -19,10 +19,10 @@ #include "common/common_test.h" #include "src/common/file_utils.h" #include "src/litert/kernel/cpu/fp32_grad/bn_grad.h" -#include "nnacl/fp32_grad/batch_norm_grad.h" -#include "nnacl/fp32/batchnorm_fp32.h" +#include "nnacl_c/fp32_grad/batch_norm_grad.h" +#include "nnacl_c/fp32/batchnorm_fp32.h" #include "src/litert/kernel_registry.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" namespace mindspore { constexpr int kSize3 = 3; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc index dbc4dd0ecb7c019165a5a32c3f5ea796e85121a7..669032a9fe4243d6f44a0f2bbe44a907cb3f17b1 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc @@ -23,7 +23,7 @@ #include "src/litert/kernel/cpu/fp32_grad/convolution.h" #include "src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.h" #include "src/litert/kernel/cpu/fp32_grad/convolution_grad_input.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc index f5446f121ca6c98ef19d2189f3c3ed50f6265aec..cd1ae7e598321172b77a407ea424c0204602a9bf 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc @@ -20,7 +20,7 @@ #include "common/common_test.h" #include "src/common/file_utils.h" #include "src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc index c6eda84e517b99148b74526483985a5881573a3a..473caf06d37d6ba3d264693883dee97f7c180c0c 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc @@ -20,8 +20,8 @@ #include "common/common_test.h" #include "src/common/utils.h" #include "src/common/file_utils.h" -#include "nnacl/fp32_grad/pooling_grad.h" -#include "nnacl/kernel/pooling.h" +#include "nnacl_c/fp32_grad/pooling_grad.h" +#include "nnacl_c/kernel/pooling.h" #include "src/litert/kernel/cpu/fp32_grad/pooling_grad.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_grad_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_grad_fp32_tests.cc index 5e0be9ef77473f379429a579cbcccc0cb76004fa..b04e75c25b47bbf932dad12db4468a1282e16337 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_grad_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_grad_fp32_tests.cc @@ -22,7 +22,7 @@ #include "src/common/utils.h" #include "src/common/file_utils.h" #include "src/litert/kernel/cpu/fp32_grad/softmax_grad.h" -#include "nnacl/fp32_grad/softmax_grad.h" +#include "nnacl_c/fp32_grad/softmax_grad.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc index f4463fce2b91b77310d1d9d44acfab709e6c39f3..8eb3c104a199a1a51c05b102c752932da10612da 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc @@ -17,7 +17,7 @@ #include #include "schema/inner/model_generated.h" #include "common/common_test.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/arithmetic_self_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/batchnorm_int8_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/batchnorm_int8_test.cc index eb560385f89cd7762f874301d32f5f848bc58fc5..72abe37c287ab054fcd53e8be144a2817fe16a34 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/batchnorm_int8_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/batchnorm_int8_test.cc @@ -17,8 +17,8 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/batchnorm_parameter.h" -#include "nnacl/int8/batchnorm_int8.h" +#include "nnacl_c/batchnorm_parameter.h" +#include "nnacl_c/int8/batchnorm_int8.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc index 52d09191abdf381db02c50bef5b692e4d216bd8c..ca25ffd238f52634fa7e08e70f1be689f5cf1279 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/conv_1x1_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/conv_1x1_int8_tests.cc index 6715924847c52734667602fb9fa3cf0335f2e743..0ba9e2a7465fccce9bf44f0e7fbc3a57e9c6815b 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/conv_1x1_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/conv_1x1_int8_tests.cc @@ -17,8 +17,8 @@ #include "common/common_test.h" #include "src/executor/kernel_exec.h" #include "src/common/file_utils.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/common_func.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/common_func.h" #include "src/litert/kernel/cpu/int8/convolution_1x1_int8.h" #include "src/litert/tensor_category.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc index a7f57980cfaec5c4cc3e8b076d82bb10b5e3de61..a301fba65257a42817c4395c0bbddfda6ffc7ceb 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/crop_parameter.h" +#include "nnacl_c/crop_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc index fe80324f738612cb889b51a8cd08989071452fb0..80c3d256ab288bcf959dffbf1a1e481438599389 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc @@ -20,9 +20,9 @@ #include "common/common_test.h" #include "src/common/file_utils.h" #include "src/litert/kernel_registry.h" -#include "nnacl/pack.h" -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/int8/deconv_int8.h" +#include "nnacl_c/pack.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/int8/deconv_int8.h" #include "src/litert/kernel/cpu/int8/deconvolution_int8.h" using mindspore::lite::DeviceType; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc index 64bf27686dd74f0e60ee4e2700e11fce9a7d2b39..21e9c732ed518e29140480de57c4f18472ecfd2e 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc @@ -17,8 +17,8 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/litert/kernel/cpu/int8/fullconnection_int8.h" -#include "nnacl/common_func.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gatherNd_int8_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gatherNd_int8_test.cc index fc5376dcddfeb716dc87c1816ccff92f8e67b13f..c3bbcf72f266741b9a6427778c46ee6a08ec4cf0 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gatherNd_int8_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gatherNd_int8_test.cc @@ -16,9 +16,9 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/gatherNd_fp32.h" -#include "nnacl/int8/gatherNd_int8.h" -#include "nnacl/gather_nd_parameter.h" +#include "nnacl_c/fp32/gatherNd_fp32.h" +#include "nnacl_c/int8/gatherNd_int8.h" +#include "nnacl_c/gather_nd_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gather_int8_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gather_int8_test.cc index f136ea84b2e21cf59793c68ff196e8f9640e18a8..f21710645f87d7b9e2858676967c0759f91a9ac8 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gather_int8_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gather_int8_test.cc @@ -16,8 +16,8 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/gather_parameter.h" -#include "nnacl/int8/gather_int8.h" +#include "nnacl_c/gather_parameter.h" +#include "nnacl_c/int8/gather_int8.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc index cd3a3f9b25ec53d5a5a6216d64cc4e244ae37de2..ef03b3f52b029a9a51bdec3d0853f9c2440d89e3 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc @@ -18,7 +18,7 @@ #include #include "schema/inner/model_generated.h" #include "common/common_test.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" #include "src/litert/kernel/cpu/int8/hswish_int8.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/l2_norm_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/l2_norm_int8_tests.cc index 9109f9582db0538e35a95e5fc4fc3ee278ab383c..b8ea47a507b0198c94530867f490c35b543fb7f7 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/l2_norm_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/l2_norm_int8_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "common/common_test.h" #include "src/litert/kernel_registry.h" -#include "nnacl/l2_norm_parameter.h" +#include "nnacl_c/l2_norm_parameter.h" namespace mindspore { class TestL2NormInt8 : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc index 588f11fd65ce90ee7cafb7d1c350ec9822dcad28..fde391e557101e2175962491fe9faa43779521fd 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc @@ -18,9 +18,9 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/litert/kernel/cpu/int8/matmul_int8.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/common_func.h" -#include "nnacl/int8/matmul_int8.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/int8/matmul_int8.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc index 2da74fa23c48342da4617dc6db82fa5f4abb59f3..e42c11226545891d6d749a519b088767242b605a 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc @@ -18,11 +18,11 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/mul_parameter.h" +#include "nnacl_c/mul_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc index d8cbc4bbc35d35fba933f173ea73914cf639097a..1d94b775e17672837f33717731226e6d30a3310b 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc @@ -20,7 +20,7 @@ #include "src/litert/tensor_category.h" #include "common/common_test.h" #include "src/common/file_utils.h" -#include "nnacl/pad_parameter.h" +#include "nnacl_c/pad_parameter.h" #include "src/litert/kernel/cpu/int8/pad_int8.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/power_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/power_int8_tests.cc index 039e37efc84455c5c9ceb9c51f197012cac0508d..3899c22acd346bf23a7282f673bd1b5978437bad 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/power_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/power_int8_tests.cc @@ -19,7 +19,7 @@ #include "schema/inner/model_generated.h" #include "common/common_test.h" #include "src/litert/kernel/cpu/int8/power_int8.h" -#include "nnacl/pow_parameter.h" +#include "nnacl_c/pow_parameter.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc index e87b38594017adff77eefa9db65ada2c9c39c5b1..0e55eb45c22d911dbceaf6149b9301c7e1cf2b4e 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc index d362571a3ce0e570e24248b661bfb56e1d9f755a..576225b6ba9171e64f24dd4b77766af7439aeb68 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc @@ -19,7 +19,7 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/litert/kernel/cpu/base/quant_dtype_cast.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reduce_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reduce_int8_tests.cc index 88099503fa8e86f4e419bdefb956e9795a791100..c9e038eb26df1354c32fbf9162e0215eff54d484 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reduce_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reduce_int8_tests.cc @@ -19,7 +19,7 @@ #include "common/common_test.h" #include "src/tensor.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/reduce_fp32.h" +#include "nnacl_c/fp32/reduce_fp32.h" namespace mindspore { using mindspore::lite::LiteQuantParam; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc index 794f0e056bd43d2cc72fb7b579b80f10948dda52..a765e82f2a1f466cf760746e97db34489600d238 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/reshape_parameter.h" +#include "nnacl_c/reshape_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc index 9777c6b613627acdb1239a121f3f7a3f4939817f..c79ef65b3e6f734e9a463bbe26787d36098e1bc4 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc @@ -19,7 +19,7 @@ #include "src/tensor.h" #include "common/common_test.h" #include "src/litert/kernel_registry.h" -#include "nnacl/int8/resize_int8.h" +#include "nnacl_c/int8/resize_int8.h" namespace mindspore { using mindspore::lite::LiteQuantParam; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc index d1b7985d354a938c49e92adfee4f3b8dcb41ab67..c664a889d520d16959f132c7c471326c37dd8f74 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc @@ -19,7 +19,7 @@ #include "src/tensor.h" #include "common/common_test.h" #include "src/litert/kernel_registry.h" -#include "nnacl/int8/resize_int8.h" +#include "nnacl_c/int8/resize_int8.h" namespace mindspore { using mindspore::lite::LiteQuantParam; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/scale_int8.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/scale_int8.cc index 4d16b9a9988dd7fef8326f7bc7187910c75e0db6..2df0fbc276f5c6fe6143ff8ee4d82f73add4c166 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/scale_int8.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/scale_int8.cc @@ -19,7 +19,7 @@ #include "common/common_test.h" #include "src/tensor.h" #include "src/litert/kernel_registry.h" -#include "nnacl/int8/scale_int8.h" +#include "nnacl_c/int8/scale_int8.h" namespace mindspore { using mindspore::lite::LiteQuantParam; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc index 251c68d66d0f783b328aaf3f256191383aa6025b..148ea7b537bca64506d38aa5094d5062ae5e14ac 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc @@ -17,7 +17,7 @@ #include #include "schema/inner/model_generated.h" #include "common/common_test.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc index cd58822bcc3cfe1920033fc26630b626538e4bbd..3893a239b480736f19eca42c43a87857213ab623 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc @@ -19,7 +19,7 @@ #include "schema/inner/model_generated.h" #include "common/common_test.h" #include "src/litert/kernel/cpu/int8/softmax_int8.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/space_to_batch_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/space_to_batch_int8_tests.cc index 4a7a0e8618298e4321f5f254b690e66bc622a1db..d3ad1f6b58cb303ecedf465ffecbe9aff227cc25 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/space_to_batch_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/space_to_batch_int8_tests.cc @@ -15,7 +15,7 @@ */ #include #include "common/common_test.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc index f528a8a8198501e54018b3b10d8101747e5520a3..7fe33d32137b8110fa6f56f3ea986cf6040428bc 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc index 46b8c22a522acf143877e69d9cc9fdbcfdd9336c..51b6102b3a2f39a26c6e46bef2ab6f14fecf1f03 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/squeeze_parameter.h" +#include "nnacl_c/squeeze_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc index 2e6051f2c8b1829f925a918bff066dabe2922c0f..173e01368c957d10d0e72fab11ea5cbe43812d64 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc @@ -18,7 +18,7 @@ #include #include "schema/inner/model_generated.h" #include "common/common_test.h" -#include "nnacl/fp32/topk_fp32.h" +#include "nnacl_c/fp32/topk_fp32.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc index a31ba896474322685ae5714d81d48b66cbd059d4..d509b258bd3aa3f7ad5c7170a422a3666b2b55b4 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc @@ -17,7 +17,7 @@ #include #include "schema/inner/model_generated.h" #include "common/common_test.h" -#include "nnacl/unsqueeze_parameter.h" +#include "nnacl_c/unsqueeze_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/string/normalize.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/string/normalize.cc index c79d737b19c666412cfe84f8a644d645ec362c3f..6644e069fb8c8c388e52303daf2643f2ea559929 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/string/normalize.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/string/normalize.cc @@ -18,7 +18,7 @@ #include #include "src/litert/kernel/cpu/string/skip_gram.h" #include "src/litert/kernel_registry.h" -#include "nnacl/skip_gram_parameter.h" +#include "nnacl_c/skip_gram_parameter.h" #include "src/common/file_utils.h" #include "common/common_test.h" #include "src/common/log_adapter.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/cuda/batchtospace_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/cuda/batchtospace_tests.cc index 2b176a0f791389fa1e3aba8a29015c0ba95ac012..a3d61223ab50ed844b9f8030f9cb0b4865e9368d 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/cuda/batchtospace_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/cuda/batchtospace_tests.cc @@ -18,7 +18,7 @@ #include "schema/ops_generated.h" #include "src/extendrt/kernel/cuda/batchtospace.h" #include "ut/src/extendrt/kernel/cuda/common.h" -#include "nnacl/batch_to_space_parameter.h" +#include "nnacl_c/batch_to_space_parameter.h" namespace mindspore { class CudaTest_BatchToSpace : public CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc index bb53b44e4317c73f25b5219ec0358bedda2da1b1..bb591e81fb42e0a1bea190e5b76b0737374a7271 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc index 0e614fbbb3ce2eab4bc28cad9b9fb94bf4caef94..7f2b318013868708d51036723a470bd5d0cb5784 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/arg_min_max_parameter.h" +#include "nnacl_c/arg_min_max_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc index 99aaf1f6267a88a955043ab58567a6c31c1a062c..6ebfd6e190a2461e60fb53ad9ca7ae9c9461f591 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/arithmetic_self_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc index a5c56665766464af2cb3bb1d498a6379266a6b1d..19633ff7c5321b674e5098e63423a96ebf0ea768 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/batch_to_space_nd_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/batch_to_space_nd_tests.cc index d55c78564f1769ead59db1e1fe7dd426f2ba5c1c..76b0edf2f25a4b32465091da7aaf7cf44a992dd4 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/batch_to_space_nd_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/batch_to_space_nd_tests.cc @@ -15,7 +15,7 @@ */ #include #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/batch_to_space_parameter.h" +#include "nnacl_c/batch_to_space_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc index 6121eff5d00b01761e9d18451e501335d95a47b3..ad411f94d580bc8dd80184f6ebb6b65386d11cf9 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.cc index 2305511f012921726f6283d94a99b41bbab4fedd..a7783bf9bfefb58e771e821a00e375d202e547cd 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.cc @@ -18,7 +18,7 @@ #include #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/opencl_subgraph.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" using mindspore::kernel::KernelExec; using mindspore::kernel::OpenCLSubGraph; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.h b/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.h index a0b5be1e09d7e37900a72d1c8b02a69f52572b9a..81318af6ad95a05dc1ee96bb3ebc933f7a67b26b 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.h +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.h @@ -23,7 +23,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ir/dtype/type_id.h" #include "src/tensor.h" #include "src/litert/tensor_category.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc index 37536eb38adc97a3ea86cc0e724126f9ae172238..c1aef2226975a39d3c9dd5d7495a8120f0a56644 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_tests.cc index e45ec80636dd0552ad663d3431681f1dfbd35d0b..57ec7dffb820776c182dbd6963ede5662be5e5ce 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc index e143c1ee8dc1d7d04972586f099cae19bd004f6f..a7a5c405c9d62f34786408027c070e1321c7f3b7 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/crop_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/crop_tests.cc index bc4e9b78d02486f7e51a2b5fa58b5c75990544af..14569e4d86f481e2d76c2fb54f27c9bc00abe6d0 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/crop_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/crop_tests.cc @@ -15,7 +15,7 @@ */ #include #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/crop_parameter.h" +#include "nnacl_c/crop_parameter.h" namespace mindspore::lite::opencl::test { class TestOpenCL_Crop : public CommonTest {}; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc index a0988660ef20244dba3bed0ea59dbf7a629e55bd..42573cfc5e128f5956aa7add37c5886a640fcc58 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/fullconnection_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/fullconnection_tests.cc index 1a4aebe5780c121e1f735de66e3ddd7c507bd2e0..4ca48a2e88453c1c6459e8a06b2f107d037b47d6 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/fullconnection_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/fullconnection_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/gather_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/gather_tests.cc index d9c6a487fe76406d1289361750424b3710a8774b..12cf982ecab48e46e859af26c74d1fc5458126ea 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/gather_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/gather_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/gather_parameter.h" +#include "nnacl_c/gather_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/layer_norm_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/layer_norm_tests.cc index 6b8e23b2976b687ca527bae0b7b158fe5bb6cafb..c1b95dbd00b746d6babc4d3ece7e86e39b8720c4 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/layer_norm_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/layer_norm_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/layer_norm_parameter.h" +#include "nnacl_c/layer_norm_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc index 24a773c8ba42e264ba9a52bb8350df5718eb8c87..a84b161f3de14728a34ebc66da5c34852e156bcf 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/one_hot_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/one_hot_tests.cc index d741989a2844d93a1237b20a076aca20540ce703..d4c1659b85042b08f575b83f561b0bec4cf22ace 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/one_hot_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/one_hot_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/fp32/one_hot_fp32.h" +#include "nnacl_c/fp32/one_hot_fp32.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/pad_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/pad_tests.cc index e376e9220b158d07fde67e7779c18c7bf449ce08..07fb93b3a9e82797b9294daf107d3be357773005 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/pad_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/pad_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/pad_parameter.h" +#include "nnacl_c/pad_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/pooling_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/pooling_tests.cc index 731c3a4e19a2f59950e5d9dd16eb90c370fc4d67..6975db168df58beceb2f2356945c2a572334f33c 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/pooling_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/pooling_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/pooling_parameter.h" +#include "nnacl_c/pooling_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc index 8915cff76013f200e1f6105305f464a0b522bb64..688ece36b75ab845995e854732a8d214ceaeddf3 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/prelu_parameter.h" +#include "nnacl_c/prelu_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc index 538f9ce19c7faa3afc9d0e8de50739c2fb0485fe..67f0924afdd48e8c1bb43bca1ed82f87cc7ad032 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/reduce_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc index 38bf78c03d1bc0f5bafaad249be03b7d68b51979..c623f643b8cefaa59eeba0f5e4ee83d1fb4dd9c6 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/reshape_parameter.h" +#include "nnacl_c/reshape_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc index e849da46eeabb606154cae257f778c63ca0e8f53..ddfccc45fb5e4615818c39a73e2405154411cafd 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/scale_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/scale_tests.cc index c7423032f9c93964a579207df3d2b5d2708f3512..bf2b99ca97b3c7d1029ffd7bb42078d485a9a616 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/scale_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/scale_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/scale_parameter.h" +#include "nnacl_c/scale_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc index 0dc26b3defbce28aa3ed2ca8417c58930d11e216..fde49b6fe188c217f76d3e848e4ae36f382f62f9 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "nnacl/slice_parameter.h" +#include "nnacl_c/slice_parameter.h" #include "ut/src/runtime/kernel/opencl/common.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc index 420e341c132f344b6b6558fb089034ea9efde7e6..b05bb9b4037ee4a7382686de2b7e3d414dca9fab 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_batch_nd_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_batch_nd_tests.cc index 91da782331a97b24a971704b78bdd0d4321a295f..8784b57eb4124763ed10a2971c0be7d74c06397c 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_batch_nd_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_batch_nd_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_depth_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_depth_tests.cc index 3b6f31c32e9901cf40ca55815ce3c414551094ee..0265de503fd2ac53fac987e6fd34632bd82d1921 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_depth_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_depth_tests.cc @@ -14,9 +14,9 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/space_to_depth_parameter.h" -#include "nnacl/base/space_to_depth_base.h" -#include "nnacl/depth_to_space_parameter.h" +#include "nnacl_c/space_to_depth_parameter.h" +#include "nnacl_c/base/space_to_depth_base.h" +#include "nnacl_c/depth_to_space_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/sparse_to_dense_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/sparse_to_dense_tests.cc index 1aa86bcbde8407562b341b68f4a71e5255bbd056..344f087621bcc249bf8d6ebbd1f15a95f8bd0019 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/sparse_to_dense_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/sparse_to_dense_tests.cc @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/sparse_to_dense_parameter.h" +#include "nnacl_c/sparse_to_dense_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/split_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/split_tests.cc index 4e1ccbe2c203d92916435d00bb914da4f44d926d..ce1e3ef161dfb06f074c952307db19732d7508a8 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/split_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/split_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/stack_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/stack_tests.cc index 8fa436584b91c60cb8f63a8f5124f17ac3e579b3..1a2ed573e312e0f1d0389b3fdd36b3f2a9d255f3 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/stack_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/stack_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/stack_parameter.h" +#include "nnacl_c/stack_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/strided_slice_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/strided_slice_tests.cc index bd6b3d48d1b5afd0fa58384ae59f4823568700dd..a80a328d06d9e019cdac06064a3f7712c2abb221 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/strided_slice_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/strided_slice_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc index 12c147028c599291dc9309fd3d3393a5f680162a..4a39996101f3b93ba07f0dfe1e30b8e52f6839cf 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/runtime_pass_tests.cc b/mindspore-lite/test/ut/src/runtime/runtime_pass_tests.cc index e8eb57e963e8e5fddaae739b0c12bc216a68ff87..ea1a53509201205606d9405159fb701f0f5c3f6d 100644 --- a/mindspore-lite/test/ut/src/runtime/runtime_pass_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/runtime_pass_tests.cc @@ -17,10 +17,10 @@ #include "src/executor/kernel_exec.h" #include "src/litert/kernel_registry.h" #include "src/litert/runtime_pass.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/instance_norm_parameter.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/instance_norm_parameter.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/transpose_parameter.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/activation_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/activation_fusion_inout_test.cc index d5d7a39968f9a5e7613533117cae9e1706a39add..de4ee5bd2b0095952e8cc5a03270d3ff548504fa 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/activation_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/activation_fusion_inout_test.cc @@ -18,7 +18,7 @@ #include #include "tools/optimizer/fusion/activation_fusion.h" #include "test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/activation.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/add_concat_act_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/add_concat_act_fusion_inout_test.cc index 39260b4385e0f96d32e07690ca273b45a1112242..5f304d4901d61c01d63670f762f46c3a074e8fb6 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/add_concat_act_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/add_concat_act_fusion_inout_test.cc @@ -23,7 +23,7 @@ #include "include/backend/optimizer/optimizer.h" #include "include/backend/optimizer/pass_manager.h" #include "tools/optimizer/fusion/add_concat_activation_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/activation.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/add_fusion.h" diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_act_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_act_fusion_inout_test.cc index eb9a2eb178c0ab60ba3cc3993e56f576e26607b8..c53b7eaf78066e9125bd1c513a9d608f6f783fe4 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_act_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_act_fusion_inout_test.cc @@ -18,7 +18,7 @@ #include #include "tools/optimizer/fusion/conv_activation_fusion.h" #include "test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/activation.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_bias_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_bias_fusion_inout_test.cc index 8725b4928f449516f7c3e6e34aaac297eb3cc3f4..be04c79d792de5c4d9496c60bc3654dd0fa17a18 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_bias_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_bias_fusion_inout_test.cc @@ -18,7 +18,7 @@ #include #include "tools/optimizer/fusion/conv_biasadd_fusion.h" #include "test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.cc index 69b923630414ffdc8d9a665b239cca2e3fb9ed33..ac5766cfa7572f77aa9b79aad384ccc1f43823f8 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.cc @@ -20,7 +20,7 @@ #include "src/common/log_adapter.h" #include "ir/func_graph.h" #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { ValueNodePtr ConvFusionInoutTest::CreateConvPrimitiveValue() { diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.cc index e97815016a17f16e413bf26396a8de9a4f345235..e3fb942e4ca9e341ca8460ba65c6c8afd08eb9b9 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.cc @@ -22,7 +22,7 @@ #include "infer/cxx_api/conv2d_fusion.h" #include "infer/make_tuple.h" #include "infer/return.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_adapter.h" #include "tools/common/tensor_util.h" #include "tools/optimizer/common/gllo_utils.h" diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_act_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_act_fusion_inout_test.cc index 1a509bfd7d3eaed9ece16e6edc9607f32be27390..e6ea162de504f70e3ab216e0a4c7cace6e6621f0 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_act_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_act_fusion_inout_test.cc @@ -18,7 +18,7 @@ #include #include "tools/optimizer/fusion/matmul_activation_fusion.h" #include "test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/mat_mul_fusion.h" #include "infer/cxx_api/activation.h" diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.cc index fa531e8ea536547d3f6ff9a6f73406560763f219..fefb608d9d06a76e0dfce18ebb5822a4119c8dc4 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.cc @@ -20,7 +20,7 @@ #include "src/common/log_adapter.h" #include "ir/func_graph.h" #include "infer/cxx_api/mat_mul_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { CNodePtr MatMulFusionInoutTest::AddMatMul(const FuncGraphPtr &graph, const AnfNodePtr &input1, const AnfNodePtr &input2, diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.h b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.h index d627f5a5a6f8d36d5b1082e35f92d129ca1339c4..578652883232083ede3c1fd1e9508d7a4ec0ff8e 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.h +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.h @@ -20,7 +20,7 @@ #include #include "test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.h" #include "ir/anf.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "include/backend/optimizer/pass.h" #include "include/backend/optimizer/optimizer.h" #include "include/backend/optimizer/pass_manager.h" diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_mul_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_mul_fusion_inout_test.cc index 3efe5807bc8d37cc834bf3d2db3c63a2375c7cc8..5e6ac368d320b84595dc07efa428060fbce7cd3e 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_mul_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_mul_fusion_inout_test.cc @@ -18,7 +18,7 @@ #include #include "tools/optimizer/fusion/matmul_mul_fusion.h" #include "test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/mul_fusion.h" #include "infer/cxx_api/mat_mul_fusion.h" diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/trans_matmul_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/trans_matmul_fusion_inout_test.cc index 961c8be58b791bc04e1adef8ffff35e9577c5cc9..1d40ce2cf91395d41b52b85e169a0ba1b129ba55 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/trans_matmul_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/trans_matmul_fusion_inout_test.cc @@ -19,7 +19,7 @@ #include "tools/optimizer/fusion/transpose_matmul_fusion.h" #include "tools/optimizer/common/gllo_utils.h" #include "test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" namespace mindspore { diff --git a/mindspore-lite/tools/benchmark/CMakeLists.txt b/mindspore-lite/tools/benchmark/CMakeLists.txt index 0af2d403a5d9b47d04186ef1f6fb047e97a65137..dcb83dd26b8b655b62ba9cc22180333ca680643a 100644 --- a/mindspore-lite/tools/benchmark/CMakeLists.txt +++ b/mindspore-lite/tools/benchmark/CMakeLists.txt @@ -61,7 +61,7 @@ if(MSLITE_EXPORT_COMPUTE_IR) set(BENCHMARK_LINK_LIB ${BENCHMARK_LINK_LIB} mindspore_lite_drawer) endif() -include_directories(${OPS_DIR}/kernel/cpu) +include_directories(${NNACL_DIR}/../) set(COMMON_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/string_util.cc @@ -69,7 +69,7 @@ set(COMMON_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/config_file.cc - ${OPS_DIR}/kernel/cpu/nnacl/nnacl_common.c + ${NNACL_DIR}/nnacl_common.c ) include_directories(${TOP_DIR}/mindspore-lite) diff --git a/mindspore-lite/tools/benchmark/benchmark_base.h b/mindspore-lite/tools/benchmark/benchmark_base.h index ab373b25716abf271202c31ba87b44e83851a923..d70491abec58135e678d00e9817d80310dfcff06 100644 --- a/mindspore-lite/tools/benchmark/benchmark_base.h +++ b/mindspore-lite/tools/benchmark/benchmark_base.h @@ -41,7 +41,7 @@ #include "src/common/utils.h" #include "ir/dtype/type_id.h" #include "schema/model_generated.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { #define BENCHMARK_LOG_ERROR(str) \ diff --git a/mindspore-lite/tools/benchmark/benchmark_unified_api.cc b/mindspore-lite/tools/benchmark/benchmark_unified_api.cc index 0d902d2df10ae18cbc683344ca8046a6974e0a9e..4b26f6424de03a686e0163dd2fa840ede222eff0 100644 --- a/mindspore-lite/tools/benchmark/benchmark_unified_api.cc +++ b/mindspore-lite/tools/benchmark/benchmark_unified_api.cc @@ -26,7 +26,7 @@ #include "src/common/common.h" #include "src/tensor.h" #include "tools/common/string_util.h" -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" #ifdef ENABLE_ARM64 #include #include diff --git a/mindspore-lite/tools/benchmark_train/CMakeLists.txt b/mindspore-lite/tools/benchmark_train/CMakeLists.txt index a915c5812fe088dfa6e91a2e9d311d41facfe080..75ce04079b1690ab832710f4e1b4351adea19316 100644 --- a/mindspore-lite/tools/benchmark_train/CMakeLists.txt +++ b/mindspore-lite/tools/benchmark_train/CMakeLists.txt @@ -5,7 +5,7 @@ set(COMMON_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc ) - +include_directories(${NNACL_DIR}/../) set(TEST_SRC ${CMAKE_CURRENT_SOURCE_DIR}/main.cc ${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc diff --git a/mindspore-lite/tools/common/func_graph_subgraph.cc b/mindspore-lite/tools/common/func_graph_subgraph.cc index 76310685be59ca2e902234dbadd07b7a32712628..6abd2edaab7ded7ce8f04889c63e3661cf54427f 100644 --- a/mindspore-lite/tools/common/func_graph_subgraph.cc +++ b/mindspore-lite/tools/common/func_graph_subgraph.cc @@ -27,7 +27,7 @@ #include "tools/common/graph_util.h" #include "tools/optimizer/common/gllo_utils.h" #include "infer/cxx_api/partial_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" namespace mindspore::lite { diff --git a/mindspore-lite/tools/common/graph_util.cc b/mindspore-lite/tools/common/graph_util.cc index 43acd95b74e12a112df74c795c34d921aa3817dc..3e407ab8ef21d71fc56b767ee666450171a17713 100644 --- a/mindspore-lite/tools/common/graph_util.cc +++ b/mindspore-lite/tools/common/graph_util.cc @@ -28,7 +28,7 @@ #include "tools/common/tensor_util.h" #include "src/common/log_adapter.h" #include "src/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/make_tuple.h" #include "tools/converter/converter_context.h" #include "tools/optimizer/common/gllo_utils.h" diff --git a/mindspore-lite/tools/common/graph_util.h b/mindspore-lite/tools/common/graph_util.h index 6d68c92dcf2d881f3b74ec03bb50e3450931c8a8..2d15581c5edde93df1a45d4b427de87de0087a12 100644 --- a/mindspore-lite/tools/common/graph_util.h +++ b/mindspore-lite/tools/common/graph_util.h @@ -35,7 +35,7 @@ #include "src/common/graph_util.h" #include "ir/anf.h" #include "ir/func_graph.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/common/node_util.h" #include "tools/converter/cxx_api/converter_para.h" diff --git a/mindspore-lite/tools/common/meta_graph_serializer.cc b/mindspore-lite/tools/common/meta_graph_serializer.cc index 7f30ea7f789b086295266a7c8e63d3a6b330ade7..0c0896045e44b5b6ef7ad13c5f8dba4fc7d2817b 100644 --- a/mindspore-lite/tools/common/meta_graph_serializer.cc +++ b/mindspore-lite/tools/common/meta_graph_serializer.cc @@ -21,7 +21,7 @@ #endif #include "flatbuffers/flatbuffers.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ir/dtype/type_id.h" #include "src/common/utils.h" #include "include/errorcode.h" diff --git a/mindspore-lite/tools/common/meta_graph_utils.cc b/mindspore-lite/tools/common/meta_graph_utils.cc index a4378267b9ee05276e0f33c3862eaa9dc7ab1fbd..7456e613c1987c7144b9309597c5904fbf04e259 100644 --- a/mindspore-lite/tools/common/meta_graph_utils.cc +++ b/mindspore-lite/tools/common/meta_graph_utils.cc @@ -19,7 +19,7 @@ #include #include "inner/model_generated.h" #include "src/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { namespace { size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx) { diff --git a/mindspore-lite/tools/common/node_util.cc b/mindspore-lite/tools/common/node_util.cc index c8774ad220dc4aa08340f66ece3464f122a4ba0c..3e26389a55446e1635391d2da4d0cb3196508c45 100644 --- a/mindspore-lite/tools/common/node_util.cc +++ b/mindspore-lite/tools/common/node_util.cc @@ -31,7 +31,7 @@ #include "mindspore/ops/infer/switch.h" #include "mindspore/ops/infer/call.h" #include "mindspore/ops/infer/cxx_api/partial_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/common/opengl_util.h b/mindspore-lite/tools/common/opengl_util.h index 2ff70a81f8b2d0318bbd09894cb8b66b5fa3cc72..9d13fae8954a5646a7bc30a2d3f0b0daa965b67c 100644 --- a/mindspore-lite/tools/common/opengl_util.h +++ b/mindspore-lite/tools/common/opengl_util.h @@ -21,7 +21,7 @@ #include #include #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #if defined(GPU_OPENCL) && defined(__ANDROID__) && defined(ENABLE_ARM64) #include "EGL/egl.h" diff --git a/mindspore-lite/tools/common/statistic_utils.cc b/mindspore-lite/tools/common/statistic_utils.cc index c1ca12841083d1bd7bb8ec7f6590ebc4d76fb478..a6d42560892b5c1e6b77bba77cac72ecb67ec936 100644 --- a/mindspore-lite/tools/common/statistic_utils.cc +++ b/mindspore-lite/tools/common/statistic_utils.cc @@ -16,7 +16,7 @@ #include "tools/common/statistic_utils.h" #if defined(ENABLE_AVX) && defined(__linux__) -#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" #ifdef _MSC_VER #include #else diff --git a/mindspore-lite/tools/common/statistic_utils.h b/mindspore-lite/tools/common/statistic_utils.h index f4d8ab00a6c1a32d85bd0a7a40032404e9aa5c94..93e8cfb7f497a25d0a75cbc58c2734c40d4d3a5a 100644 --- a/mindspore-lite/tools/common/statistic_utils.h +++ b/mindspore-lite/tools/common/statistic_utils.h @@ -25,7 +25,7 @@ #include #include "include/errorcode.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindapi/base/type_id.h" namespace mindspore::lite { diff --git a/mindspore-lite/tools/common/tensor_util.cc b/mindspore-lite/tools/common/tensor_util.cc index 319a682dc6b9dc06506ee57ab4a8027f8c55aa5b..a4a30b3f682afb3bf0b387cebfa34448378e7e28 100644 --- a/mindspore-lite/tools/common/tensor_util.cc +++ b/mindspore-lite/tools/common/tensor_util.cc @@ -19,7 +19,7 @@ #include "src/common/utils.h" #include "tools/common/graph_util.h" #include "abstract/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { namespace { diff --git a/mindspore-lite/tools/converter/CMakeLists.txt b/mindspore-lite/tools/converter/CMakeLists.txt index bbdf5a362927bdf2d89896d9fcd96a5ae3981524..54b14698a7b2f5046ecf45d87a54843d2d30b869 100644 --- a/mindspore-lite/tools/converter/CMakeLists.txt +++ b/mindspore-lite/tools/converter/CMakeLists.txt @@ -14,7 +14,7 @@ endif() include(${LITE_DIR}/cmake/ccsrc_converter.cmake) -include_directories(${TOP_DIR}/mindspore/mindspore/ops/kernel/cpu) +include_directories(${NNACL_DIR}/..) file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc diff --git a/mindspore-lite/tools/converter/adapter/acl/common/utils.cc b/mindspore-lite/tools/converter/adapter/acl/common/utils.cc index e20e5097556aeed1cccebb9a35e0729132d71113..4164288a33bf9dfd2f7a16264f92430c7db5a134 100644 --- a/mindspore-lite/tools/converter/adapter/acl/common/utils.cc +++ b/mindspore-lite/tools/converter/adapter/acl/common/utils.cc @@ -25,7 +25,7 @@ #include "include/common/utils/utils.h" #include "src/common/log_util.h" #include "ir/func_graph.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/adapter/acl/infer/custom_infer.cc b/mindspore-lite/tools/converter/adapter/acl/infer/custom_infer.cc index 925123cc5994fd13b93eace9e13ef9b008a0cf5d..6582cfad2351f781e584d3c67f4db67356879235 100644 --- a/mindspore-lite/tools/converter/adapter/acl/infer/custom_infer.cc +++ b/mindspore-lite/tools/converter/adapter/acl/infer/custom_infer.cc @@ -20,7 +20,7 @@ #include "include/registry/register_kernel_interface.h" #include "common/log_adapter.h" #include "tools/converter/adapter/acl/common/acl_types.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/common.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/adapter/acl/infer/flash_attention_infer.cc b/mindspore-lite/tools/converter/adapter/acl/infer/flash_attention_infer.cc index f829263bcdbf45f94668854b288fdb6c6edb8905..3b738d7fcf7cb09dc2ae06f5f700582c3f595f19 100644 --- a/mindspore-lite/tools/converter/adapter/acl/infer/flash_attention_infer.cc +++ b/mindspore-lite/tools/converter/adapter/acl/infer/flash_attention_infer.cc @@ -19,7 +19,7 @@ #include "include/registry/register_kernel_interface.h" #include "common/log_adapter.h" #include "tools/converter/adapter/acl/common/acl_types.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/tools/converter/adapter/acl/infer/forward_rasterize_infer.cc b/mindspore-lite/tools/converter/adapter/acl/infer/forward_rasterize_infer.cc index 4fbf6df89e7ed7ee7ec04bcc2d226dbb71256e15..b85d4eb42880319b13382b7230c3fc8afe047764 100644 --- a/mindspore-lite/tools/converter/adapter/acl/infer/forward_rasterize_infer.cc +++ b/mindspore-lite/tools/converter/adapter/acl/infer/forward_rasterize_infer.cc @@ -19,7 +19,7 @@ #include "include/registry/register_kernel_interface.h" #include "common/log_adapter.h" #include "tools/converter/adapter/acl/common/acl_types.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/arithmetic_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/arithmetic_mapper.cc index 99c655d6e541997b1fa1e41ea824c51829fdbf37..e49b0abf5c5288ebdedb728db5f840256171fb8a 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/arithmetic_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/arithmetic_mapper.cc @@ -24,7 +24,7 @@ #include "src/common/log_util.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/cast_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/cast_mapper.cc index 3d53d91210261feb8d2db23bdc22b96afc93579e..de8ec1269cd48b6bd303ff49b8509eef89250994 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/cast_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/cast_mapper.cc @@ -20,7 +20,7 @@ #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "tools/converter/adapter/acl/common/utils.h" #include "tools/lite_exporter/fetch_content.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_name_c.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/constant_of_shape_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/constant_of_shape_mapper.cc index 4962713c2184022e4612a655a3d7542d9732e18a..7ae050f428a3f0fb1f7703fdd48d6721f38969cf 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/constant_of_shape_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/constant_of_shape_mapper.cc @@ -24,7 +24,7 @@ #include "ops_utils/op_utils.h" #include "src/common/log_util.h" #include "tools/common/tensor_util.h" -#include "mindspore/ops/kernel/cpu/nnacl/base/cast_base.h" +#include "nnacl_c/base/cast_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/conv2d_transpose_fusion_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/conv2d_transpose_fusion_mapper.cc index 4101bd22680aae2e4053100150c7760436f496b4..b89713802d12f25a9e60604d435364a899f52f58 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/conv2d_transpose_fusion_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/conv2d_transpose_fusion_mapper.cc @@ -21,7 +21,7 @@ #include "tools/converter/adapter/acl/common/utils.h" #include "include/registry/converter_context.h" #include "tools/converter/adapter/acl/mapper/tbe_op_def.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "src/common/log_util.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/conv_base_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/conv_base_mapper.cc index 7b47bf6c13104c20c5c2a5383e3e0178808e1176..2580324cff4e6d86ee58a6b34f62772acfd5d68c 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/conv_base_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/conv_base_mapper.cc @@ -18,7 +18,7 @@ #include #include #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "utils/check_convert_utils.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/gather_d_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/gather_d_mapper.cc index 72b69e3ea8de89dd2c8e4dd3c94d62d7e9c74a5a..ac97661101eb7a2c3a32bfb233cc3703bbc54904 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/gather_d_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/gather_d_mapper.cc @@ -18,7 +18,7 @@ #include #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_name_g.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/gather_fusion_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/gather_fusion_mapper.cc index 0e7bdfe7c969256774bf0ca1f0f2cceb44bbed9f..6420779ea4171f99b78ca231e665080bbb4760bc 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/gather_fusion_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/gather_fusion_mapper.cc @@ -20,7 +20,7 @@ #include #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "tools/converter/adapter/acl/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_name_g.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/gru_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/gru_mapper.cc index 4afcd1383405f007b7da79246eb70356d33306bd..831afebb081fe14a2d4156733aa3498e44c5f3e8 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/gru_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/gru_mapper.cc @@ -21,7 +21,7 @@ #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "tools/converter/adapter/acl/mapper/tbe_op_def.h" #include "tools/converter/adapter/acl/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_name.h" #include "src/common/log_util.h" #include "tools/optimizer/common/gllo_utils.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/lstm_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/lstm_mapper.cc index 6ea31068388c3a707125d745031fc6fcdc1b6959..968a9b0110cf724fda71faf2a6269f886bf32cd6 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/lstm_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/lstm_mapper.cc @@ -22,7 +22,7 @@ #include "tools/converter/adapter/acl/mapper/tbe_op_def.h" #include "tools/converter/adapter/acl/common/utils.h" #include "infer/lstm.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_name.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/matmul_fusion_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/matmul_fusion_mapper.cc index c7a9807c447fbd6580f220d67eb2645d6d67b404..04bc0e3c083c209f89d15647bf7e81ec3ff4764e 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/matmul_fusion_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/matmul_fusion_mapper.cc @@ -31,7 +31,7 @@ #include "mindspore/ops/op_def/op_name.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops/base_operator.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.cc index d7ef4e4029739512fbb717b1b7d414740112ac15..fc2813b49594583d40f68d88bb40e25ee173664e 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.cc @@ -21,7 +21,7 @@ #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "tools/converter/adapter/acl/mapper/tbe_op_def.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/onehot_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/onehot_mapper.cc index ce89b4d6b9919345cbb46e0465b28e2e402af86e..74a0453e111ebeab54606745c829935b9785d6f2 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/onehot_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/onehot_mapper.cc @@ -20,7 +20,7 @@ #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "include/registry/converter_context.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_name_o.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/primitive_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/primitive_mapper.cc index c13b4d041e6250e345e7a1217055321a46ca05d7..e9a728502eb4e7dfb521c9f0e1774f18f2f20386 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/primitive_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/primitive_mapper.cc @@ -26,7 +26,7 @@ #include "ops_utils/op_utils.h" #include "infer/cxx_api/avg_pool_fusion.h" #include "infer/cxx_api/max_pool_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/quant_dtype_cast_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/quant_dtype_cast_mapper.cc index b147832a007e0f486998d70953f1602d0293e370..902af1107e90d6bcefa01e4a9ef27fbe20b84b3c 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/quant_dtype_cast_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/quant_dtype_cast_mapper.cc @@ -24,7 +24,7 @@ #include "src/common/log_util.h" #include "mindspore/ops/op_def/op_name.h" #include "infer/quant_dtype_cast.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/reshape_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/reshape_mapper.cc index 5faba4a5dc3dae23c665303b7b44ea23ea350111..4a363a586dea9034b7b190b85e053761dd0748bb 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/reshape_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/reshape_mapper.cc @@ -21,7 +21,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "src/common/log_util.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/adapter/acl/common/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_name_r.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/resize_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/resize_mapper.cc index 1c8af3f0c4fc45fa112dc20800fa36c6a46a25ad..d52622b7cd3adb726a0c8b08f6673733cbc62b9a 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/resize_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/resize_mapper.cc @@ -28,7 +28,7 @@ #include "ops_utils/op_utils.h" #include "src/common/log_util.h" #include "mindspore/ops/op_def/op_name.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/stridedslice_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/stridedslice_mapper.cc index 1d58cc8e75a9a3f1c8137a80480f5dd2d56d84e0..9b252148e3f027657c7b1dff8d3c86fc7e64b5a1 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/stridedslice_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/stridedslice_mapper.cc @@ -20,7 +20,7 @@ #include "tools/converter/adapter/acl/mapper/tbe_op_def.h" #include "include/registry/converter_context.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_name_s.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/tile_fusion_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/tile_fusion_mapper.cc index bb0d07ca91a8c9ec8c04fa55e1dca915e8256851..2f02609cbdc512c0772ea4139354a157d3decf57 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/tile_fusion_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/tile_fusion_mapper.cc @@ -22,7 +22,7 @@ #include "tools/converter/adapter/acl/common/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/lite_exporter/fetch_content.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/topk_fusion_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/topk_fusion_mapper.cc index 60a488003f6f28f8983d80952651f5b2b81b5f0d..99e30517ca1c10d234c2117dd274a44119ab80c2 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/topk_fusion_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/topk_fusion_mapper.cc @@ -23,7 +23,7 @@ #include "src/common/log_util.h" #include "infer/topk.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/transpose_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/transpose_mapper.cc index b99b878389f5c9c78fb548d828d45ae94b519d45..4dd96725f437fa3a381c23ae966f7154fe9edea4 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/transpose_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/transpose_mapper.cc @@ -21,7 +21,7 @@ #include #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "tools/converter/adapter/acl/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_name_t.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/upsample_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/upsample_mapper.cc index d75fd3cef08de1ebb5e4c2e4938f88152b5790a0..fd383b25ed8f46c4cf2176ca995d45fd613a9895 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/upsample_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/upsample_mapper.cc @@ -21,7 +21,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "src/common/log_util.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/adapter/acl/common/utils.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/where_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/where_mapper.cc index f73d66ec17c5710742df63be36446a060d2bfa75..0aff2c4f5f685f76a425f337258496697ba39d2a 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/where_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/where_mapper.cc @@ -19,7 +19,7 @@ #include "tools/converter/adapter/acl/common/utils.h" #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "tools/converter/adapter/acl/mapper/tbe_op_def.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc b/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc index f1752372f6952403f414d48e9592589aef456e48..197203c0b389ed3098afc6c42aafac1777d9066f 100644 --- a/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc +++ b/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc @@ -38,7 +38,7 @@ #include "infer/standard_normal.h" #include "infer/tuple_get_item.h" #include "cxx_api/model/acl/model_converter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/utils.h" #include "src/common/log_util.h" #include "src/common/file_utils.h" diff --git a/mindspore-lite/tools/converter/anf_transform.cc b/mindspore-lite/tools/converter/anf_transform.cc index 1501056291fd028f8765c587f1431f37f6222b1c..406986a0d80cc0fdef507ea19f8009aaca469bc6 100644 --- a/mindspore-lite/tools/converter/anf_transform.cc +++ b/mindspore-lite/tools/converter/anf_transform.cc @@ -22,7 +22,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_adapter.h" #include "tools/converter/optimizer_manager.h" #include "tools/optimizer/common/gllo_utils.h" diff --git a/mindspore-lite/tools/converter/anf_transform_for_ge.cc b/mindspore-lite/tools/converter/anf_transform_for_ge.cc index 9f6b9bf05276d0e52ac474b44036166b50998929..41d8cc61605190484670c9d3d720d26b6c40f8bb 100644 --- a/mindspore-lite/tools/converter/anf_transform_for_ge.cc +++ b/mindspore-lite/tools/converter/anf_transform_for_ge.cc @@ -22,7 +22,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_adapter.h" #include "tools/converter/optimizer_manager.h" #include "tools/optimizer/common/gllo_utils.h" diff --git a/mindspore-lite/tools/converter/config_parser/acl_option_param_parser.cc b/mindspore-lite/tools/converter/config_parser/acl_option_param_parser.cc index aab7d32d2ccbf4c1916985764ad8d03091e84a5e..1d47c900b906647b7f021c517ae65356908132e2 100644 --- a/mindspore-lite/tools/converter/config_parser/acl_option_param_parser.cc +++ b/mindspore-lite/tools/converter/config_parser/acl_option_param_parser.cc @@ -20,7 +20,7 @@ #include "tools/common/string_util.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/converter.cc b/mindspore-lite/tools/converter/converter.cc index 5d8291505171d2cab19528d69225ecbea3a129d2..196d7287991890f9a1b5af64422f553c8722718c 100644 --- a/mindspore-lite/tools/converter/converter.cc +++ b/mindspore-lite/tools/converter/converter.cc @@ -36,7 +36,7 @@ #include "src/common/log_util.h" #include "tools/converter/parser/parser_utils.h" #include "tools/converter/import/mindspore_importer.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/micro/coder/coder.h" #include "src/common/prim_util.h" #include "src/common/version_manager.h" diff --git a/mindspore-lite/tools/converter/converter_funcgraph.cc b/mindspore-lite/tools/converter/converter_funcgraph.cc index a62d6d866a4332c751874164cf0f8954122aa767..1bfe5f9e27c951b81b0ebf45c26de118c22df24e 100644 --- a/mindspore-lite/tools/converter/converter_funcgraph.cc +++ b/mindspore-lite/tools/converter/converter_funcgraph.cc @@ -34,7 +34,7 @@ #include "src/common/log_util.h" #include "tools/converter/parser/parser_utils.h" #include "tools/converter/import/mindspore_importer.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/micro/coder/coder.h" #include "src/common/prim_util.h" #include "src/common/version_manager.h" diff --git a/mindspore-lite/tools/converter/converter_packed_node.cc b/mindspore-lite/tools/converter/converter_packed_node.cc index beb91a4ccce937f70d6adc9ffbccdda9336239fb..45a6a617d7d52159526a947eb0e8a00b42551242 100644 --- a/mindspore-lite/tools/converter/converter_packed_node.cc +++ b/mindspore-lite/tools/converter/converter_packed_node.cc @@ -23,7 +23,7 @@ #include "mindspore/ops/op_def/op_name.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32.h" #include "src/litert/kernel/cpu/nnacl/nnacl_kernel.h" -#include "nnacl/kernel/matmul_struct.h" +#include "nnacl_c/kernel/matmul_struct.h" namespace mindspore { namespace { diff --git a/mindspore-lite/tools/converter/export_model.cc b/mindspore-lite/tools/converter/export_model.cc index 18b38208da0065cb175a8ac08fea5742d4ca2be0..47b23f35207168a62a4602680e89bf2800731617 100644 --- a/mindspore-lite/tools/converter/export_model.cc +++ b/mindspore-lite/tools/converter/export_model.cc @@ -34,7 +34,7 @@ #include "tools/converter/parser/parser_utils.h" #include "tools/optimizer/graph/control_flow_pass.h" #include "tools/optimizer/graph/clip_convert_activation_pass.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/tools/converter/import/mindir_adjust.cc b/mindspore-lite/tools/converter/import/mindir_adjust.cc index 8c727c8fcaf35318aaad1179daf2fe59c6deb9ee..fd18d283755bfeee9720e100d091d9a7d029c16b 100644 --- a/mindspore-lite/tools/converter/import/mindir_adjust.cc +++ b/mindspore-lite/tools/converter/import/mindir_adjust.cc @@ -26,7 +26,7 @@ #include "src/common/log_adapter.h" #include "src/common/quant_utils.h" #include "tools/converter/parser/parser_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "infer/fake_quant_param.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" diff --git a/mindspore-lite/tools/converter/import/mindir_control_flow_adjust.cc b/mindspore-lite/tools/converter/import/mindir_control_flow_adjust.cc index 502dd06323fdd7e8218ff2acfc76ed2823adb95a..7b38e03210849bcc536afe9286b98a4b35ac57eb 100644 --- a/mindspore-lite/tools/converter/import/mindir_control_flow_adjust.cc +++ b/mindspore-lite/tools/converter/import/mindir_control_flow_adjust.cc @@ -27,7 +27,7 @@ #include "tools/common/node_util.h" #include "tools/converter/parser/parser_utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" namespace { constexpr const int kSwitchTruePartialIndex = 2; diff --git a/mindspore-lite/tools/converter/import/mindspore_importer.cc b/mindspore-lite/tools/converter/import/mindspore_importer.cc index c496034ab016388ed6eccec138d92f2076d80ccb..233f923100ba3ed1fa579de48f379f0845721ecb 100644 --- a/mindspore-lite/tools/converter/import/mindspore_importer.cc +++ b/mindspore-lite/tools/converter/import/mindspore_importer.cc @@ -37,7 +37,7 @@ #include "tools/converter/parser/unify_format.h" #include "tools/converter/parser/lstm_adjust_pass.h" #include "tools/optimizer/graph/redundant_op_remove_pass.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/common.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" diff --git a/mindspore-lite/tools/converter/import/primitive_adjust.cc b/mindspore-lite/tools/converter/import/primitive_adjust.cc index 308d18fb5c81ba612ef0d8c7cb00ef9d446c8f23..c9760a749281476ffe5e52d89a47c7ffeae578da 100644 --- a/mindspore-lite/tools/converter/import/primitive_adjust.cc +++ b/mindspore-lite/tools/converter/import/primitive_adjust.cc @@ -67,7 +67,7 @@ #include "infer/random_standard_normal.h" #include "infer/fill.h" #include "tools/converter/parser/parser_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" diff --git a/mindspore-lite/tools/converter/import/remove_public_primitive.cc b/mindspore-lite/tools/converter/import/remove_public_primitive.cc index 26cdfffa66b92f8c6401d7093f3ac7e751074187..c072441b2769b06f282d752f90cc82532e76c3db 100644 --- a/mindspore-lite/tools/converter/import/remove_public_primitive.cc +++ b/mindspore-lite/tools/converter/import/remove_public_primitive.cc @@ -20,7 +20,7 @@ #include #include "mindspore/ops/op_def/structure_ops.h" #include "tools/converter/parser/parser_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc b/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc index a551196dbc8d2f91a1d23339e127244c46078c4a..bb04530dd604b6b9b9ae4d304ac1e1972646299e 100644 --- a/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc +++ b/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc @@ -29,7 +29,7 @@ #include "tools/common/meta_graph_utils.h" #include "include/errorcode.h" #include "schema/inner/model_generated.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pattern.h b/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pattern.h index 1281053f24cbbba328cc5c89913557932ee325d8..fa1a39b5a4477c8e055a7ebe128d6060fc4ddbf9 100644 --- a/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pattern.h +++ b/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pattern.h @@ -24,7 +24,7 @@ #include #include "src/common/log_adapter.h" #include "schema/inner/model_generated.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h b/mindspore-lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h index 3913b38397d48b7de7f1cc9e450c93a42fc2a6cd..297391a82cd56c74effbbef7a0b80fdb34412f35 100644 --- a/mindspore-lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h +++ b/mindspore-lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h @@ -22,7 +22,7 @@ #include "tools/converter/optimizer.h" #include "tools/common/graph_util.h" #include "tools/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc b/mindspore-lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc index 8f650e0f4fb60060272535ead315af89ff235193..dde3e9227b66f0aabdeae2b87db24503ec7378a6 100644 --- a/mindspore-lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc +++ b/mindspore-lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc @@ -30,7 +30,7 @@ #include "tools/common/node_util.h" #include "src/common/string_utils.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::converter::kFmkTypeTf; namespace { diff --git a/mindspore-lite/tools/converter/micro/coder/allocator/memory_manager.cc b/mindspore-lite/tools/converter/micro/coder/allocator/memory_manager.cc index cee396e3547f800498ca2c386d5d0d4165929797..9883f4fffe80230dafc2af895a28eb1bf9f8eaa3 100644 --- a/mindspore-lite/tools/converter/micro/coder/allocator/memory_manager.cc +++ b/mindspore-lite/tools/converter/micro/coder/allocator/memory_manager.cc @@ -16,7 +16,7 @@ #include "tools/converter/micro/coder/allocator/memory_manager.h" #include -#include "mindspore/ops/kernel/cpu/nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/micro/coder/opcoders/op_coder.h" namespace mindspore::lite::micro { diff --git a/mindspore-lite/tools/converter/micro/coder/generator/component/common_component.cc b/mindspore-lite/tools/converter/micro/coder/generator/component/common_component.cc index daafa97a73d3138c90384e33c9586e7b1e3a3bf9..ab2e8f4284664a2d70f3bb9fc36635eb6354ceeb 100644 --- a/mindspore-lite/tools/converter/micro/coder/generator/component/common_component.cc +++ b/mindspore-lite/tools/converter/micro/coder/generator/component/common_component.cc @@ -21,7 +21,7 @@ #include "coder/utils/coder_utils.h" #include "coder/log.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/c_api/model_c.h" #include "coder/generator/component/const_blocks/license.h" #include "tools/common/string_util.h" diff --git a/mindspore-lite/tools/converter/micro/coder/generator/component/const_blocks/debug_utils.cc b/mindspore-lite/tools/converter/micro/coder/generator/component/const_blocks/debug_utils.cc index 7a3415815b302b95042233ae59ad94e198437359..d0adbc7fa8492731330f733846c01a0efd1d83de 100644 --- a/mindspore-lite/tools/converter/micro/coder/generator/component/const_blocks/debug_utils.cc +++ b/mindspore-lite/tools/converter/micro/coder/generator/component/const_blocks/debug_utils.cc @@ -41,7 +41,7 @@ const char debug_utils_h[] = R"RAW( #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #define MICRO_INFO(content, args...) \ { printf("[INFO] %s|%d: " #content "\r\n", __func__, __LINE__, ##args); } diff --git a/mindspore-lite/tools/converter/micro/coder/generator/component/train_component.cc b/mindspore-lite/tools/converter/micro/coder/generator/component/train_component.cc index 90bd64ca9aa7c61d1efcfa2496c7b0b859ba863b..efbb7964bd7d8786a2d2615376242e06dc23ce86 100644 --- a/mindspore-lite/tools/converter/micro/coder/generator/component/train_component.cc +++ b/mindspore-lite/tools/converter/micro/coder/generator/component/train_component.cc @@ -17,7 +17,7 @@ #include "coder/generator/component/train_component.h" #include #include "coder/utils/coder_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "coder/utils/type_cast.h" namespace mindspore::lite::micro { diff --git a/mindspore-lite/tools/converter/micro/coder/generator/generator.cc b/mindspore-lite/tools/converter/micro/coder/generator/generator.cc index e826b6811c49a9be5b3360ce4d644f415d2226b1..7eed5ce74b31b1d09e8bb013c96324e7059be403 100644 --- a/mindspore-lite/tools/converter/micro/coder/generator/generator.cc +++ b/mindspore-lite/tools/converter/micro/coder/generator/generator.cc @@ -571,8 +571,8 @@ int Generator::CodeRegKernelHFile() { MS_CHECK_TRUE(!cofs.bad(), "filed to open file"); MS_LOG(INFO) << "write " << reg_kernel_header; cofs << g_hwLicense; - cofs << "#include \"nnacl/tensor_c.h\"\n"; - cofs << "#include \"nnacl/custom_parameter.h\"\n\n"; + cofs << "#include \"nnacl_c/tensor_c.h\"\n"; + cofs << "#include \"nnacl_c/custom_parameter.h\"\n\n"; cofs << KernelRegistry::GetInstance()->GenKernelInterface(kCustomKernelName, kCustomKernelParam) << "\n"; return RET_OK; } diff --git a/mindspore-lite/tools/converter/micro/coder/log.h b/mindspore-lite/tools/converter/micro/coder/log.h index f22ea2b4b0db1c3fcd8e3703976321dae25eaafa..68b43c8e7f39ab68554f404929bc56348e59fb71 100644 --- a/mindspore-lite/tools/converter/micro/coder/log.h +++ b/mindspore-lite/tools/converter/micro/coder/log.h @@ -19,7 +19,7 @@ #include "src/common/log_adapter.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #define MS_CHECK_PTR(ptr) \ do { \ diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.cc index 5fd77feb8191e695b1346a340a16de28db2f888e..9e887788e2eff2b050de3513dd627a732bad3d5e 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.cc @@ -17,8 +17,8 @@ #include "coder/opcoders/base/conv2d_base_coder.h" #include #include -#include "nnacl/fp32/winograd_utils.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/fp32/winograd_utils.h" +#include "nnacl_c/int8/quantize.h" #include "coder/log.h" #include "src/litert/tensor_category.h" namespace mindspore::lite::micro { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.h index b03d9e146cf47f5eb7a2ccf2918d7d4dfef1cbdc..dd5ef61c3e73234deca82c20c8055f13a4e29b19 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.h @@ -23,7 +23,7 @@ #include #include "coder/opcoders/op_coder.h" #include "src/litert/kernel/cpu/base/layout_transform.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro { class Conv2DBaseCoder : public OperatorCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.cc index 738441b939c9056e534c57acc90af61ec34f8bbb..ef6131fabc186d122526db2d964bbcc10f048beb 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.cc @@ -15,8 +15,8 @@ */ #include "coder/opcoders/base/detection_post_process_base_coder.h" -#include "mindspore/ops/kernel/cpu/nnacl/op_base.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -125,8 +125,8 @@ int DetectionPostProcessBaseCoder::AllocateBuffer() { int DetectionPostProcessBaseCoder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/detection_post_process_parameter.h", - "nnacl/fp32/detection_post_process_fp32.h", + "nnacl_c/detection_post_process_parameter.h", + "nnacl_c/fp32/detection_post_process_fp32.h", "wrapper/base/detection_post_process_base_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.h index fc2b216a9e2ee8dbe940a3c35c2e0923f00983fe..0099d8f991dd827935cf89fbaf706913fb9037f3 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.h @@ -22,7 +22,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/detection_post_process_parameter.h" +#include "nnacl_c/detection_post_process_parameter.h" #include "coder/opcoders/serializers/serializer.h" namespace mindspore::lite::micro { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.cc index be1f041dc3e0944f214bc31e79361cdf352807f2..6cebd9dccb1f5475268c9dba93ef6ce20f7c2459 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.cc @@ -123,7 +123,7 @@ int DTypeCastCoder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/cast_base.h", + "nnacl_c/base/cast_base.h", }, { "cast_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.h index 03f9e54fb59fb61f4c1558d22c2fe36a7b474651..fcd1421e8fe1c9ec899f311727ef147a4d782342 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" #include "coder/opcoders/serializers/serializer.h" namespace mindspore::lite::micro { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/full_connection_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/full_connection_base_coder.h index c11a703b4d6ab26fbedd29a55e50925fdfd86ea6..51be7340b98a77fd59f9634ecd2fa4524b604d13 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/full_connection_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/full_connection_base_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro { class FullConnectionBaseCoder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.cc index 8d08edd542df1b4daed970ec2f04059f70deba3e..71c2d20558ff963b8a0e1ac6a77bcf13ba46bd7f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.cc @@ -47,7 +47,7 @@ int QuantDTypeCastCoder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/quant_dtype_cast_int8.h", + "nnacl_c/int8/quant_dtype_cast_int8.h", }, { "quant_dtype_cast_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.h index 3e24703e4c8394934fe0d195e4d0e33634f2d83c..0aa013220148710b548543112a89d87001e7e1dd 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" namespace mindspore::lite::micro { class QuantDTypeCastCoder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/reduce_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/reduce_base_coder.h index af8d6d21d4393e98dcd8e4673813f3ae4a4d9d9c..3b6b6c2fb8b1db79344eedbd25d120e41fa58594 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/reduce_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/reduce_base_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/reduce_parameter.h" namespace mindspore::lite::micro { class ReduceBaseCoder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/resize_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/resize_base_coder.h index bc6e5c5fcfe18304081449a2b0d0ba3867dad759..a6f78d7ba9f41cdf6918df0d9e72269f943be2a9 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/resize_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/resize_base_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" namespace mindspore::lite::micro { class ResizeBaseCoder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/softmax_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/softmax_base_coder.h index 0ec4945158ee1ee9221151c18f20b6195a23f8c0..c295082fb1aff9adf73caec8ae113a41b72699c5 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/softmax_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/softmax_base_coder.h @@ -20,8 +20,8 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::lite::micro { class SoftmaxBaseCoder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc index ee8873424e2d92e602960176a4a395254c8f3580..c319ee337f03a9a73d8a19249bd020afbcdbf64a 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc @@ -40,7 +40,7 @@ int StackFP32Coder::ReSize() { int StackFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/stack_base.h", + "nnacl_c/base/stack_base.h", }, { "stack_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h index 0807433266ad10c82ce2ed8db4edd1bfe8f2d554..33f98382d3b43c053c37e0d791a13cd0b8db9fe0 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/stack_parameter.h" +#include "nnacl_c/stack_parameter.h" namespace mindspore::lite::micro::nnacl { class StackFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc index 44622279a1a980ccc4cf9e35f59a29f8ebddc864..07f63c934395b1340d61397946976089cc83babf 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc @@ -176,7 +176,7 @@ int StridedSliceBaseCoder::DoCode(CoderContext *ctx) { inner_size_ = GetInnerSize(input_tensor_->data_type(), inner_); Collect(ctx, { - "nnacl/fp32/strided_slice_fp32.h", + "nnacl_c/fp32/strided_slice_fp32.h", "wrapper/base/strided_slice_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.h index 5e2f2d6f43414cbfabb001f6404d54fed486c569..87fbdf7bb3e3fdcf805ef2aa27f337e34e121b9f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STRIDED_SLICE_BASE_CODER_H_ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/fp32/strided_slice_fp32.h" +#include "nnacl_c/fp32/strided_slice_fp32.h" namespace mindspore::lite::micro { class StridedSliceBaseCoder final : public OperatorCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.cc index 11b27a4eef594813c4691f299b6a993a4700199f..b7e873584d8032721bce33e0d3a435f938a645ce 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.cc @@ -81,7 +81,7 @@ int StridedSliceDynamicBaseCoder::Prepare(CoderContext *context) { int StridedSliceDynamicBaseCoder::DoCode(CoderContext *ctx) { Collect(ctx, { - "nnacl/fp32/strided_slice_fp32.h", + "nnacl_c/fp32/strided_slice_fp32.h", }, { "strided_slice_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.h index 1368c4e00ac10bbe4b9f31f16140448326945587..d7553b6e1b4a5857eddae7ee86a800d4089301be 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.h @@ -19,8 +19,8 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/dynamic_parameter/strided_slice_dynamic_parameter.h" -#include "nnacl/strided_slice_parameter.h" -#include "nnacl/kernel/strided_slice.h" +#include "nnacl_c/strided_slice_parameter.h" +#include "nnacl_c/kernel/strided_slice.h" namespace mindspore::lite::micro { class StridedSliceDynamicBaseCoder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.cc index 3a6912dfb3263f548e5eaf30b0f7dcc5c9537c4d..e06094e077212b1afb9bffc51ac656358108e2a2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.cc @@ -60,7 +60,7 @@ int UnstackBaseCoder::Prepare(CoderContext *context) { int UnstackBaseCoder::DoCode(CoderContext *ctx) { Collect(ctx, { - "nnacl/base/unstack_base.h", + "nnacl_c/base/unstack_base.h", }, { "unstack_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.h index e095be3cfec02716c150fc273d43dab8173688ca..84c2e938336b33bc15faeac6626972099af10c74 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.h @@ -17,8 +17,8 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_UNSTACK_BASE_CODER_H_ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/base/unstack_base.h" -#include "nnacl/op_base.h" +#include "nnacl_c/base/unstack_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite::micro { class UnstackBaseCoder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/add_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/add_int8_coder.cc index 31fbde7ab0e2c1aa353e5428cf04dba4fd19cea9..2e92d877c002d930c534643f195a3066b551333c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/add_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/add_int8_coder.cc @@ -23,8 +23,8 @@ #include "coder/opcoders/serializers/serializer.h" #include "coder/utils/common.h" #include "mindspore/ops/op_def/array_ops.h" -#include "nnacl/arithmetic_parameter.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/int8/quantize.h" using mindspore::schema::PrimitiveType_AddFusion; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.cc index 431e78423ebe0c91c1c8ca812232c8aeb7872128..b7f0a556a0e0154d80e1b5d370256bc7e34b4f84 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.cc @@ -15,7 +15,7 @@ */ #include "coder/opcoders/cmsis-nn/int8/conv2d_base_coder.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::lite::micro::cmsis { int Conv2DBaseCoder::SetQuantArgs() { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.h index 7fb4937314890ecc3e648bd007c3ba48191911d7..c9bb04ee501527327352cfee53a1fd5f19719fa9 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/base/conv2d_base_coder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro::cmsis { class Conv2DBaseCoder : public micro::Conv2DBaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_int8_coder.h index c0409d7cc710d4c40dfee152c3d829d9350422ee..1ead129abbe7ede91ec4ccadd6fb52a3d2a44410 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/cmsis-nn/int8/conv2d_base_coder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro::cmsis { class Conv2DInt8Coder final : public Conv2DBaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/fullconnection_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/fullconnection_int8_coder.h index 8e4bdec4fa29144b9c44f0ff7e9c0265ada702b5..aa79e03bd409752daa5b45caed030e996fe5129a 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/fullconnection_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/fullconnection_int8_coder.h @@ -21,7 +21,7 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/base/full_connection_base_coder.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::lite::micro::cmsis { class FullConnectionInt8Coder final : public FullConnectionBaseCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/mul_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/mul_int8_coder.cc index 5f0fe01b956531052277ce418c58d5a0941d301d..4b6bb69ea2ed1a4b2b4025751910633ccae472cb 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/mul_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/mul_int8_coder.cc @@ -17,7 +17,7 @@ #include "coder/opcoders/cmsis-nn/int8/mul_int8_coder.h" #include #include "coder/opcoders/serializers/serializer.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" #include "coder/opcoders/file_collector.h" using mindspore::schema::PrimitiveType_MulFusion; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/pooling_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/pooling_int8_coder.h index 058009680bb0a4695c0f41f352bf2b8c10c1fd6f..daa157ed291a06e53fda61f96d869c8a5aabbf45 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/pooling_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/pooling_int8_coder.h @@ -21,7 +21,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/pooling_int8.h" +#include "nnacl_c/int8/pooling_int8.h" namespace mindspore::lite::micro::cmsis { class PoolingInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/custom/custom_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/custom/custom_coder.cc index 2a40de01d3329452a39fc4fbe1488266cb2afe30..f749d362823e23ede62fccc81a5d7bd0e69fa106 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/custom/custom_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/custom/custom_coder.cc @@ -23,7 +23,7 @@ #include "tools/converter/micro/coder/opcoders/op_coder_register.h" #include "tools/converter/micro/coder/opcoders/kernel_registry.h" #include "src/common/prim_util.h" -#include "nnacl/custom_parameter.h" +#include "nnacl_c/custom_parameter.h" using mindspore::schema::PrimitiveType_Custom; @@ -151,7 +151,7 @@ void CustomCoder::FreeTensors(Serializer *code, std::string array_name, size_t t } int CustomCoder::DoCode(CoderContext *const context) { - Collect(context, {"nnacl/custom_parameter.h", "nnacl/tensor_c.h", "src/registered_kernel.h"}, {}); + Collect(context, {"nnacl_c/custom_parameter.h", "nnacl_c/tensor_c.h", "src/registered_kernel.h"}, {}); Serializer code; MS_CHECK_RET_CODE(TransformTensors(&code, "inputs", input_tensors_), "Transform input tensors error!"); MS_CHECK_RET_CODE(TransformTensors(&code, "outputs", output_tensors_), "Transform output tensors error!"); diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc index b16b5e6d9d6848fe12b7b124f158f0e36036411f..d01d4c11afc8a90a0e3a4bd3c982a6b4c532ea73 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc @@ -33,7 +33,7 @@ int ActivationDynamicFP16Coder::Prepare(CoderContext *const context) { int ActivationDynamicFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp16/activation_fp16.h", + "nnacl_c/fp16/activation_fp16.h", }, { "activation_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_fp16_coder.cc index 0fdf0a7fe1c9e80b805153a6a0083d6c92762755..cbc12059ce0053c8cefca8c60e9642ebfa863f32 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_fp16_coder.cc @@ -35,7 +35,7 @@ int ActivationFP16Coder::DoCode(CoderContext *const context) { int count = input_tensor_->ElementsNum(); Collect(context, { - "nnacl/fp16/activation_fp16.h", + "nnacl_c/fp16/activation_fp16.h", }, { "activation_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc index 3228e7e277e9252004425c2201472869f4453cb5..001a9af1093004bde3f29584929532acc5a08f39 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc @@ -98,8 +98,8 @@ int ArithmeticDynamicFP16Coder::DoCode(CoderContext *const context) { NNaclFp32Serializer code; Collect(context, { - "nnacl/fp16/arithmetic_fp16.h", - "nnacl/base/broadcast_to.h", + "nnacl_c/fp16/arithmetic_fp16.h", + "nnacl_c/base/broadcast_to.h", }, { "arithmetic_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h index ca958d7385d33367cf65e711b843b860262cb9f5..6451cad21e252992cd2bb24eb84abb83382ebfd8 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h @@ -20,11 +20,11 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/base/cast_base.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/base/cast_base.h" +#include "nnacl_c/arithmetic_parameter.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/nnacl/dynamic_parameter/arithmetic_dynamic_parameter.h" -#include "nnacl/broadcast_to_parameter.h" +#include "nnacl_c/broadcast_to_parameter.h" namespace mindspore::lite::micro::nnacl { using mindspore::schema::PrimitiveType_AddFusion; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_fp16_coder.cc index c31f7172b9a0b6b104adfdd0e0fe095fba5eaa4e..adaa95eb72b445791d7bef22666a4199cb89950e 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_fp16_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp16/arithmetic_fp16_coder.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" -#include "nnacl/broadcast_to_parameter.h" +#include "nnacl_c/broadcast_to_parameter.h" #include "base/float16.h" namespace mindspore::lite::micro::nnacl { @@ -105,8 +105,8 @@ int ArithmeticFP16Coder::DoCode(CoderContext *const context) { NNaclFp32Serializer code; Collect(context, { - "nnacl/fp16/arithmetic_fp16.h", - "nnacl/base/broadcast_to.h", + "nnacl_c/fp16/arithmetic_fp16.h", + "nnacl_c/base/broadcast_to.h", }, { "arithmetic_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.cc index 4d7fe34a3bfae91cfb4c206cdfd60f163c86c435..02aaa731dcf4d9f73eb5cd7ae3e29fdd154fb128 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.cc @@ -59,7 +59,7 @@ int ArithmeticSelfFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp16/arithmetic_self_fp16.h", + "nnacl_c/fp16/arithmetic_self_fp16.h", }, { "arithmetic_self_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.h index e95229f7434f6027968ba64548d63451c1b4b628..aa2db70513ece971a340b0acc3f025e5803e71b2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/arithmetic_self_parameter.h" namespace mindspore::lite::micro::nnacl { using mindspore::schema::PrimitiveType_Abs; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc index 063ef0d8b65414c2264b58e57d9b228f817e5fb4..fa40afcbb8af1578afcb1c445e0b8eb7b7c2431d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc @@ -40,7 +40,7 @@ int ConcatDynamicFP16Coder::Prepare(CoderContext *const context) { int ConcatDynamicFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/concat_base.h", + "nnacl_c/base/concat_base.h", }, { "concat_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h index 6408403b1ae708917f272a3e2241ab90c88e37a8..678259462534476aa259177374ac0e242ef6d79b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" namespace mindspore::lite::micro::nnacl { class ConcatDynamicFP16Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.cc index fd9699633094c1dfcd87df3f1ec432637ae28787..52b11e206981aa6f0750e9494427d83967ab9d81 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.cc @@ -42,7 +42,7 @@ int ConcatFP16Coder::ReSize() { int ConcatFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/concat_base.h", + "nnacl_c/base/concat_base.h", }, { "concat_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.h index 6428ac6fe49edc11aa61171b3172c43abb78dbfd..e3cafe327f7b7ae3aed954b9ca7fbed88520f480 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/nnacl/fp32/concat_fp32_coder.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" namespace mindspore::lite::micro::nnacl { class ConcatFP16Coder final : public ConcatFP32Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc index 2c1e01afef7a708dea768599ec4bd20be107d68b..7ea7893d408b85e4849d948d9ff9bd297b18a247 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc @@ -18,9 +18,9 @@ #include "src/common/version_manager.h" #include "src/common/tensor_util.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/winograd_utils.h" -#include "nnacl/base/conv_common_base.h" -#include "nnacl/infer/conv2d_infer.h" +#include "nnacl_c/fp32/winograd_utils.h" +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/infer/conv2d_infer.h" #include "coder/shape_info_container.h" #include "coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h" #include "coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h index 78dd3ebf7d94598df790bba0ea55aeb1302c16b9..6852962fb82820eaf9b8713351248dfa4e2651ea 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h @@ -20,7 +20,7 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro::nnacl { class ConvDelegateDynamicFP16Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.cc index 117ba90fb1b47b1a41d430a14a7895bb432bf0e0..9d4cb9aac7418be5ed3e0fcc93ece76c0b51913f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.cc @@ -18,9 +18,9 @@ #include "src/common/version_manager.h" #include "src/common/tensor_util.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/winograd_utils.h" -#include "nnacl/base/conv_common_base.h" -#include "nnacl/infer/conv2d_infer.h" +#include "nnacl_c/fp32/winograd_utils.h" +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/infer/conv2d_infer.h" #include "coder/opcoders/nnacl/fp16/convolution_fp16_coder.h" #include "coder/opcoders/nnacl/fp16/conv_depthwise_fp16_coder.h" #include "coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.h index 094b01b65a92805bc39df46f5f53369d5cc32b09..923bbd2971c5a08956c09d034784b7816e4679fd 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.h @@ -20,7 +20,7 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro::nnacl { class ConvDelegateFP16Coder : public ConvDelegateCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.cc index 690d0e1bde6d309d76a82f4a64d88e0f9c23448c..05654d8857809d06949924a50cb80681c025bb77 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.cc @@ -90,8 +90,8 @@ void ConvolutionDepthwise3x3FP16Coder::CollectFilesForFunc(CoderContext *const c } Collect(context, { - "nnacl/fp16/conv_depthwise_fp16.h", - "nnacl/fp16/pack_fp16.h", + "nnacl_c/fp16/conv_depthwise_fp16.h", + "nnacl_c/fp16/pack_fp16.h", }, { "conv_depthwise_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.h index 50c07f4e1aaf66f467a493469ca14b83733bf5a0..fee118732da53d91b5f9b1ce6db117d026276e23 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/base/conv2d_base_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_fp16_coder.cc index 0140e2a4e40c050d99cb147e5a0666909b7d9e7f..72ccac96d1e0a8c11df54da6b7dd752320644a06 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_fp16_coder.cc @@ -65,9 +65,9 @@ void ConvolutionDepthwiseFP16Coder::CollectFilesForFunc(CoderContext *const cont } Collect(context, { - "nnacl/fp16/conv_depthwise_fp16.h", - "nnacl/fp16/pack_fp16.h", - "nnacl/fp16/activation_fp16.h", + "nnacl_c/fp16/conv_depthwise_fp16.h", + "nnacl_c/fp16/pack_fp16.h", + "nnacl_c/fp16/activation_fp16.h", }, { "conv_depthwise_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_sw_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_sw_fp16_coder.cc index 10acc3768fbdcebbb4897089fee8b00878b915c2..6a0aae3387dea5789491c843ed8a5e04e021267c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_sw_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_sw_fp16_coder.cc @@ -95,9 +95,9 @@ void ConvolutionDepthwiseSWFP16Coder::CollectFilesForFunc(CoderContext *const co } Collect(context, { - "nnacl/fp32/conv_depthwise_fp32.h", - "nnacl/fp16/conv_depthwise_fp16.h", - "nnacl/fp16/pack_fp16.h", + "nnacl_c/fp32/conv_depthwise_fp32.h", + "nnacl_c/fp16/conv_depthwise_fp16.h", + "nnacl_c/fp16/pack_fp16.h", }, { "conv_depthwise_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc index 06b3e4ef6eec8572562d7a95500648dc82c651ef..2c9544fc4280aecf28ceec9614b3a1bf0644a26a 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h" #include -#include "nnacl/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_utils.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" #include "coder/utils/coder_utils.h" @@ -231,11 +231,11 @@ void Convolution1x1DynamicFP16Coder::CollectFilesForFunc(CoderContext *const con } Collect(context, { - "nnacl/fp16/matmul_fp16.h", - "nnacl/conv_parameter.h", - "nnacl/op_base.h", - "nnacl/fp16/conv_fp16.h", - "nnacl/base/conv1x1_base.h", + "nnacl_c/fp16/matmul_fp16.h", + "nnacl_c/conv_parameter.h", + "nnacl_c/op_base.h", + "nnacl_c/fp16/conv_fp16.h", + "nnacl_c/base/conv1x1_base.h", }, { "common_func.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h index abc34ad2ee1ccec524ee7bf08cdb1de06fc8000a..a0558c9a81cbef03bbe08314ef3a23eaf2492b92 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h @@ -19,8 +19,8 @@ #include #include -#include "nnacl/conv_parameter.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "coder/opcoders/op_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.cc index 6c1aedea8f982acc70157e1bc7a81f28be095c63..b997b84075967c13d1e28f93b0ff313eaa26a4d7 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.cc @@ -17,7 +17,7 @@ #include "coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.h" #include #include -#include "nnacl/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_utils.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" @@ -148,11 +148,11 @@ void Convolution1x1FP16Coder::CollectFilesForFunc(CoderContext *const context) { } Collect(context, { - "nnacl/fp16/matmul_fp16.h", - "nnacl/conv_parameter.h", - "nnacl/op_base.h", - "nnacl/fp16/conv_fp16.h", - "nnacl/base/conv1x1_base.h", + "nnacl_c/fp16/matmul_fp16.h", + "nnacl_c/conv_parameter.h", + "nnacl_c/op_base.h", + "nnacl_c/fp16/conv_fp16.h", + "nnacl_c/base/conv1x1_base.h", "wrapper/base/micro_parameter.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.h index f1e88619a5909a699a03266c9fd16625da677b12..6c5afb1f6e72113594fe68e4ee6dd75dc5119ed1 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/base/conv2d_base_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "base/float16.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc index 19d1ab92cc042b0b7157e3a742a121e3507575e1..4f2782ea1bab8cf8cb572907a5c991229839c607 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc @@ -17,7 +17,7 @@ #include "coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h" #include #include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" -#include "nnacl/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_utils.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -131,10 +131,10 @@ void ConvolutionDynamicFP16Coder::CollectFilesForFunc(CoderContext *const contex }); Collect(context, { - "nnacl/fp16/matmul_fp16.h", - "nnacl/conv_parameter.h", - "nnacl/op_base.h", - "nnacl/fp16/conv_fp16.h", + "nnacl_c/fp16/matmul_fp16.h", + "nnacl_c/conv_parameter.h", + "nnacl_c/op_base.h", + "nnacl_c/fp16/conv_fp16.h", }, { "common_func.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h index 1ba4753058e248adec180bab9a87f75c13d7662b..617e904b6cf9717a50ce99ecaf367ac0ce9bcfac 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/op_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.cc index 43f0e00e64c7ce8c6614f5762f69358a23c25632..874dda1c817dd61b4ab83c08c9dc045b849cb125 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.cc @@ -19,7 +19,7 @@ #include #include #include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" -#include "nnacl/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_utils.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -101,10 +101,10 @@ void ConvolutionFP16Coder::CollectFilesForFunc(CoderContext *const context) { }); Collect(context, { - "nnacl/fp16/matmul_fp16.h", - "nnacl/conv_parameter.h", - "nnacl/op_base.h", - "nnacl/fp16/conv_fp16.h", + "nnacl_c/fp16/matmul_fp16.h", + "nnacl_c/conv_parameter.h", + "nnacl_c/op_base.h", + "nnacl_c/fp16/conv_fp16.h", }, { "common_func.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.h index 2c92876f4f9b4a9817c17d5d81200271db8b6013..428478308e0f2e6de5fad078dd921f7283ac22af 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/nnacl/fp32/convolution_fp32_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.cc index b03e484e902a625da5ebfd46b9002193bfe4dc47..68d4b74bfca55f2cf615fd46e969c899209c8372 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.cc @@ -15,7 +15,7 @@ */ #include "coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.h" #include -#include "nnacl/base/minimal_filtering_generator.h" +#include "nnacl_c/base/minimal_filtering_generator.h" #include "coder/opcoders/parallel.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" @@ -205,10 +205,10 @@ std::string ConvolutionWinogradFP16Coder::GetOutputTransFunc(int input_unit, int void ConvolutionWinogradFP16Coder::CollectFilesForFunc(CoderContext *const context) { Collect(context, - {"nnacl/fp16/conv_fp16.h", "nnacl/fp16/winograd_utils_fp16.h", - "nnacl/fp16/winograd_transform_fp16.h" - "nnacl/base/minimal_filtering_generator.h" - "nnacl/base/conv_common_base.h"}, + {"nnacl_c/fp16/conv_fp16.h", "nnacl_c/fp16/winograd_utils_fp16.h", + "nnacl_c/fp16/winograd_transform_fp16.h" + "nnacl_c/base/minimal_filtering_generator.h" + "nnacl_c/base/conv_common_base.h"}, { "conv_fp16.c", "winograd_utils_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.h index 824dbc5dba042bf172df99d83c6d68a5f5a8fead..8f46a73aed94fd83e18e5a5792aae24f698e1219 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.h @@ -21,7 +21,7 @@ #include #include #include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro::nnacl { typedef struct TransFuncFp16Str { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc index 5470b56ab03d566e1e3e5ae751cff0be2aa68cfd..19bdfad19c96286711aa33e3d80be3fe2658da5c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc @@ -21,7 +21,7 @@ using mindspore::schema::PrimitiveType_Custom; namespace mindspore::lite::micro::nnacl { void CustomGruFP16Coder::InitNnaclFile(CoderContext *const context) { - Collect(context, {"nnacl/fp16/custom_gru_fp16.h"}, + Collect(context, {"nnacl_c/fp16/custom_gru_fp16.h"}, {"custom_gru_fp16.c", "pack_fp16.c", "matmul_fp16.c", "arithmetic_fp16.c", "activation_fp16.c"}); } diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h index ef38f1f0d52b27b340976c5da1814aade51ddae0..1d16fce69eb0e84d73ff0046a1465262a2e40e6b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h" -#include "nnacl/custom_gru_parameter.h" +#include "nnacl_c/custom_gru_parameter.h" namespace mindspore::lite::micro::nnacl { class CustomGruFP16Coder : public CustomGruFP32Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.cc index d6be5f03063cd26d952f986b7061d0e183ede77d..dfbdf43ef5fabfd10cb05f1ce6972f19e44fd613 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.cc @@ -99,16 +99,16 @@ void DeConvolutionFP16Coder::CollectFilesForFunc(CoderContext *const context) { } Collect(context, { - "nnacl/fp16/deconv_fp16.h", - "nnacl/fp16/pack_fp16.h", - "nnacl/fp16/matmul_fp16.h", - "nnacl/fp16/common_func_fp16.h", - "nnacl/base/minimal_filtering_generator.h", - "nnacl/conv_parameter.h", - "nnacl/common_func.h", - "nnacl/matmul_parameter.h", + "nnacl_c/fp16/deconv_fp16.h", + "nnacl_c/fp16/pack_fp16.h", + "nnacl_c/fp16/matmul_fp16.h", + "nnacl_c/fp16/common_func_fp16.h", + "nnacl_c/base/minimal_filtering_generator.h", + "nnacl_c/conv_parameter.h", + "nnacl_c/common_func.h", + "nnacl_c/matmul_parameter.h", "wrapper/base/micro_parameter.h", - "nnacl/op_base.h", + "nnacl_c/op_base.h", }, { "common_func.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.h index 664fc24b5b3c9dd729f0857966417c8c7f778154..683c2b569917c1845bf01750cf256560cceca081 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.cc index 638a3e0bf62838dd07d5f335708a3ed9b712d629..9ac0bb0033ddb8f2b1ecd382d9eca58fbe30b4e5 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.cc @@ -33,7 +33,7 @@ int LayerNormFP16Coder::Prepare(CoderContext *const context) { int LayerNormFP16Coder::DoCode(CoderContext *const context) { NNaclFp32Serializer code; code.CodeStruct("layer_norm_compute_parm", compute_); - Collect(context, {"nnacl/fp16/layer_norm_fp16.h"}, {"layer_norm_fp16.c"}); + Collect(context, {"nnacl_c/fp16/layer_norm_fp16.h"}, {"layer_norm_fp16.c"}); if (output_tensors_.size() == C3NUM) { code.CodeFunction("LayerNormFp16", input_tensor_, input_tensors_.at(SECOND_INPUT), input_tensors_.at(THIRD_INPUT), output_tensor_, output_tensors_.at(SECOND_INPUT), output_tensors_.at(THIRD_INPUT), diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.h index df025e3cd97635c36363249faabe0710d52b9f29..7be90eb62369b123006458e33ffa79def51ba741 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/nnacl/fp32/layernorm_fp32_coder.h" -#include "nnacl/layer_norm_parameter.h" +#include "nnacl_c/layer_norm_parameter.h" namespace mindspore::lite::micro::nnacl { class LayerNormFP16Coder final : public LayerNormFP32Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.cc index f2a1eaa364077a07e62fadc5d7ca4ccf84170906..213144f26749376fa52c2387780441fa7a8c2739 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.cc @@ -198,8 +198,8 @@ int LstmFP16Coder::Prepare(CoderContext *const context) { int LstmFP16Coder::DoCode(CoderContext *context) { Collect(context, { - "nnacl/lstm_parameter.h", - "nnacl/fp16/lstm_fp16.h", + "nnacl_c/lstm_parameter.h", + "nnacl_c/fp16/lstm_fp16.h", }, { "lstm_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.h index fbaa3bd0690ff417b0163f6ad8b44fc384d82364..8ab283b98c32c670e96286c07fcd99cb9940e71f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/nnacl/fp32/lstm_fp32_coder.h" -#include "nnacl/lstm_parameter.h" +#include "nnacl_c/lstm_parameter.h" namespace mindspore::lite::micro::nnacl { class LstmFP16Coder final : public LstmFP32Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc index 66b6db4b22913edc20fad946a452235dc7124b8e..86dda6589e4885530a9c9885e6249dbbe5f6b851 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc @@ -46,7 +46,7 @@ int LstmMindirDynamicFP16Coder::Prepare(CoderContext *const context) { } int LstmMindirDynamicFP16Coder::DoCode(CoderContext *const context) { - Collect(context, {"nnacl/lstm_parameter.h", "nnacl/fp16/lstm_fp16.h"}, + Collect(context, {"nnacl_c/lstm_parameter.h", "nnacl_c/fp16/lstm_fp16.h"}, {"lstm_fp16.c", "activation_fp16.c", "arithmetic_fp16.c", "matmul_fp16.c", "pack_fp16.c"}, {"MatmulBaseFp16Neon.S"}); auto ret = InitInputWeightBias(context); diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h index a348b917d28c9c365a0e43509744fcc372525a1a..1526547c75dc9941640f65b877d2cebf2f165919 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/lstm_parameter.h" +#include "nnacl_c/lstm_parameter.h" #include "coder/opcoders/nnacl/dynamic_parameter/dynamic_lstm_parameter.h" #include "coder/opcoders/op_coder.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc index f1b6548850361797928de12f400bd8aa62c56f22..a82db8fb3d99f931e0da1627c2b6659ef2d43fd3 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc @@ -208,8 +208,8 @@ int MatMulDynamicFP16BaseCoder::ComputeMatrixAWorkspace() { int MatMulDynamicFP16BaseCoder::CollectFilesForTarget(CoderContext *const context) { Collect(context, { - "nnacl/fp16/pack_fp16.h", - "nnacl/fp16/matmul_fp16.h", + "nnacl_c/fp16/pack_fp16.h", + "nnacl_c/fp16/matmul_fp16.h", }, { "pack_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h index 250fb96ba85908e9bd0446f7d99c42b1cb852d11..29de2e3c68297185689f8391d5c275166b4bcd16 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h @@ -21,7 +21,7 @@ #include #include "tools/converter/micro/coder/opcoders/op_coder.h" #include "tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "coder/opcoders/nnacl/dynamic_parameter/matmul_dynamic_parameter.h" #include "base/float16.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h index 768143b2c6cfde47d3b43f043451ab9f69d71e02..6f1a38c0e7412fa7aa0e541d1a9ab1c467441c5b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h @@ -19,7 +19,7 @@ #include #include "tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { class MatMulDynamicFP16Coder final : public MatMulDynamicFP16BaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc index 1bf4e8ca9b0ddd1a8e4a767b293b79b9f336fe4d..e188b928c6217e780c0eca223847481e04214db1 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc @@ -227,8 +227,8 @@ int MatMulFP16BaseCoder::Prepare(CoderContext *const context) { int MatMulFP16BaseCoder::CollectFilesForTarget(CoderContext *const context) { Collect(context, { - "nnacl/fp16/pack_fp16.h", - "nnacl/fp16/matmul_fp16.h", + "nnacl_c/fp16/pack_fp16.h", + "nnacl_c/fp16/matmul_fp16.h", }, { "pack_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h index b56a8c1256a1179c65420d31c74ae34b0b4c1215..47398fbe5996dc3228ccc38a3d0e8b03748ce1c2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h @@ -21,7 +21,7 @@ #include #include "coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { class MatMulFP16BaseCoder : public MatMulFP32BaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h index c5ea36cdcdcb5532b5f0100f6c490f93967fc1fb..dd381cedb002ac5dd8e6219d98c96bd175a123b2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { class MatMulFP16Coder final : public MatMulFP16BaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc index 8fd99da3fe5477d4bb21ab801aa6fa62ce4275d6..49abf35b62379817d07e76bd87f95900491acae1 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc @@ -65,7 +65,7 @@ int PoolingDynamicFP16Coder::Prepare(CoderContext *const context) { int PoolingDynamicFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp16/pooling_fp16.h", + "nnacl_c/fp16/pooling_fp16.h", }, { "pooling_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h index d36f935635c9e25adaeedf91ae12e300136c94dd..79b67ccc5928651b875252832684c239858ef117 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h @@ -20,8 +20,8 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/dynamic_parameter/pooling_dynamic_parameter.h" -#include "nnacl/pooling_parameter.h" -#include "nnacl/kernel/pooling.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/kernel/pooling.h" namespace mindspore::lite::micro::nnacl { class PoolingDynamicFP16Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_fp16_coder.cc index b043a8815017a35a26f8ec27486dfcdd94e69364..513f1ae32c939595163504496605f53ff4b3e418 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_fp16_coder.cc @@ -56,8 +56,8 @@ int PoolingFP16Coder::DoCode(CoderContext *const context) { float maxf = FLT16_MAX; Collect(context, { - "nnacl/fp16/pooling_fp16.h", - "nnacl/kernel/pooling.h", + "nnacl_c/fp16/pooling_fp16.h", + "nnacl_c/kernel/pooling.h", }, { "pooling_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/reduce_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/reduce_fp16_coder.cc index 5d2bf54c8e8f1aa655ada7427801abcd4a1d25af..31b5f501c5108f06b9cb07c0e577bce693d70c59 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/reduce_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/reduce_fp16_coder.cc @@ -34,7 +34,7 @@ int ReduceFP16Coder::Prepare(CoderContext *const context) { int ReduceFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp16/reduce_fp16.h", + "nnacl_c/fp16/reduce_fp16.h", }, { "reduce_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.cc index 18c40b8b18e7c91be02d9c47ed876af0bf5b35e4..c19ef95f1f8f161e6b023d7791e39fabdf374ce4 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.cc @@ -22,7 +22,7 @@ #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" #include "coder/utils/common.h" -#include "nnacl/fp32/resize_fp32.h" +#include "nnacl_c/fp32/resize_fp32.h" #include "base/float16.h" using mindspore::schema::PrimitiveType_Resize; @@ -33,8 +33,8 @@ int ResizeFP16Coder::DataTypeLen() { return sizeof(uint16_t); } int ResizeFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp16/resize_fp16.h", - "nnacl/fp32/resize_fp32.h", + "nnacl_c/fp16/resize_fp16.h", + "nnacl_c/fp32/resize_fp32.h", }, { "resize_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.h index 769bb62c764d8f33a911cd21997e009285cfef82..59ae67147527286f9bb9cd0e2772a98796b5046d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.h @@ -23,7 +23,7 @@ #include #include "include/errorcode.h" #include "src/executor/kernel_exec.h" -#include "nnacl/base/cast_base.h" +#include "nnacl_c/base/cast_base.h" namespace mindspore::lite::micro::nnacl { class ResizeFP16Coder : public ResizeFP32Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc index 0d6790e5dadcf6817c06efcd84ebd05936c7864c..05c905f7f648554c7b6b9c9ea2f92fff8c6ecf66 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc @@ -50,8 +50,8 @@ int ScaleDynamicFP16Coder::DoCode(CoderContext *const context) { // init struct ScaleParameters Collect(context, { - "nnacl/kernel/scale.h", - "nnacl/fp16/scale_fp16.h", + "nnacl_c/kernel/scale.h", + "nnacl_c/fp16/scale_fp16.h", }, { "scale_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h index e64286a893110a7b6e2eef4bfc6226dc37ab812a..723bf08ed38e366d6b708c9d17774771e5cfb6c1 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h @@ -20,8 +20,8 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/dynamic_parameter/scale_dynamic_parameter.h" -#include "nnacl/kernel/scale.h" -#include "nnacl/scale_parameter.h" +#include "nnacl_c/kernel/scale.h" +#include "nnacl_c/scale_parameter.h" namespace mindspore::lite::micro::nnacl { class ScaleDynamicFP16Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.cc index d03aeabd7507343e1259e2efe8f3bd169bb5aba8..ee6c0b7c241ad489c7a943df675217f2b1a9f461 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.cc @@ -42,9 +42,9 @@ int ScaleFP16Coder::DoCode(CoderContext *const context) { // init struct ScaleParameters Collect(context, { - "nnacl/scale_parameter.h", - "nnacl/kernel/scale.h", - "nnacl/fp16/scale_fp16.h", + "nnacl_c/scale_parameter.h", + "nnacl_c/kernel/scale.h", + "nnacl_c/fp16/scale_fp16.h", }, { "scale_fp32_wrapper.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.h index 5032cf5172dafb8355ae6afcaba3904fb7122f56..b4da8a24c47290554495a6b587cafc7ae8a4c38f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.h @@ -20,8 +20,8 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/fp32/scale_fp32_coder.h" -#include "nnacl/scale_parameter.h" -#include "nnacl/kernel/scale.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/kernel/scale.h" namespace mindspore::lite::micro::nnacl { class ScaleFP16Coder final : public ScaleFP32Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc index a4ffaa0d3810e1378364a6e1aa17c2d34d96c794..e486120634d8b36d11f728be7d92e522bae14a5f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc @@ -44,7 +44,7 @@ int SliceDynamicFP16Coder::Prepare(CoderContext *const context) { int SliceDynamicFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/slice_base.h", + "nnacl_c/base/slice_base.h", }, { "slice_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h index 35b0bde08505f5f72e712afb63172545e3220dac..6defebcf0953cd21c0aacdd7482d76e7f209b8d9 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h @@ -21,8 +21,8 @@ #include #include "tools/converter/micro/coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/dynamic_parameter/slice_dynamic_parameter.h" -#include "nnacl/kernel/slice.h" -#include "nnacl/op_base.h" +#include "nnacl_c/kernel/slice.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite::micro::nnacl { class SliceDynamicFP16Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.cc index d75c9f676f08efd1fbfcd126f680c21303a4699e..36517122a2f7e7d453a0d3ddcf9c9f63389175e0 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.cc @@ -24,7 +24,7 @@ namespace mindspore::lite::micro::nnacl { int SliceFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/slice_base.h", + "nnacl_c/base/slice_base.h", }, { "slice_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.h index d6f503ec31d1f01b3c276a0583a5f74684ca08c1..ec5fb0a0d87831889a3f0a1ec73d559380881546 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.h @@ -20,7 +20,7 @@ #include #include "tools/converter/micro/coder/opcoders/op_coder.h" #include "tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.h" -#include "nnacl/kernel/slice.h" +#include "nnacl_c/kernel/slice.h" namespace mindspore::lite::micro::nnacl { class SliceFP16Coder final : public SliceFP32Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc index 9c08e9f3fd024d1d9d023f2ad17d10fc4f5fc284..270469687787de0e0a061c6cab3189bd9e467823 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc @@ -44,8 +44,8 @@ int SoftmaxDynamicFP16Coder::Prepare(CoderContext *const context) { int SoftmaxDynamicFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp16/softmax_fp16.h", - "nnacl/fp16/log_softmax_fp16.h", + "nnacl_c/fp16/softmax_fp16.h", + "nnacl_c/fp16/log_softmax_fp16.h", }, { "softmax_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h index 1063969baa9da1c553456613c50516f04917ab0a..4041db0faff0abe688bf4ac054b90ae9ba26f310 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h @@ -21,8 +21,8 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/dynamic_parameter/softmax_dynamic_parameter.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/kernel/softmax.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/kernel/softmax.h" namespace mindspore::lite::micro::nnacl { class SoftmaxDynamicFP16Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_fp16_coder.cc index ceea05ede4d5af4a4ae8be11f66b711cd2e5e875..4bc8e6908b0f711e0864a01b5e2f094cecb409de 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_fp16_coder.cc @@ -40,8 +40,8 @@ int SoftMaxFP16Coder::Prepare(CoderContext *const context) { int SoftMaxFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp16/softmax_fp16.h", - "nnacl/fp16/log_softmax_fp16.h", + "nnacl_c/fp16/softmax_fp16.h", + "nnacl_c/fp16/log_softmax_fp16.h", }, { "softmax_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc index 59c8d8b81126d8dd22ddedfeeb121c2f8c955f6d..7ce12f96a0ca807a918f2c3f4183cb2ac653c0d1 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc @@ -40,9 +40,9 @@ int TransposeDynamicFp16Coder::Prepare(CoderContext *const context) { int TransposeDynamicFp16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/transpose_parameter.h", - "nnacl/errorcode.h", - "nnacl/fp16/transpose_fp16.h", + "nnacl_c/transpose_parameter.h", + "nnacl_c/errorcode.h", + "nnacl_c/fp16/transpose_fp16.h", }, { "transpose_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h index b31f1022a467e0316fad6a037c909215bc0ee142..62594117e01e65182724be7cd3700aa7b89a8473 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h @@ -19,7 +19,7 @@ #include #include #include "coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" #include "coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.cc index 3d826cae5102d8fb5a4c2e09903f81dff9e88c69..1ada88d62c0b63fe639be7afacf3801cc4728de3 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.cc @@ -81,9 +81,9 @@ int TransposeFp16Coder::ResetStatus() { int TransposeFp16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/transpose_parameter.h", - "nnacl/errorcode.h", - "nnacl/fp16/transpose_fp16.h", + "nnacl_c/transpose_parameter.h", + "nnacl_c/errorcode.h", + "nnacl_c/fp16/transpose_fp16.h", }, { "transpose_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.h index ce99f558bd2cdf6fcefb5c2933ec2153e1309a66..91d68a7a2319bc1fc3d30cdd83d4dbe9faf1e820 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.h @@ -19,7 +19,7 @@ #include #include #include "coder/opcoders/nnacl/fp32/transpose_fp32_coder.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" namespace mindspore::lite::micro::nnacl { class TransposeFp16Coder final : public TransposeFp32Coder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_fp32_coder.cc index edc442e9483247879e9b58efb2a979d7cf112935..bae2a83938d89bd4d6e17eca6bc1122fa5ef1025 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_fp32_coder.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "coder/opcoders/nnacl/fp32/activation_fp32_coder.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" @@ -29,7 +29,7 @@ int ActivationFP32Coder::DoCode(CoderContext *const context) { Collect(context, { "wrapper/fp32/activation_fp32_wrapper.h", - "nnacl/fp32/activation_fp32.h", + "nnacl_c/fp32/activation_fp32.h", }, { "activation_fp32_wrapper.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/addn_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/addn_fp32_coder.cc index d2394514ab0c595f4ccdfb2967563a51f7b56746..04465a07f549de6c5f68d6228d7311dc25179037 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/addn_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/addn_fp32_coder.cc @@ -31,7 +31,7 @@ int AddNFP32Coder::DoCode(CoderContext *const context) { // Get Tensor Pointer Collect(context, { - "nnacl/fp32/add_fp32.h", + "nnacl_c/fp32/add_fp32.h", }, { "add_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/affine_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/affine_fp32_coder.h index f60a80fe7d5c0fee0597c3453360b0ed9c596763..b71d0818804fb235dee514c76541124ab071ab0b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/affine_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/affine_fp32_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/affine_parameter.h" +#include "nnacl_c/affine_parameter.h" #include "tools/converter/micro/coder/wrapper/base/affine_wrapper.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.cc index 9ec429e81406bd39ef9f1a3359c0ebdd3820aadc..c94b10029e102a75048ac584a0585b5687ecaa44 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.h" #include #include "coder/opcoders/file_collector.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" #include "coder/opcoders/parallel.h" #include "coder/log.h" @@ -208,7 +208,7 @@ int ArithmeticFP32Coder::ConstTensorBroadCast(CoderContext *const context) { } FreeConstTileBuff(); NNaclFp32Serializer init_code; - Collect(context, {"wrapper/fp32/arithmetic_fp32_wrapper.h", "nnacl/fp32/arithmetic_fp32.h"}, + Collect(context, {"wrapper/fp32/arithmetic_fp32_wrapper.h", "nnacl_c/fp32/arithmetic_fp32.h"}, {"arithmetic_fp32_wrapper.c", "arithmetic_fp32.c"}); if (input_tensor_->IsConst() && arithmetic_parameter_->in_elements_num0_ != arithmetic_parameter_->out_elements_num_) { @@ -286,7 +286,7 @@ void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) { if (arithmetic_opt_run_ == "ElementOptSub" || arithmetic_run_ == "ElementSub") { Collect(context, { - "nnacl/fp32/sub_fp32.h", + "nnacl_c/fp32/sub_fp32.h", }, { "sub_fp32.c", @@ -294,7 +294,7 @@ void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) { } else if (arithmetic_opt_run_ == "ElementOptAdd" || arithmetic_run_ == "ElementAdd") { Collect(context, { - "nnacl/fp32/add_fp32.h", + "nnacl_c/fp32/add_fp32.h", }, { "add_fp32.c", @@ -304,7 +304,7 @@ void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) { } else if (arithmetic_opt_run_ == "ElementOptMul" || arithmetic_run_ == "ElementMul") { Collect(context, { - "nnacl/fp32/mul_fp32.h", + "nnacl_c/fp32/mul_fp32.h", }, { "mul_fp32.c", @@ -312,7 +312,7 @@ void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) { } else if (arithmetic_run_ == "ElementAddRelu") { Collect(context, { - "nnacl/fp32/add_fp32.h", + "nnacl_c/fp32/add_fp32.h", }, { "add_fp32.c", @@ -321,7 +321,7 @@ void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) { arithmetic_run_ == "ElementDiv") { Collect(context, { - "nnacl/fp32/div_fp32.h", + "nnacl_c/fp32/div_fp32.h", }, { "div_fp32.c", @@ -329,7 +329,7 @@ void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) { } else { Collect(context, { - "nnacl/fp32/arithmetic_fp32.h", + "nnacl_c/fp32/arithmetic_fp32.h", }, { "arithmetic_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.h index 169ed4577af152011e51b1c12ff622891bbad7bf..4f35e507f7b7f6466da6e6b96879004ab122817d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "wrapper/fp32/arithmetic_fp32_wrapper.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.cc index 583e1d0dda01e76809c54da9e0731c0ee910320d..c423e263d2cfb4135b5b26edec6afaad4b9fd954 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.cc @@ -17,7 +17,7 @@ #include "coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.h" #include #include -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" @@ -68,7 +68,7 @@ int ArithmeticSelfFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp32/arithmetic_self_fp32.h", + "nnacl_c/fp32/arithmetic_self_fp32.h", }, { "arithmetic_self_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.h index 64f57af29617b16d6943b617493fc1b3b1250751..a7c0770784560328735cef9888f250a1687125b4 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.h @@ -20,8 +20,8 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/fp32/arithmetic_self_fp32.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/fp32/arithmetic_self_fp32.h" +#include "nnacl_c/arithmetic_self_parameter.h" namespace mindspore::lite::micro::nnacl { using mindspore::schema::PrimitiveType_Abs; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/assign_add_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/assign_add_fp32_coder.h index 48f602549535f8bc9afd4418a030de40caebf8c2..f0f878a7c947deca70dcbb000368a6878c6fbdd4 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/assign_add_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/assign_add_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/base/tile_base.h" +#include "nnacl_c/base/tile_base.h" namespace mindspore::lite::micro::nnacl { class AssignAddFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.cc index 0df6c6c9495aef025de24707128432e47b1e904a..61f6b395665a0e7925aa125891d2681c33e02cfc 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.cc @@ -16,8 +16,8 @@ #include "coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.h" #include #include -#include "nnacl/fp32/batchnorm_fp32.h" -#include "nnacl/op_base.h" +#include "nnacl_c/fp32/batchnorm_fp32.h" +#include "nnacl_c/op_base.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" @@ -58,11 +58,11 @@ int BatchnormFP32Coder::DoCode(CoderContext *const context) { MS_CHECK_PTR(var_tensor); Collect(context, { - "nnacl/fp32/batchnorm.h", - "nnacl/kernel/batch_norm.h", + "nnacl_c/fp32/batchnorm.h", + "nnacl_c/kernel/batch_norm.h", }, { - "nnacl/fp32/batchnorm.c", + "nnacl_c/fp32/batchnorm.c", }); NNaclFp32Serializer code; code.CodeStruct("bn_struct", batchnorm_struct_); diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.h index 1f76796a8f513c8ab2f2c671a9368a13aad43ef2..e7e1760668a63bda4f700c97fed43e24766297ef 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/kernel/batch_norm.h" +#include "nnacl_c/kernel/batch_norm.h" namespace mindspore::lite::micro::nnacl { class BatchnormFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.cc index cf36a3d5d5f26e8ef518fc1b4a8f2072f0a65745..d1b48dd6d9bc76c6b6131c30026dcc2a07df0127 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.cc @@ -15,6 +15,7 @@ */ #include "coder/opcoders/nnacl/fp32/biasadd_fp32_coder.h" +#include #include "coder/opcoders/file_collector.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" @@ -35,12 +36,12 @@ int BiasAddFP32Coder::DoCode(CoderContext *ctx) { std::string bias_str = allocator_->GetRuntimeAddr(input_tensors_.at(kWeightIndex), true); Collect(ctx, { - "nnacl/arithmetic_parameter.h", - "nnacl/nnacl_utils.h", - "nnacl/nnacl_common.h", - "nnacl/base/arithmetic_base.h", - "nnacl/fp32/add_fp32.h", - "nnacl/fp32/arithmetic_fp32.h", + "nnacl_c/arithmetic_parameter.h", + "nnacl_c/nnacl_utils.h", + "nnacl_c/nnacl_common.h", + "nnacl_c/base/arithmetic_base.h", + "nnacl_c/fp32/add_fp32.h", + "nnacl_c/fp32/arithmetic_fp32.h", }, { "arithmetic_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.h index 62abec085d2dc1dca2d22c014e148517e4b07acd..25c983db87baa5f606600369e02d422f2deb697d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore::lite::micro::nnacl { class BiasAddFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.cc index 6419e11d066c87d049e077c8e46c53680d505210..da17059d441f08ee3991cc4e5d538df5b34ec66e 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.cc @@ -37,7 +37,7 @@ int ConcatFP32Coder::ReSize() { int ConcatFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/concat_base.h", + "nnacl_c/base/concat_base.h", }, { "concat_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.h index 6f3f5c71b0ac78af358b6e2605f40147e55e2ef8..2cda19ffb4c1e74a1d667d904887313c2289767c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" namespace mindspore::lite::micro::nnacl { class ConcatFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.cc index 3aece46e2718967c5654efa991d952825d298bc4..b4608f6601a08872ae78f8832a08f53f0fd12f5a 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.cc @@ -17,8 +17,8 @@ #include "coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.h" #include "src/common/version_manager.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/winograd_utils.h" -#include "nnacl/base/conv_common_base.h" +#include "nnacl_c/fp32/winograd_utils.h" +#include "nnacl_c/base/conv_common_base.h" #include "coder/opcoders/nnacl/fp32/convolution_fp32_coder.h" #include "coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.h" #include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.h index 0cd3adb5baf5235051b65c8d0a44a44c83433001..08678bcbacc180bb0d3af2eb7792acc1cef2e970 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.h @@ -19,7 +19,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro::nnacl { class ConvDelegateCoder : public OperatorCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc index 537fa14e0dd6f71ba7442662f52fb1561cdfa042..492af5376a7e40b29537759305c0565b8a1c8ef6 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc @@ -84,7 +84,7 @@ void ConvolutionDepthwiseFP32Coder::InitCodeOnline(CoderContext *const context) } Collect(context, { - "nnacl/fp32/pack_fp32.h", + "nnacl_c/fp32/pack_fp32.h", }, {"pack_fp32.c"}); NNaclFp32Serializer init_code; @@ -117,7 +117,7 @@ void ConvolutionDepthwiseFP32Coder::CollectFilesForFunc(CoderContext *const cont } Collect(context, { - "nnacl/fp32/conv_depthwise_fp32.h", + "nnacl_c/fp32/conv_depthwise_fp32.h", }, { "conv_depthwise_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.cc index cbdf2e7121eed34c4decf31ff36c3d41f27eb6b0..6008a5f8542c967c3f4af2024074bbd9b8094cd9 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.cc @@ -19,7 +19,7 @@ #include #include #include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" -#include "nnacl/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_utils.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -145,10 +145,10 @@ void ConvolutionFP32Coder::CollectFilesForFunc(CoderContext *const context) { } Collect(context, { - "nnacl/fp32/conv_common_fp32.h", - "nnacl/fp32/matmul_fp32.h", - "nnacl/conv_parameter.h", - "nnacl/op_base.h", + "nnacl_c/fp32/conv_common_fp32.h", + "nnacl_c/fp32/matmul_fp32.h", + "nnacl_c/conv_parameter.h", + "nnacl_c/op_base.h", }, { "common_func.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.h index cf6ad6148da8104f45f84c7e226d5263db2d51e4..94fbe0598efbcfd21a7da924664fa8792d358b27 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/base/conv2d_base_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc index c15d31018171c2eb12ebb551c9f4c4967e272215..39938f5591ed52dec7d46f8acc852f1c62b2c3d7 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc @@ -15,7 +15,7 @@ */ #include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" #include -#include "nnacl/base/minimal_filtering_generator.h" +#include "nnacl_c/base/minimal_filtering_generator.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" #include "coder/opcoders/file_collector.h" @@ -226,10 +226,10 @@ void ConvolutionWinogradFP32Coder::InitCodeOnline(CoderContext *const context) { } Collect(context, { - "nnacl/base/minimal_filtering_generator.h", - "nnacl/fp32/pack_fp32.h", + "nnacl_c/base/minimal_filtering_generator.h", + "nnacl_c/fp32/pack_fp32.h", }, - {"minimal_filtering_generator.c", "nnacl/fp32/pack_fp32.h"}); + {"minimal_filtering_generator.c", "nnacl_c/fp32/pack_fp32.h"}); NNaclFp32Serializer init_code; init_code.CodeBufferOffsetExpression(trans_weight_, context->weight_name(), context->weight_offset_name(), context->weight_size_name(), trans_weight_size_); @@ -279,8 +279,8 @@ void ConvolutionWinogradFP32Coder::CollectFilesForFunc(CoderContext *const conte } Collect(context, { - "nnacl/fp32/conv_winograd_fp32.h", - "nnacl/common_func.h", + "nnacl_c/fp32/conv_winograd_fp32.h", + "nnacl_c/common_func.h", }, { "common_func.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h index c5c3a53492a0bbd1dce4bbcb1c44202b8e029c30..a3d6489f4a5b2988e244dd29d087c1155697ed57 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h @@ -21,7 +21,7 @@ #include #include #include "coder/opcoders/base/conv2d_base_coder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "wrapper/fp32/conv_winograd_fp32_wrapper.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc index ecbc6701f1c25d9da077988e624ba422738b15b4..e5ba2f5a1adb21bb2c9de276b90bec4cbcfb5606 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" -#include "nnacl/custom_gru_parameter.h" +#include "nnacl_c/custom_gru_parameter.h" using mindspore::schema::PrimitiveType_Custom; @@ -127,7 +127,7 @@ int CustomGruFP32Coder::InitWeightAndBias() { } void CustomGruFP32Coder::InitNnaclFile(CoderContext *const context) { - Collect(context, {"nnacl/fp32/custom_gru_fp32.h"}, + Collect(context, {"nnacl_c/fp32/custom_gru_fp32.h"}, {"custom_gru_fp32.c", "pack_fp32.c", "matmul_fp32.c", "arithmetic_fp32.c", "activation_fp32.c"}); } diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h index f7ccbf313ded9e484a352e440ed37c95022dec26..ffb0714056322e79771b78115bbd0bc2d4958c69 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/custom_gru_parameter.h" +#include "nnacl_c/custom_gru_parameter.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.cc index 83bf34d45b6cb05e84088dc9c09fc7d052ac5d84..181bef9fe4e042439c91800489d4bf40dbfce7f1 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.cc @@ -18,7 +18,7 @@ #include #include #include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" -#include "nnacl/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_utils.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -148,15 +148,15 @@ void DeConvolutionFP32Coder::CollectFilesForFunc(CoderContext *const context) { Collect(context, { "wrapper/fp32/deconvolution_fp32_wrapper.h", - "nnacl/fp32/conv_common_fp32.h", - "nnacl/pack.h", - "nnacl/fp32/common_func_fp32.h", - "nnacl/base/minimal_filtering_generator.h", - "nnacl/fp32/matmul_fp32.h", - "nnacl/conv_parameter.h", - "nnacl/matmul_parameter.h", + "nnacl_c/fp32/conv_common_fp32.h", + "nnacl_c/pack.h", + "nnacl_c/fp32/common_func_fp32.h", + "nnacl_c/base/minimal_filtering_generator.h", + "nnacl_c/fp32/matmul_fp32.h", + "nnacl_c/conv_parameter.h", + "nnacl_c/matmul_parameter.h", "wrapper/base/micro_parameter.h", - "nnacl/op_base.h", + "nnacl_c/op_base.h", }, { "deconvolution_fp32_wrapper.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.h index b6901eb07edb2ab94271b7de1d5a6d55708ed269..a01bc376273116ff83a0f7f7c51c8b80c3f3bc75 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.h @@ -19,11 +19,11 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/base/conv2d_base_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" -#include "nnacl/fp32/deconv_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/deconv_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" namespace mindspore::lite::micro::nnacl { class DeConvolutionFP32Coder : public Conv2DBaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.cc index 0cbc7ea1f26446911559a5be3ca3e4dc832fd857..3fb054b50c31aa5d82e9094f6300aef8d57bbd43 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.cc @@ -44,7 +44,7 @@ int ExpFP32Coder::Prepare(CoderContext *context) { int ExpFP32Coder::DoCode(CoderContext *ctx) { Collect(ctx, { - "nnacl/fp32/exp_fp32.h", + "nnacl_c/fp32/exp_fp32.h", }, { "exp_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.h index 2be5cc5b907d5e87f9d16e3bc96c0c253117a67f..20f4628f21badeecab93d0788288f15c289aac24 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_EXP_FP32_CODER_H_ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/fp32/exp_fp32.h" +#include "nnacl_c/fp32/exp_fp32.h" namespace mindspore::lite::micro::nnacl { class ExpFP32Coder final : public OperatorCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.cc index 32d89de92fef39d5ffa3d51e2ba44783dcc28253..beb21734ac6da1ec5a0ffe42964c5e045dbfe6bf 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.cc @@ -36,7 +36,7 @@ int FillFP32Coder::Prepare(CoderContext *context) { int FillFP32Coder::DoCode(CoderContext *ctx) { Collect(ctx, { - "nnacl/kernel/fill.h", + "nnacl_c/kernel/fill.h", }, { "fill.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.h index ccd47e5c93a860436f79f832f20b49b695dbf615..c9f6cf23dc20ab282a4ccf51b6801d8121c2b0ba 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_FILL_FP32_CODER_H_ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/kernel/fill.h" +#include "nnacl_c/kernel/fill.h" namespace mindspore::lite::micro::nnacl { class FillFP32Coder final : public OperatorCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc index 3cea0e352aea17af6cb947188434cbe3efff8493..a9066d29152182c6439bcbb88de7578b0f4d479f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.h" #include -#include "nnacl/gather_parameter.h" +#include "nnacl_c/gather_parameter.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" #include "coder/utils/coder_utils.h" @@ -44,7 +44,7 @@ int GatherDynamicFP32Coder::Prepare(CoderContext *const context) { int GatherDynamicFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/gather_base.h", + "nnacl_c/base/gather_base.h", }, { "gather_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc index 1208baddaf1af70f16a8c02eab59dd98cdf362ec..0040822cd59a810e025332d7a2e375777a6ea309 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp32/gather_fp32_coder.h" #include -#include "nnacl/gather_parameter.h" +#include "nnacl_c/gather_parameter.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/parallel.h" #include "coder/opcoders/file_collector.h" @@ -55,7 +55,7 @@ int GatherFP32Coder::DoCode(CoderContext *context) { // generate code .h .c Collect(context, { - "nnacl/base/gather_base.h", + "nnacl_c/base/gather_base.h", }, { "gather_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h index 6bf7ae6a5f2383f4d9f4546ebd4bd003d09b72ea..062e247db63f6204d713411f6886629fcbe37b6b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/base/tile_base.h" +#include "nnacl_c/base/tile_base.h" namespace mindspore::lite::micro::nnacl { class GatherFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/groupnorm_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/groupnorm_fp32_coder.cc index b5415861f3c8239d9c90303afb87d474a6fb7938..a66d3c451c2deac06f9f76af23df352e0dff18b9 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/groupnorm_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/groupnorm_fp32_coder.cc @@ -16,8 +16,8 @@ #include "coder/opcoders/nnacl/fp32/groupnorm_fp32_coder.h" #include #include -#include "nnacl/fp32/group_norm_fp32.h" -#include "nnacl/op_base.h" +#include "nnacl_c/fp32/group_norm_fp32.h" +#include "nnacl_c/op_base.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" @@ -68,7 +68,7 @@ int GroupNormFP32Coder::DoCode(CoderContext *const context) { MS_CHECK_PTR(offset_tensor); Collect(context, { - "nnacl/fp32/group_norm_fp32.h", + "nnacl_c/fp32/group_norm_fp32.h", }, { "group_norm_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.cc index 367ab77a83ca8f88e4a54f6a12adf624b734132d..cb93a00369a78c58f1483331d957a2dd1c19fd66 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.cc @@ -40,7 +40,7 @@ int InstanceNormFP32Coder::Prepare(CoderContext *const context) { int InstanceNormFP32Coder::DoCode(CoderContext *const context) { NNaclFp32Serializer code; code.CodeStruct("instance_norm_param", *param_); - Collect(context, {"nnacl/fp32/pack_fp32.h", "nnacl/fp32/instance_norm_fp32.h"}, + Collect(context, {"nnacl_c/fp32/pack_fp32.h", "nnacl_c/fp32/instance_norm_fp32.h"}, {"pack_fp32.c", "instance_norm_fp32.c"}); if (input_tensors_[0]->format() == NHWC) { code.CodeFunction("PackNHWCToNC4HW4NotAlignedFp32", input_tensor_, tmp_src_data_, param_->batch_, diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.h index 2fd42c4f4f685a20ba206e965afa7350fd5246d0..ad4fa8136e72588afcf6dad632e73ebf83a29f5b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/instance_norm_parameter.h" +#include "nnacl_c/instance_norm_parameter.h" namespace mindspore::lite::micro::nnacl { class InstanceNormFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.cc index e5d6b10ec29c2f575d730232e2c6718b7a16ce15..f5f135a45522988c247e88d9881f17ff0b38395f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.cc @@ -68,7 +68,7 @@ int LayerNormFP32Coder::Prepare(CoderContext *const context) { int LayerNormFP32Coder::DoCode(CoderContext *const context) { NNaclFp32Serializer code; code.CodeStruct("layer_norm_compute_parm", compute_); - Collect(context, {"nnacl/fp32/layer_norm_fp32.h"}, {"layer_norm_fp32.c"}); + Collect(context, {"nnacl_c/fp32/layer_norm_fp32.h"}, {"layer_norm_fp32.c"}); if (output_tensors_.size() == kOutputNum) { code.CodeFunction("LayerNorm", input_tensor_, input_tensors_.at(SECOND_INPUT), input_tensors_.at(THIRD_INPUT), output_tensor_, output_tensors_.at(SECOND_INPUT), output_tensors_.at(THIRD_INPUT), diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.h index 4842d9cfbae91ef4188780f8c817dc4cf701ae41..bde27e13c3ce14e334c927dc093623b373d55e7a 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.h @@ -19,8 +19,8 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/layer_norm_parameter.h" -#include "nnacl/kernel/layer_norm.h" +#include "nnacl_c/layer_norm_parameter.h" +#include "nnacl_c/kernel/layer_norm.h" namespace mindspore::lite::micro::nnacl { class LayerNormFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.cc index 30f95332e44fe85ab30c5a869721e088e74db7ff..ab42dac9587cc3ead11a5f8b4545bf3a97109496 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.cc @@ -182,8 +182,8 @@ int LstmFP32Coder::Prepare(CoderContext *const context) { int LstmFP32Coder::DoCode(CoderContext *context) { Collect(context, { - "nnacl/lstm_parameter.h", - "nnacl/fp32/lstm_fp32.h", + "nnacl_c/lstm_parameter.h", + "nnacl_c/fp32/lstm_fp32.h", }, { "lstm_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.h index b54bf53a24e1b74fb1a05370b0addfa5987f3304..eed3c3c16016ac3e38207c418a30a4e5ce70282c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/lstm_parameter.h" +#include "nnacl_c/lstm_parameter.h" namespace mindspore::lite::micro::nnacl { class LstmFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc index 84438cc8bfb0fc6fcf669eae792049703f686460..0193dfa76bea95b63c8cc4e0eff4afd661dd3d43 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc @@ -21,7 +21,7 @@ #include "coder/opcoders/parallel.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" #include "wrapper/fp32/matmul_fp32_wrapper.h" #include "coder/opcoders/nnacl/dequant/de_quant.h" @@ -160,8 +160,8 @@ int MatMulFP32BaseCoder::Prepare(CoderContext *const context) { return RET_OK; } int MatMulFP32BaseCoder::CollectFilesForTarget(CoderContext *const context) { Collect(context, { - "nnacl/fp32/pack_fp32.h", - "nnacl/fp32/matmul_fp32.h", + "nnacl_c/fp32/pack_fp32.h", + "nnacl_c/fp32/matmul_fp32.h", "wrapper/fp32/matmul_fp32_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h index 6e92508a2179460f9f5840c29d53d4f941542b16..11a21413586a0c44da62264ddedf874ae5a61dbe 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "wrapper/base/micro_parameter.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_coder.h index 046a2a1a8f5f620d9553f7b8ebdfa9c165d07357..e05bae8cf45632633d7a34419de2b86c914567cb 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { class MatMulFP32Coder final : public MatMulFP32BaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/ones_like_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/ones_like_fp32_coder.cc index 32909be176f24a71b7019bac93a21390ac11c7be..5646ad255ba8a2daeb7329ecca091c1ba737a30c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/ones_like_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/ones_like_fp32_coder.cc @@ -29,7 +29,7 @@ int OnesLikeFP32Coder::Prepare(CoderContext *const context) { return RET_OK; } int OnesLikeFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/kernel/ones_like.h", + "nnacl_c/kernel/ones_like.h", }, { "ones_like.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.cc index e51b055f122e6e86dd30ea55b6172106bfc524c6..be7212f1f58fad47163475c43916e75d33c7eddf 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.cc @@ -81,8 +81,8 @@ int PadFP32Coder::ExtendPaddings(int *paddings, int length, const int *ori_paddi int PadFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp32/pad_fp32.h", - "nnacl/pad_parameter.h", + "nnacl_c/fp32/pad_fp32.h", + "nnacl_c/pad_parameter.h", }, { "pad_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.h index 3cf3ff31b35a76966a4fb8ad413213656fc89521..30446a871c61af5ba6a7c46a4f8e852944406e4b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/fp32/pad_fp32.h" +#include "nnacl_c/fp32/pad_fp32.h" namespace mindspore::lite::micro::nnacl { class PadFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.cc index b02dc33605567ca61973831580ee2691fa9deef4..25bdc4958a64dddd0bf823410ef5e9c6e77de6a2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp32/pooling_fp32_coder.h" #include #include -#include "nnacl/fp32/pooling_fp32.h" +#include "nnacl_c/fp32/pooling_fp32.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -49,8 +49,8 @@ int PoolingFP32Coder::DoCode(CoderContext *const context) { Collect(context, { "wrapper/fp32/pooling_fp32_wrapper.h", - "nnacl/kernel/pooling.h", - "nnacl/fp32/pooling_fp32.h", + "nnacl_c/kernel/pooling.h", + "nnacl_c/fp32/pooling_fp32.h", }, { "pooling_fp32_wrapper.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.h index 37bebc00b9527b6567cd5f52b595fa7229c9fe69..79f8844d63365c8f8dbe74cc696207b0cde10e73 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/kernel/pooling.h" +#include "nnacl_c/kernel/pooling.h" namespace mindspore::lite::micro::nnacl { class PoolingFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.cc index 462d1a45dd1e29ad233c3609890754ef54f753ec..edc05ea02349e8ae2e68d8b205f26b24def68cb7 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.cc @@ -51,8 +51,8 @@ int PowerFP32Coder::DoCode(CoderContext *const context) { // generate code .h .c Collect(context, { - "nnacl/pow_parameter.h", - "nnacl/fp32/power_fp32.h", + "nnacl_c/pow_parameter.h", + "nnacl_c/fp32/power_fp32.h", }, { "power_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.h index 84fb3a393087e40f0ecb6d4c42c5da88b7b22bdb..e99ff5948ff631bc15f1c41f0664476c8b23dc8d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/pow_parameter.h" +#include "nnacl_c/pow_parameter.h" namespace mindspore::lite::micro::nnacl { class PowerFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/prelu_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/prelu_fp32_coder.cc index defb669ae2a0939c3aa60052a2b708292c71a6fa..68d7430d3bb6767278afd5323e2fc9a92a0a8d6d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/prelu_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/prelu_fp32_coder.cc @@ -15,8 +15,8 @@ */ #include "coder/opcoders/nnacl/fp32/prelu_fp32_coder.h" #include -#include "nnacl/fp32/prelu_fp32.h" -#include "nnacl/op_base.h" +#include "nnacl_c/fp32/prelu_fp32.h" +#include "nnacl_c/op_base.h" #include "coder/allocator/allocator.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" @@ -29,7 +29,7 @@ int PReluFP32Coder::DoCode(CoderContext *const context) { int count = input_tensor_->ElementsNum(); Collect(context, { - "nnacl/fp32/prelu_fp32.h", + "nnacl_c/fp32/prelu_fp32.h", }, { "prelu_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/reduce_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/reduce_fp32_coder.cc index 2970309aadbd8d359e18fb00f41066f97b8b1609..ed6127c27938174137c251e447564b5baad0f614 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/reduce_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/reduce_fp32_coder.cc @@ -50,7 +50,7 @@ int ReduceFP32Coder::ReSize() { int ReduceFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp32/reduce_fp32.h", + "nnacl_c/fp32/reduce_fp32.h", }, { "reduce_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.cc index d84d0c6057f41ae71d53ca5ea6d0c7a8c315e463..fdffada831f6b62f96d019983f4848d741a670a6 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.cc @@ -159,7 +159,7 @@ int ResizeFP32Coder::ResizePrepare() { int ResizeFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp32/resize_fp32.h", + "nnacl_c/fp32/resize_fp32.h", }, { "resize_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.h index 6654df2cd2d3c507f93d86ea3faa61e487fcd990..9375ed99f297e6adeb95524f5872a319f825ad1a 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.h @@ -22,7 +22,7 @@ #include #include #include "include/errorcode.h" -#include "nnacl/fp32/resize_fp32.h" +#include "nnacl_c/fp32/resize_fp32.h" #include "src/executor/kernel_exec.h" #include "src/litert/kernel/cpu/fp32/resize_fp32.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.cc index d4b3ca1495255b48bc98fc64c42b0da15bae89d3..e88d491210028c723076bfd98f634f845b298a4d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.cc @@ -93,9 +93,9 @@ int ScaleFP32Coder::DoCode(CoderContext *const context) { Collect(context, { "wrapper/fp32/scale_fp32_wrapper.h", - "nnacl/scale_parameter.h", - "nnacl/kernel/scale.h", - "nnacl/fp32/scale_fp32.h", + "nnacl_c/scale_parameter.h", + "nnacl_c/kernel/scale.h", + "nnacl_c/fp32/scale_fp32.h", }, { "scale_fp32_wrapper.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.h index 9a764e2ecb8cf8b025d4dcac84a0630782e20bdc..2dd048c1001f5e430ada4b6f0d150693b34aa821 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.h @@ -19,8 +19,8 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/scale_parameter.h" -#include "nnacl/kernel/scale.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/kernel/scale.h" namespace mindspore::lite::micro::nnacl { class ScaleFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.cc index 8e8d88a61074094c333c5c8031324002240d7ab5..2566297a5ee557f633e49da679391e0e0515c4c0 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.cc @@ -17,8 +17,8 @@ #include "tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.h" #include "tools/converter/micro/coder/opcoders/file_collector.h" #include "tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" -#include "nnacl/slice_parameter.h" -#include "nnacl/base/slice_base.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/base/slice_base.h" #include "coder/opcoders/parallel.h" using mindspore::schema::PrimitiveType_SliceFusion; @@ -73,7 +73,7 @@ int SliceFP32Coder::Prepare(CoderContext *const context) { int SliceFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/slice_base.h", + "nnacl_c/base/slice_base.h", "wrapper/fp32/slice_fp32_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.h index 213843238433857adb13ac340669a4514fbf7a97..37b999a23ec7d27a6c094e3a0155be82a291ebde 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "tools/converter/micro/coder/opcoders/op_coder.h" -#include "nnacl/kernel/slice.h" +#include "nnacl_c/kernel/slice.h" namespace mindspore::lite::micro::nnacl { class SliceFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/softmax_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/softmax_fp32_coder.cc index 1bf18c06c05f2a47ba2b76205ac501d26ed35ae6..e7075b55c8e30199355dd74f0b6f7d74cb835dd0 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/softmax_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/softmax_fp32_coder.cc @@ -36,8 +36,8 @@ int SoftMaxFP32Coder::Prepare(CoderContext *const context) { int SoftMaxFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp32/softmax_fp32.h", - "nnacl/fp32/log_softmax_fp32.h", + "nnacl_c/fp32/softmax_fp32.h", + "nnacl_c/fp32/log_softmax_fp32.h", }, { "softmax_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc index 4460f914b835199dee4d384d6cb5e79ac7f13c46..1529179d1aa774c16939dbed6f971c1d5b459523 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc @@ -19,7 +19,7 @@ #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" #include "src/common/log_adapter.h" -#include "nnacl/splice_parameter.h" +#include "nnacl_c/splice_parameter.h" using mindspore::schema::PrimitiveType_Splice; namespace mindspore::lite::micro::nnacl { int SpliceFP32Coder::DoCode(CoderContext *const context) { @@ -42,8 +42,8 @@ int SpliceFP32Coder::DoCode(CoderContext *const context) { } Collect(context, { - "nnacl/splice_parameter.h", - "nnacl/fp32/splice_fp32.h", + "nnacl_c/splice_parameter.h", + "nnacl_c/fp32/splice_fp32.h", }, { "splice_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc index c8778a7d032dfbd03ffc402af6bc7c20b5714bd4..6392fcd61f0677b6c31b14136915c0a9bfa7bd30 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc @@ -19,7 +19,7 @@ #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" #include "coder/utils/coder_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::schema::PrimitiveType_Split; @@ -64,7 +64,7 @@ int SplitDynamicFP32Coder::Prepare(CoderContext *const context) { } int SplitDynamicFP32Coder::DoCode(CoderContext *const context) { - Collect(context, {"nnacl/base/split_base.h"}, {"split_base.c"}); + Collect(context, {"nnacl_c/base/split_base.h"}, {"split_base.c"}); NNaclFp32Serializer code; code << " void *output_ptrs[" << output_tensors_.size() << "] = {"; for (int i = 0; i < param_->num_split_; i++) { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h index 88ca2bfed27cac4808fbe2d8f1f205a2f584ae86..efc253c0a8796e7e0d2f1d23da9ac8ba4e070ada 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h @@ -20,7 +20,7 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/dynamic_parameter/split_dynamic_parameter.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::lite::micro::nnacl { class SplitDynamicFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.cc index 433185765a26aa7f1b48584520110417f21d6ff8..88a30e33ddbfc60762878afe02a2e52b3d49d3c9 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.cc @@ -33,7 +33,7 @@ int SplitFP32Coder::Prepare(CoderContext *const context) { } int SplitFP32Coder::DoCode(CoderContext *const context) { - Collect(context, {"nnacl/base/split_base.h"}, {"split_base.c"}); + Collect(context, {"nnacl_c/base/split_base.h"}, {"split_base.c"}); if (support_parallel_) { Collect(context, {"wrapper/fp32/split_fp32_wrapper.h"}, {"split_fp32_wrapper.c"}); } diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.h index f65214c118df855a4127087d814a1c23248a5d3d..32d75f2d4a139dab87e3e60a2b8a3da8695ba767 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::lite::micro::nnacl { class SplitFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.cc index 5a663f1bcbca5241ac576c5bdee1392959499a40..da8f1e1168b7a822f8be6366463ad708c28fa0e3 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.cc @@ -47,10 +47,10 @@ int TileFP32Coder::DoCode(CoderContext *const context) { // generate code .h .c Collect(context, { - "nnacl/fp32/tile.h", + "nnacl_c/fp32/tile.h", }, { - "nnacl/fp32/tile.c", + "nnacl_c/fp32/tile.c", }); NNaclFp32Serializer code; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.h index c0627cc594cf2ba7d24ff8e2d7cdf6d3e8265ffe..294e3f1b89b5429a234bc175440ac43f0e5ef2d6 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/base/tile_base.h" +#include "nnacl_c/base/tile_base.h" namespace mindspore::lite::micro::nnacl { class TileFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc index 7fb160d5674bba582808651200092ba6ff2b31bf..1f2bade904ab1f784c85000012e4b0f705a885ca 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc @@ -42,9 +42,9 @@ int TransposeDynamicFp32Coder::Prepare(CoderContext *const context) { int TransposeDynamicFp32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/transpose_parameter.h", - "nnacl/errorcode.h", - "nnacl/fp32/transpose_fp32.h", + "nnacl_c/transpose_parameter.h", + "nnacl_c/errorcode.h", + "nnacl_c/fp32/transpose_fp32.h", }, { "transpose_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h index 9230b8e38d10ac067a1602f8930e28ecd8b74f0c..f956c28176785dd10269afbfe65a9e95f3a80b24 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h @@ -19,7 +19,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" #include "coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.cc index 97b4677572b4f0839a8261b9d652749ba5b5ba16..4bacbb201a8f29dad69012358260d6f41dd9083b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.cc @@ -107,9 +107,9 @@ int TransposeFp32Coder::DoCode(CoderContext *const context) { Collect(context, { "wrapper/fp32/transpose_fp32_wrapper.h", - "nnacl/transpose_parameter.h", - "nnacl/errorcode.h", - "nnacl/fp32/transpose_fp32.h", + "nnacl_c/transpose_parameter.h", + "nnacl_c/errorcode.h", + "nnacl_c/fp32/transpose_fp32.h", }, { "transpose_fp32_wrapper.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.h index 1b81dc26a526e499f0aa18667385df0398af31ea..737516d7d0af7bd7a38faab65fd525c3339e4c0d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.h @@ -19,7 +19,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" namespace mindspore::lite::micro::nnacl { class TransposeFp32Coder : public OperatorCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/activation_grad_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/activation_grad_coder.cc index b9d940d35e165a0fc3a6fc8a2300b4f68f9bc69e..4b867f11a054f7c27895ea88c664b54d5d2e905c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/activation_grad_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/activation_grad_coder.cc @@ -15,7 +15,7 @@ */ #include "coder/opcoders/nnacl/fp32_grad/activation_grad_coder.h" -#include "nnacl/fp32_grad/activation_grad_fp32.h" +#include "nnacl_c/fp32_grad/activation_grad_fp32.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" @@ -31,7 +31,7 @@ int ActivationGradCoder::DoCode(CoderContext *const context) { int count = input_tensor_->ElementsNum(); Collect(context, { - "nnacl/fp32_grad/activation_grad_fp32.h", + "nnacl_c/fp32_grad/activation_grad_fp32.h", }, { "activation_grad_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/adam_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/adam_coder.cc index c4427cf9e2214d6151ed21328ea7807d7ea3c116..04821d8c70d1bdd81e7248c3ecb3ed49dcbc9fd7 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/adam_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/adam_coder.cc @@ -15,7 +15,7 @@ */ #include "coder/opcoders/nnacl/fp32_grad/adam_coder.h" -#include "nnacl/fp32_grad/optimizer.h" +#include "nnacl_c/fp32_grad/optimizer.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" @@ -53,7 +53,7 @@ int AdamCoder::DoCode(CoderContext *const context) { auto *adam_param = reinterpret_cast(parameter_); Collect(context, { - "nnacl/fp32/adam_fp32.h", + "nnacl_c/fp32/adam_fp32.h", }, { "adam_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.cc index c194f0cabd1ed0eced705b0579e88f899cddfa81..4b743fcd8d3ac64875f355bf6a0f3e04f479cfe8 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.h" #include -#include "nnacl/fp32_grad/softmax_crossentropy_parameter.h" +#include "nnacl_c/fp32_grad/softmax_crossentropy_parameter.h" #include "coder/opcoders/file_collector.h" #include "schema/inner/ops_generated.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" @@ -49,8 +49,8 @@ int SoftmaxCrossEntropyWithLogitsCoder::DoCode(CoderContext *const context) { MS_CHECK_TRUE(input_tensors_.size() == DIMENSION_2D, "inputs size is not equal to two"); Collect(context, { - "nnacl/fp32/softmax_fp32.h", - "nnacl/fp32_grad/softmax_cross_entropy_with_logits.h", + "nnacl_c/fp32/softmax_fp32.h", + "nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.h", }, { "softmax_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.h index 3aea3e4d27dde32a00996385987e3f72ab6b1c68..6161589bf2cd6f28cdf299ceb4f473e63077c5a5 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CODER_H_ #include -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "coder/opcoders/op_coder.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/activation_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/activation_int8_coder.cc index 762dcd4963fdeb20c8960b832d627c9524ee49aa..38cd39b22c22b3e5195878010ded9e8d0b5ee7ad 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/activation_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/activation_int8_coder.cc @@ -18,7 +18,7 @@ #include "coder/opcoders/nnacl/int8/relux_int8_coder.h" #include "coder/opcoders/nnacl/int8/tanh_int8_coder.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" #include "schema/model_generated.h" #include "src/common/version_manager.h" #include "coder/opcoders/nnacl/int8/leaky_relu_int8_coder.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.cc index d9732ca62ccabf0845d04757b32a99948dd81a77..266a2f32d53e90a622e3792cb90f6b49521b889a 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.cc @@ -23,7 +23,7 @@ #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" #include "coder/utils/common.h" #include "mindspore/ops/op_def/array_ops.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" using mindspore::schema::PrimitiveType_AddFusion; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.h index 06809bd8d41eb818040ac1e3c34a8cedf2b9df7f..1b8a832e7847e06aae39b65f59c6dbf1444c2a55 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/add_int8.h" +#include "nnacl_c/int8/add_int8.h" namespace mindspore::lite::micro::nnacl { class AddInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/affine_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/affine_int8_coder.h index 97acf295502c31370d860d9db6329e6546afe1de..09eec6802d3335d5ddb136910909ca8f26c3d1c8 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/affine_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/affine_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/affine_parameter.h" +#include "nnacl_c/affine_parameter.h" #include "tools/converter/micro/coder/wrapper/base/affine_wrapper.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.cc index bf8a074d642961249d7e364d184d1a6edb7f62f3..b31fc592bd78fdde65c1160da72ec2e6be12c9e1 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.cc @@ -114,7 +114,7 @@ int ArithmeticSelfInt8Coder::Prepare(CoderContext *context) { int ArithmeticSelfInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/arithmetic_self_int8.h", + "nnacl_c/int8/arithmetic_self_int8.h", }, { "arithmetic_self_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.h index 5dcf1373281da6473f7b1be22ce6a401ce4fd3b2..3f7695ff61709aca46e24b5f4b9b251f19f672f5 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.h @@ -20,9 +20,9 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/arithmetic_self_int8.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/arithmetic_self_int8.h" +#include "nnacl_c/arithmetic_self_parameter.h" namespace mindspore::lite::micro::nnacl { class ArithmeticSelfInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.cc index 59a3df56576bcafbbef4f58de83954686011131e..9616f46901a4b9934c1dba9149074b968f721a2f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.cc @@ -54,8 +54,8 @@ int BatchNormInt8Coder::DoCode(CoderContext *context) { Collect(context, { - "nnacl/slice_parameter.h", - "nnacl/kernel/batch_norm.h", + "nnacl_c/slice_parameter.h", + "nnacl_c/kernel/batch_norm.h", }, { "batchnorm_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.h index 6cba41bc397b21a75dbfc7f779e871d10da85306..a38e7fc7ed37f8a139e843c8f735cb9d22b99372 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" namespace mindspore::lite::micro::nnacl { class BatchNormInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.cc index 5aea116d04799b3a202ba10cb4704b072bfc7019..41b2eb5d1371d542c6fe7f6900a5e3c30b1a4c87 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.cc @@ -16,9 +16,9 @@ #include "coder/opcoders/nnacl/int8/concat_int8_coder.h" #include -#include "nnacl/int8/concat_int8.h" -#include "nnacl/int8/arithmetic_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/concat_int8.h" +#include "nnacl_c/int8/arithmetic_int8.h" +#include "nnacl_c/int8/quantize.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -100,7 +100,7 @@ int ConcatInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/concat_int8.h", + "nnacl_c/int8/concat_int8.h", "wrapper/int8/concat_int8_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.h index 46a8bdb67432d60bc36661d3c24948b2421c2ca4..25c80fd774fdfab85a0112a8e5ee0480baa4b2f5 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/concat_int8.h" +#include "nnacl_c/int8/concat_int8.h" #include "wrapper/int8/concat_int8_wrapper.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc index 14727188731f9a52d9396571ec6a443eb2d8c0d8..a8d11f51de9d3b6f3085aa9c40621176237efdf1 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc @@ -48,12 +48,12 @@ int Conv2D1x1Int8Coder::DoCode(CoderContext *const context) { "wrapper/int8/conv1x1_init_int8_wrapper.h", "wrapper/int8/conv1x1_run_int8_wrapper.h", "wrapper/base/micro_parameter.h", - "nnacl/common_func.h", - "nnacl/base/conv1x1_base.h", - "nnacl/int8/matmul_int8.h", - "nnacl/int8/pack_int8.h", - "nnacl/int8/conv1x1_int8.h", - "nnacl/errorcode.h", + "nnacl_c/common_func.h", + "nnacl_c/base/conv1x1_base.h", + "nnacl_c/int8/matmul_int8.h", + "nnacl_c/int8/pack_int8.h", + "nnacl_c/int8/conv1x1_int8.h", + "nnacl_c/errorcode.h", }, { "common_func.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.h index eff155ff84308834c5c37cc13f48b5d156222250..97b6843e031add57aa0022734a9430cfd821ea0b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "wrapper/base/micro_parameter.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc index 9ebddef3a714c213f88133103e31054f1b0885ec..8e8b594d1d53f49a609f07883dbd06e5d2e46934 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc @@ -17,7 +17,7 @@ #include "coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.h" #include #include "include/securec.h" -#include "nnacl/int8/conv3x3_int8.h" +#include "nnacl_c/int8/conv3x3_int8.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -128,8 +128,8 @@ int Conv2D3x3Int8Coder::Prepare(CoderContext *const context) { int Conv2D3x3Int8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/conv_int8.h", - "nnacl/int8/conv3x3_int8.h", + "nnacl_c/int8/conv_int8.h", + "nnacl_c/int8/conv3x3_int8.h", }, { "pack_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.h index bd96e3e2c8537702bfcd591ea851b8f26adb5a21..d521ff1fc7c979f7406c82dc38bc818869c92a1d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro::nnacl { class Conv2D3x3Int8Coder final : public Conv2DBaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.cc index a519519981fc6d31974e254a20ac31146c485f60..99479ebee8d8d5c0ed5222db618f1cb851897fd5 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.cc @@ -214,8 +214,8 @@ int Conv2DINT8Coder::DoCode(CoderContext *const context) { } Collect(context, { - "nnacl/int8/conv_int8.h", - "nnacl/common_func.h", + "nnacl_c/int8/conv_int8.h", + "nnacl_c/common_func.h", "wrapper/int8/convolution_int8_wrapper.h", "wrapper/base/common_wrapper.h", "wrapper/base/optimize_handler_wrapper.h", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.h index ddfc69130f1e8d8499d3e0dd18bffdf4b4721b0c..bab8a56aaff683311c04a73d838ceabd6fe65ab2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/base/conv2d_base_coder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/convolution_depthwise_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/convolution_depthwise_int8_coder.cc index 96a3e1e4250aa4d7b6ac474019af59547d5d9226..7b8b71b24f85e451d2276ac083dcab3593b3e7d2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/convolution_depthwise_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/convolution_depthwise_int8_coder.cc @@ -19,7 +19,7 @@ #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" -#include "nnacl/int8/conv_depthwise_int8.h" +#include "nnacl_c/int8/conv_depthwise_int8.h" namespace mindspore::lite::micro { int ConvolutionDepthwiseINT8Coder::Prepare(CoderContext *const context) { @@ -90,8 +90,8 @@ int ConvolutionDepthwiseINT8Coder::DoCode(CoderContext *const context) { "Only support input channel equals output channel."); Collect(context, { - "nnacl/int8/conv_depthwise_int8.h", - "nnacl/int8/pack_int8.h", + "nnacl_c/int8/conv_depthwise_int8.h", + "nnacl_c/int8/pack_int8.h", "wrapper/int8/convolution_depthwise_int8_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.cc index 27fdc3f520d0d01ad2b772549606f47aef579ada..0dfdc993eb7889792041462116f6f9d6db4294dd 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/int8/deconvolution_int8_coder.h" #include -#include "nnacl/int8/deconv_int8.h" +#include "nnacl_c/int8/deconv_int8.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" @@ -125,7 +125,7 @@ int DeconvolutionInt8Coder::InitRunBuf(CoderContext *const context) { int DeconvolutionInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/deconv_int8.h", + "nnacl_c/int8/deconv_int8.h", }, { "deconv_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.h index c8ae3ccb89162fc7d6c24a92758bba83d475e996..0acd4f17d7f918af1c80002ed7ca15623334a3d2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/base/conv2d_base_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { class DeconvolutionInt8Coder final : public Conv2DBaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/detection_post_process_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/detection_post_process_int8_coder.cc index 00783220b69d814e4ea674c2f9313590ba98dc3a..23f3d45bde0c779093107e772b1bc81f961f70ed 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/detection_post_process_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/detection_post_process_int8_coder.cc @@ -45,7 +45,7 @@ int DetectionPostProcessInt8Coder::GetInputData(CoderContext *const context, Ser Collect(context, { - "nnacl/int8/quant_dtype_cast_int8.h", + "nnacl_c/int8/quant_dtype_cast_int8.h", }, { "quant_dtype_cast_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.cc index e6b86e7c5e4d98bbd4e32f3e77ad61533abe76ca..83a967bf14d5c601d7f273310767c2a121222fa4 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.cc @@ -53,7 +53,7 @@ int DivInt8Coder::Prepare(CoderContext *context) { int DivInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/div_int8.h", + "nnacl_c/int8/div_int8.h", }, { "div_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.h index 6cb9cb913efb55ab2b7c3b70a704781ee230f473..4a8859e0cc7a0206280f5379829f5df24de721d4 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::lite::micro::nnacl { class DivInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.cc index fdefd4a52391514442d1a6177da6e0fb613e3466..c45d3ceaa10247a0058e703ee3bca568c870096c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.cc @@ -15,7 +15,7 @@ */ #include "coder/opcoders/nnacl/int8/fullconnection_int8_coder.h" -#include "nnacl/int8/matmul_int8.h" +#include "nnacl_c/int8/matmul_int8.h" #include "coder/log.h" using mindspore::schema::PrimitiveType_FullConnection; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.h index c5946a0cd686ddc1983048fd80587e567a1c0453..642f520b7800249f1122cdd0320fb101e3dd4744 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.h @@ -21,8 +21,8 @@ #include #include #include "coder/opcoders/nnacl/int8/matmul_base_int8_coder.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { class FullConnectionInt8Coder final : public MatMulBaseInt8Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.cc index 59491e61e4e93e9ac1cad672dc6c66c3707774eb..ed3159470c6a150f3ad7cf55a4e68a15297099f4 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.cc @@ -50,7 +50,7 @@ int GatherInt8Coder::Prepare(CoderContext *context) { int GatherInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/gather_int8.h", + "nnacl_c/int8/gather_int8.h", }, { "gather_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.h index ee4f4c355700190fd9271826d565b99cb5ee3afc..f3143efdfb084d1859fd43aa7ccdf030b46438b5 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.h @@ -20,9 +20,9 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/gather_int8.h" -#include "nnacl/gather_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/gather_int8.h" +#include "nnacl_c/gather_parameter.h" namespace mindspore::lite::micro::nnacl { class GatherInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.cc index f04a6fd591b41bd779e83acdffedf3b082eb66fa..fbfbc46685e6b3ca2f2c6853ca0f9a491607f611 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.cc @@ -50,7 +50,7 @@ int LeakyReluInt8Coder::Prepare(CoderContext *context) { int LeakyReluInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/leaky_relu_int8.h", + "nnacl_c/int8/leaky_relu_int8.h", }, { "leaky_relu_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.h index f76ec591879788acda0b50207803ada29bd23b6e..bd5fadf9e183f5a6fddbe47b4f43e82027fe267b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.h @@ -20,9 +20,9 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/leaky_relu_int8.h" -#include "nnacl/activation_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/leaky_relu_int8.h" +#include "nnacl_c/activation_parameter.h" namespace mindspore::lite::micro::nnacl { class LeakyReluInt8Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc index 6cf6c17f720fdb586e4044f614443bd46aa55233..aee53e9550105d957b5468cd3c20500dbe54e555 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc @@ -274,11 +274,11 @@ void MatMulBaseInt8Coder::DoBatchCode(NNaclInt8Serializer *code_ptr) { int MatMulBaseInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/common_func.h", - "nnacl/int8/common_func_int8.h", - "nnacl/int8/matmul_int8.h", - "nnacl/int8/fixed_point.h", - "nnacl/int8/relux_int8.h", + "nnacl_c/common_func.h", + "nnacl_c/int8/common_func_int8.h", + "nnacl_c/int8/matmul_int8.h", + "nnacl_c/int8/fixed_point.h", + "nnacl_c/int8/relux_int8.h", "wrapper/int8/matmul_int8_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h index 28a9f2716b733c41cd222a1603a8a256434dee81..10c5d2cbed57dd3e5d88fd1cd80e3e43e4023b93 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.h index bec68df5b2e0143df43ec1299f87c08128b10af8..5c8548b1813e3a652a22674e6b2d4aa3b777ada2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.h @@ -20,7 +20,7 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/int8/matmul_base_int8_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { class MatMulInt8Coder final : public MatMulBaseInt8Coder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.cc index 2474f2419fbd5981c008af51c2df8038c6acfcb2..3734306488ee27ea11164b20ef3cbbc1505cf9f4 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.cc @@ -205,7 +205,7 @@ int PadInt8Coder::HandleMirrorPad() { int PadInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/pad_int8.h", + "nnacl_c/int8/pad_int8.h", }, { "pad_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.h index 3cd0b85afd836f5b6c218ad2e5383cb117c49ad9..8a3305faa08853a66607aa0ee068f2fe637ae5bc 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.h @@ -20,9 +20,9 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/pad_int8.h" -#include "nnacl/pad_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/pad_int8.h" +#include "nnacl_c/pad_parameter.h" namespace mindspore::lite::micro::nnacl { class PadInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.cc index 1e639308e5a89e6b01c5faa0f904458e390e7894..75c7491cbc3cb76b791909fce37017a95d1a5c29 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/int8/pooling_int8_coder.h" #include #include -#include "nnacl/int8/pooling_int8.h" +#include "nnacl_c/int8/pooling_int8.h" #include "coder/log.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" #include "coder/opcoders/file_collector.h" @@ -54,9 +54,9 @@ int PoolingInt8Coder::DoCode(CoderContext *const context) { std::vector out_quant_args = out_tensor->quant_params(); Collect(context, { - "nnacl/int8/pooling_int8.h", - "nnacl/kernel/pooling.h", - "nnacl/errorcode.h", + "nnacl_c/int8/pooling_int8.h", + "nnacl_c/kernel/pooling.h", + "nnacl_c/errorcode.h", }, { "pooling_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.h index d255f77dfdcd8ff52b4300cca115401f7d1d7988..9400bde3542d8dab55ca15176877676002e057f2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.h @@ -21,7 +21,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/kernel/pooling.h" +#include "nnacl_c/kernel/pooling.h" namespace mindspore::lite::micro::nnacl { class PoolingInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/prelu_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/prelu_int8_coder.h index aacfe99d9046e0f421f5fcfb9c459d9bbb600196..c5281c57f0e6b0ba72663af57851332212a797e3 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/prelu_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/prelu_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" #include "coder/opcoders/nnacl/int8/leaky_relu_int8_coder.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.cc index 6e77ded3e870e87b6182edd0eebabd5f63e71ad5..68d549829d997260b99c5558b7e74cf489d74984 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.cc @@ -235,8 +235,8 @@ int ReduceInt8Coder::DoCode(CoderContext *const context) { if (axes_hw_pattern_) { Collect(context, { - "nnacl/int8/pack_int8.h", - "nnacl/int8/reduce_int8.h", + "nnacl_c/int8/pack_int8.h", + "nnacl_c/int8/reduce_int8.h", }, { "pack_int8.c", @@ -256,7 +256,7 @@ int ReduceInt8Coder::DoCode(CoderContext *const context) { } else { Collect(context, { - "nnacl/int8/reduce_int8.h", + "nnacl_c/int8/reduce_int8.h", }, { "reduce_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.h index 1fbf1efc77ce593c46261309183c59d2e98ddfde..03bc752aebfc0eb03f0d016b4a871a4aacc2468c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.h @@ -20,8 +20,8 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/reduce_int8.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/reduce_int8.h" #include "coder/opcoders/base/reduce_base_coder.h" namespace mindspore::lite::micro::nnacl { class ReduceInt8Coder final : public ReduceBaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.cc index 5ddf04b101619a99f5ab6258c3fb0f00f6fa7267..0f67e4a8c7b62ce95f28da923928c4a43700fd5b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.cc @@ -15,7 +15,7 @@ */ #include "coder/opcoders/nnacl/int8/relux_int8_coder.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" @@ -41,7 +41,7 @@ int ReluxInt8Coder::Prepare(CoderContext *const context) { int ReluxInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/relux_int8.h", + "nnacl_c/int8/relux_int8.h", }, { "relux_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.h index aded1729fbbb625515925a65ba5c38e790703371..79ce41466fe0fd1cccbeb30187e42b288c5f9183 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.h @@ -21,7 +21,7 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/utils/common.h" -#include "nnacl/int8/relux_int8.h" +#include "nnacl_c/int8/relux_int8.h" #include "coder/log.h" #include "include/errorcode.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reshape_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reshape_int8_coder.cc index 1533307bc8a7cd1f92a407f93a07d6c371ee50d1..2b2fb9d202baa9be31d171e4d1312610c0ea5b2d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reshape_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reshape_int8_coder.cc @@ -35,7 +35,7 @@ int ReshapeInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/reshape_int8.h", + "nnacl_c/int8/reshape_int8.h", }, { "reshape_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.cc index f2fea6dc7ecec38e4166895949a934095a50fe1e..6084353246bf1a975d3adceacea0c73e3319502e 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.cc @@ -19,7 +19,7 @@ #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" #include "coder/opcoders/file_collector.h" #include "include/securec.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" #include "coder/opcoders/parallel.h" using mindspore::schema::PrimitiveType_Resize; @@ -67,7 +67,7 @@ int ResizeInt8Coder::ReSize() { int ResizeInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/resize_int8.h", + "nnacl_c/int8/resize_int8.h", "wrapper/int8/resize_int8_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.h index 9b7e6a8482294795354b7c5370f8df04939a5b7b..3300de10df327e4b58401164abb1df5c2b36c69b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/base/resize_base_coder.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite::micro::nnacl { class ResizeInt8Coder final : public ResizeBaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sigmoid_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sigmoid_int8_coder.cc index 6f29562a1cc517821574e41bb13e3d78ba3a3abf..eea8127fec58ba5880595b9d496f106c2491cc2f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sigmoid_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sigmoid_int8_coder.cc @@ -55,7 +55,7 @@ int SigmodInt8Coder::Prepare(CoderContext *const context) { int SigmodInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/sigmoid_int8.h", + "nnacl_c/int8/sigmoid_int8.h", }, { "sigmoid_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/softmax_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/softmax_int8_coder.cc index 452a452e90815d2db34f2acb19b4a76519ed8051..492e1d81133f43c092d86ecc16635eadcd2696ab 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/softmax_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/softmax_int8_coder.cc @@ -19,7 +19,7 @@ #include #include #include "schema/inner/ops_generated.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "coder/log.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" #include "coder/opcoders/file_collector.h" @@ -72,7 +72,7 @@ int SoftMaxInt8Coder::DoCode(CoderContext *const context) { "n_dim should be less than the length of maximum value of input_shape"); Collect(context, { - "nnacl/int8/softmax_int8.h", + "nnacl_c/int8/softmax_int8.h", }, { "softmax_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.cc index e08f4549ca71692dcedf98e1fc7da0c35eed71b9..cbf0e03ee667086d2433feffd68985062ec6c4ef 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.cc @@ -16,6 +16,7 @@ #include "coder/opcoders/nnacl/int8/sub_int8_coder.h" #include +#include #include "include/errorcode.h" #include "coder/log.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" @@ -73,7 +74,7 @@ int SubInt8Coder::Prepare(CoderContext *const context) { } int SubInt8Coder::DoCode(CoderContext *const context) { - Collect(context, {"nnacl/int8/arithmetic_int8.h", "nnacl/int8/sub_int8.h"}, {"arithmetic_int8.c", "sub_int8.c"}); + Collect(context, {"nnacl_c/int8/arithmetic_int8.h", "nnacl_c/int8/sub_int8.h"}, {"arithmetic_int8.c", "sub_int8.c"}); NNaclInt8Serializer code; // Todo: Parallel run wrapper auto element_num = output_tensor_->ElementsNum(); diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.h index a616c143932416c370a6a848e7f50e51df653f99..79aa9ae362f0ebcbab829c89514fd083ce52a5d2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::lite::micro::nnacl { class SubInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/tanh_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/tanh_int8_coder.cc index 6704d9ee305b1634c25428aea51afe350fbefed7..23f8451628a476788e80b8cd78b6e575902b376f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/tanh_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/tanh_int8_coder.cc @@ -21,7 +21,7 @@ #include "include/errorcode.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" -#include "nnacl/int8/tanh_int8.h" +#include "nnacl_c/int8/tanh_int8.h" namespace mindspore::lite::micro::nnacl { int TanhInt8Coder::Prepare(CoderContext *const context) { return RET_OK; } @@ -29,7 +29,7 @@ int TanhInt8Coder::Prepare(CoderContext *const context) { return RET_OK; } int TanhInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/tanh_int8.h", + "nnacl_c/int8/tanh_int8.h", }, {"tanh_int8.c", "activation_fp32.c"}); diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.cc index 43a19abc68ba0f0b01599cfc9ab8a7b132cf3f9c..9fa23bb7530c3af69a290b1f19b7d57d5e6eca39 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.cc @@ -49,7 +49,7 @@ int TransposeInt8Coder::Prepare(CoderContext *const context) { } int TransposeInt8Coder::DoCode(CoderContext *const context) { - Collect(context, {"nnacl/int8/pack_int8.h", "nnacl/int8/transpose_int8.h"}, {"pack_int8.c", "transpose_int8.c"}); + Collect(context, {"nnacl_c/int8/pack_int8.h", "nnacl_c/int8/transpose_int8.h"}, {"pack_int8.c", "transpose_int8.c"}); NNaclInt8Serializer code; auto out_shape = output_tensors_[0]->shape(); diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.h index 12e06f865021a9faea2e75df095ff06a7210254e..0c49a6f359fefe4769efd473b498cdbadf91dc0b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.h @@ -18,7 +18,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" namespace mindspore::lite::micro::nnacl { class TransposeInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h index 4d27c7008a7a6bd91a3f63a6eeb581466e4a4d3b..8201f56ed872d7d14b771c13684b1b51770c90f2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h @@ -20,43 +20,43 @@ #include #include #include "coder/opcoders/serializers/serializer.h" -#include "nnacl/batchnorm_parameter.h" -#include "nnacl/fp32/arithmetic_fp32.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "wrapper/base/micro_parameter.h" -#include "nnacl/scale_parameter.h" -#include "nnacl/slice_parameter.h" -#include "nnacl/split_parameter.h" -#include "nnacl/transpose_parameter.h" -#include "nnacl/base/tile_base.h" -#include "nnacl/fp32/transpose_fp32.h" -#include "nnacl/pooling_parameter.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/splice_parameter.h" -#include "nnacl/lstm_parameter.h" -#include "nnacl/group_norm_parameter.h" -#include "nnacl/activation_parameter.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/transpose_parameter.h" +#include "nnacl_c/base/tile_base.h" +#include "nnacl_c/fp32/transpose_fp32.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/splice_parameter.h" +#include "nnacl_c/lstm_parameter.h" +#include "nnacl_c/group_norm_parameter.h" +#include "nnacl_c/activation_parameter.h" #include "wrapper/fp32/dequant_int8_to_fp32_wrapper.h" -#include "nnacl/fp32/exp_fp32.h" -#include "nnacl/fp32/strided_slice_fp32.h" -#include "nnacl/tensor_c.h" +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/fp32/strided_slice_fp32.h" +#include "nnacl_c/tensor_c.h" #include "wrapper/fp32/arithmetic_fp32_wrapper.h" #include "wrapper/base/affine_wrapper.h" #include "wrapper/fp32/conv_winograd_fp32_wrapper.h" -#include "nnacl/instance_norm_parameter.h" -#include "nnacl/layer_norm_parameter.h" -#include "nnacl/broadcast_to_parameter.h" -#include "nnacl/custom_gru_parameter.h" -#include "nnacl/unstack_parameter.h" -#include "nnacl/kernel/scale.h" -#include "nnacl/kernel/pooling.h" -#include "nnacl/kernel/layer_norm.h" -#include "nnacl/kernel/fill.h" -#include "nnacl/kernel/batch_norm.h" -#include "nnacl/kernel/tile.h" -#include "nnacl/kernel/slice.h" -#include "nnacl/kernel/strided_slice.h" +#include "nnacl_c/instance_norm_parameter.h" +#include "nnacl_c/layer_norm_parameter.h" +#include "nnacl_c/broadcast_to_parameter.h" +#include "nnacl_c/custom_gru_parameter.h" +#include "nnacl_c/unstack_parameter.h" +#include "nnacl_c/kernel/scale.h" +#include "nnacl_c/kernel/pooling.h" +#include "nnacl_c/kernel/layer_norm.h" +#include "nnacl_c/kernel/fill.h" +#include "nnacl_c/kernel/batch_norm.h" +#include "nnacl_c/kernel/tile.h" +#include "nnacl_c/kernel/slice.h" +#include "nnacl_c/kernel/strided_slice.h" #include "coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h" #include "coder/opcoders/nnacl/dynamic_parameter/dynamic_lstm_parameter.h" #include "coder/opcoders/nnacl/dynamic_parameter/slice_dynamic_parameter.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h index 245c11416f79b050392f196041c6b33c60562e21..dd016f6e7e3b055837e5ec3aab7b52d152e0eae9 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h @@ -18,26 +18,26 @@ #include #include #include "wrapper/base/affine_wrapper.h" -#include "nnacl/pooling_parameter.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "coder/opcoders/serializers/serializer.h" -#include "nnacl/op_base.h" -#include "nnacl/int8/add_int8.h" -#include "nnacl/int8/arithmetic_int8.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/add_int8.h" +#include "nnacl_c/int8/arithmetic_int8.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "wrapper/base/micro_parameter.h" -#include "nnacl/int8/concat_int8.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/reshape_parameter.h" -#include "nnacl/slice_parameter.h" -#include "nnacl/batchnorm_parameter.h" -#include "nnacl/pad_parameter.h" -#include "nnacl/transpose_parameter.h" -#include "nnacl/int8/relux_int8.h" +#include "nnacl_c/int8/concat_int8.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/reshape_parameter.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" +#include "nnacl_c/pad_parameter.h" +#include "nnacl_c/transpose_parameter.h" +#include "nnacl_c/int8/relux_int8.h" #include "wrapper/int8/concat_int8_wrapper.h" -#include "nnacl/kernel/pooling.h" -#include "nnacl/kernel/batch_norm.h" +#include "nnacl_c/kernel/pooling.h" +#include "nnacl_c/kernel/batch_norm.h" namespace mindspore::lite::micro::nnacl { class NNaclInt8Serializer : public Serializer { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.cc index b4e2e2edf8ac99ff673af45fd3f0cccd76022eca..c70505921de8904e9460e8cdecd15d55f38d9381 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.cc @@ -16,11 +16,11 @@ #include #include -#include "nnacl/pooling_parameter.h" -#include "nnacl/slice_parameter.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/int8/add_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/int8/add_int8.h" +#include "nnacl_c/int8/quantize.h" #include "coder/opcoders/parallel.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.h b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.h index 6a219becfbddf4f92d51650cde66cd6aea437ab8..ae06d2877361a101de0a076f29bb844a3af29774 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.h @@ -18,12 +18,12 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_SERIALIZERS_NNACL_SERIALIZER_NNACL_STREAM_UTILS_H_ #include #include -#include "nnacl/op_base.h" -#include "nnacl/pooling_parameter.h" -#include "nnacl/slice_parameter.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/int8/add_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/int8/add_int8.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::lite::micro { std::ostream &operator<<(std::ostream &code, const ::QuantArg &quant_arg); diff --git a/mindspore-lite/tools/converter/micro/coder/utils/type_cast.h b/mindspore-lite/tools/converter/micro/coder/utils/type_cast.h index d8c03ab69ae53d6b7d7124ccfe012e938e784ac7..4c6b1a0d30650811d07356135533724f0f7552e4 100644 --- a/mindspore-lite/tools/converter/micro/coder/utils/type_cast.h +++ b/mindspore-lite/tools/converter/micro/coder/utils/type_cast.h @@ -25,7 +25,7 @@ #include "ir/dtype/type_id.h" #include "include/api/format.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/micro/coder/config.h" #include "base/float16.h" diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/base/affine_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/base/affine_wrapper.h index c0a3e6e6a660befb7527036a4b3e26c27caab94e..1cad479a6ac28c1f566ff607dae0ace9593484d5 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/base/affine_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/base/affine_wrapper.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_AFFINE_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_AFFINE_WRAPPER_H_ -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/base/common_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/base/common_wrapper.h index c1c6e9aa88cc44948fdb77b65ccfbfd8316c567b..73349ceea40e56139d36cfa7713041a3cf3e0ca8 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/base/common_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/base/common_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_COMMON_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_COMMON_WRAPPER_H_ -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" bool GetSupportOptFlag(); #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_COMMON_WRAPPER_H_ diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/base/micro_parameter.h b/mindspore-lite/tools/converter/micro/coder/wrapper/base/micro_parameter.h index 3763bcdfb626b8bcbd7676a38f6de5cd47279626..e8bf7eecf4ad74e08d7ed83e608a281528e52cb4 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/base/micro_parameter.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/base/micro_parameter.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_MICRO_PARAMETER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_MICRO_PARAMETER_H_ -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" typedef struct { ActType act_type_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/base/optimize_handler_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/base/optimize_handler_wrapper.h index 8d3000f0299652f17bf1b8fb26994feebe111160..8835550aea4f9f13cdb96af91478e66f6c2f87f1 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/base/optimize_handler_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/base/optimize_handler_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_OPTIMIZE_HANDLER_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_OPTIMIZE_HANDLER_WRAPPER_H_ -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #ifdef ENABLE_ARM64 void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.c index 2a3f17158c7d7c2da055d75ffa705de7418b7e13..60a543270cbc3ee357a4c1df1e94c3f98e5f23d8 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/base/strided_slice_wrapper.h" -#include "nnacl/fp32/strided_slice_fp32.h" +#include "nnacl_c/fp32/strided_slice_fp32.h" int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *param) { StridedSliceStruct strided_slice; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.h index 8beb4051969e9677d0dbd28543c55c90ca6d1432..b9d7ede02cc4bca6970779921b089a0dcd8dd46b 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_BASE_STRIDED_SLICE_WRAPPER_H_ #define MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_BASE_STRIDED_SLICE_WRAPPER_H_ #include -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.c index 07e0f1037f214d43e25ba76cfc408c56e7cf3433..4dec9396b2e568db7b54b03dc3645b23eca40e78 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/fp32/activation_fp32_wrapper.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/errorcode.h" int DoSigmoid(void *cdata, int task_id, float lhs_scale, float rhs_scale) { ActivationFp32Args *args = (ActivationFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.h index 64a13b0e864f72a079667c3970e95d1a71ef1c2b..d4eb98c0c99a7833854ad767a000ec4d18a14ec0 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_ACTIVATION_FP32_WRAPPER_H_ #define MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_ACTIVATION_FP32_WRAPPER_H_ #include -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" typedef struct { const float *input_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/arithmetic_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/arithmetic_fp32_wrapper.h index 7c4d3001ee58c48ff4f48d0263aaa346381203a8..3e58197e757c1c150e8cbc70e201d90d497eea1e 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/arithmetic_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/arithmetic_fp32_wrapper.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_ARITHMETIC_FP32_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_ARITHMETIC_FP32_WRAPPER_H_ -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" #include #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/concat_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/concat_fp32_wrapper.c index 144a1e93efd239ab09533a1709e194773670616d..6052ffca157a40c0a1f33998790c5a1fb29deb2f 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/concat_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/concat_fp32_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/fp32/concat_fp32_wrapper.h" -#include "nnacl/errorcode.h" -#include "nnacl/base/concat_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/base/concat_base.h" int DoConcatRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { ConcatFp32Args *args = (ConcatFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.c index e1dabe5a5c6465032a3ec9b491c0ecec1da696da..7f2810c6b88565e270f5c610587ed6442c442c74 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.c @@ -16,8 +16,8 @@ #include "wrapper/fp32/conv_fp32_wrapper.h" #include -#include "nnacl/errorcode.h" -#include "nnacl/fp32/conv_common_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp32/conv_common_fp32.h" int ConvFp32Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) { ConvFp32Args *args = (ConvFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.h index 4ee4ecb962afdbe361a6a41dab1e668e4c185772..96aab4e739667c13a2ad1b5ba27e01fc62e2a2f9 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_CONV_FP32_WRAPPER_H_ #define MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_CONV_FP32_WRAPPER_H_ -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.c index 69a431e5b6d8872be773670030c3c2217c1108f0..9153ee3d185603650d3e2dcff1a3dbc3685f42b9 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/fp32/conv_winograd_fp32_wrapper.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" int ConvWinogradFp32Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) { ConvWinogradFp32Args *args = (ConvWinogradFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.h index fe91a2881761421db4adce68517da0d25d9acb57..f71bc64521fffd2b1822c7f2090d753324d80497 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.h @@ -15,8 +15,8 @@ */ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_CONV_WINOGRAD_FP32_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_CONV_WINOGRAD_FP32_WRAPPER_H_ -#include "nnacl/fp32/winograd_utils.h" -#include "nnacl/fp32/conv_winograd_fp32.h" +#include "nnacl_c/fp32/winograd_utils.h" +#include "nnacl_c/fp32/conv_winograd_fp32.h" #ifdef __cplusplus #include typedef struct TransFuncStr { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.c index 917effc89f00a8e4d6e78cb6e9cea2439ce56b88..4a438edbc5dcf8363f7e8dcb8d4d2ae29f517eef 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/fp32/deconvolution_fp32_wrapper.h" -#include "nnacl/fp32/deconv_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/deconv_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" int DoDeconvFp32(const float *packed_input, const float *packed_weight, const float *packed_bias, float *packed_output, float *output, float *tmp_ori_buffer, const MicroMatmulParameter *matmul_param, diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.h index 64d18ef2418ecdbf36c28de50d281ac4854d104c..78bade2affec9aa4ed93b730cf82d76c37eef861 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.h @@ -17,9 +17,9 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_DECONVOLUTION_FP32_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_DECONVOLUTION_FP32_WRAPPER_H_ -#include "nnacl/errorcode.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "wrapper/base/micro_parameter.h" typedef struct { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/fill_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/fill_fp32_wrapper.c index e102fe0c4bfb1d258bfaf820597e30d18e5d527e..bb7827bd491fd0459dfc2884a495325fc9148711 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/fill_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/fill_fp32_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/fp32/fill_fp32_wrapper.h" -#include "nnacl/errorcode.h" -#include "nnacl/base/fill_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/base/fill_base.h" int DoFillFp32(void *cdata, int task_id, float lhs_scale, float rhs_scale) { FillFp32Args *args = (FillFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.c index a0a2dd95f2f17554662b9116fc4a0eef9c195c60..75551a98f9039141665743ade39ef5c5ded1d24b 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/fp32/matmul_fp32_wrapper.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" void InitMatrixA(const float *src_ptr, float *dst_ptr, const MicroMatmulParameter *params_, bool is_vector_a) { if (is_vector_a) { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.h index 4303a6ce37fa6799039a970873a5847a73d6415b..c3b34d57cf830c69677048a61b08566d484e163b 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_MATMUL_FP32_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_MATMUL_FP32_WRAPPER_H_ #include -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" #include "wrapper/base/micro_parameter.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.c index e6253af1bcd70ff65d043da6a4959bcd555c075f..61705faf528744572654cf8f3b20e71777d7ddaf 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/fp32/pooling_fp32_wrapper.h" -#include "nnacl/fp32/pooling_fp32.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp32/pooling_fp32.h" +#include "nnacl_c/errorcode.h" int DoMaxPooling(void *cdata, int task_id, float lhs_scale, float rhs_scale) { PoolingFp32Args *args = (PoolingFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.h index 34062d3cfb5b52151fb9cdea91d8bc3349eb394d..72326f74b9a423e36f5601ae9602771f7b6f3862 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_POOLING_FP32_WRAPPER_H_ #define MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_POOLING_FP32_WRAPPER_H_ #include -#include "nnacl/fp32/pooling_fp32.h" +#include "nnacl_c/fp32/pooling_fp32.h" typedef struct { const float *input_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.c index 69ff70fc22b8e353c3d8a6f6fbfd61ba4d878485..2ffd06eea1dfaa6040a87057a0f291ade40f6c29 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/fp32/scale_fp32_wrapper.h" -#include "nnacl/fp32/scale_fp32.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp32/scale_fp32.h" +#include "nnacl_c/errorcode.h" int DoScaleReluRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { ScaleFp32Args *args = (ScaleFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.h index 7fd713876b3d5052465a8ff77e0605c206e82cdf..5bc688862114e65b836c7ec0a68f36c3c707aa36 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_SCALE_FP32_WRAPPER_H_ #define MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_SCALE_FP32_WRAPPER_H_ #include -#include "nnacl/fp32/scale_fp32.h" +#include "nnacl_c/fp32/scale_fp32.h" typedef struct { const float *input_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.c index 6227fbd07645c94189c42adc1f832d24bc516b6a..68bb194e92a6d17ccb755dce1abfcb2c66d363c5 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/fp32/slice_fp32_wrapper.h" -#include "nnacl/base/slice_base.h" +#include "nnacl_c/base/slice_base.h" int DoSliceRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { SliceFp32Args *args = (SliceFp32Args *)(cdata); diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.h index 33aa514dc464206d4267c466f82df2e8bbd4c88b..f1a829c7452f0edacc9c2c8f23e79f424a5f43b8 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.h @@ -18,8 +18,8 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_SLICE_FP32_WRAPPER_H_ #include -#include "nnacl/slice_parameter.h" -#include "nnacl/kernel/slice.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/kernel/slice.h" typedef struct { float *input_data_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.c index 8e42305ab7b4769b637f714d0b59a97a866c6654..dae81be527a1f1c3079c5099b8b079c00400ee19 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/fp32/split_fp32_wrapper.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" int DoSplitRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { SplitFp32Args *args = (SplitFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.h index 9309be5056db4eb2ea956c9ea4f13d81fa47ba11..e6bc6557a0414fc6d4ca955525a5fc69583cba91 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_SPLIT_FP32_WRAPPER_H_ #define MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_SPLIT_FP32_WRAPPER_H_ #include -#include "nnacl/base/split_base.h" +#include "nnacl_c/base/split_base.h" typedef struct { const void *input_ptr_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.c index b75da08af023b3ed9dd809e8a606d5ec22fd0cab..88de73b301f624218313afe826874c4efedaf907 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.c @@ -16,8 +16,8 @@ #include "wrapper/fp32/transpose_fp32_wrapper.h" #include -#include "nnacl/fp32/pack_fp32.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/errorcode.h" int DoTransposeNCHWToNHWC(void *cdata, int task_id, float lhs_scale, float rhs_scale) { TransposeFp32Args *args = (TransposeFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.h index 61462ff154aca65b014e16a15c096a8e685c01de..b63a9b123101dcf72aa38ad6e2a704914737ff1b 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.h @@ -17,8 +17,8 @@ #ifndef MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_TRANSPOSE_FP32_WRAPPER_H_ #define MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_TRANSPOSE_FP32_WRAPPER_H_ #include -#include "nnacl/fp32/transpose_fp32.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/fp32/transpose_fp32.h" +#include "nnacl_c/transpose_parameter.h" typedef struct { const void *input_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.c index df8287b6acc1bbbf2ad4ee48e8327596586f42df..aac9a5818680bb6423a1e1ddd2cd21c1e5d8ce89 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/int8/add_int8_wrapper.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" int AddBroadcastInt8Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) { AddInt8Args *args = (AddInt8Args *)(cdata); diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.h index ad4e09c6c2fc44bd72e8e0a873633b0a721665df..8411709ae6869026a88ccb1a0c1d7c1a144c013f 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.h @@ -17,9 +17,9 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_ADD_INT8_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_ADD_INT8_WRAPPER_H_ #include -#include "nnacl/int8/matmul_int8.h" -#include "nnacl/int8/add_int8.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/int8/add_int8.h" +#include "nnacl_c/arithmetic_parameter.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.c index 33777ee41d11ddf17c1e556bb0e9ef6405014fe9..7d8494d57cb6d60e7fbdd4eaddd95beaeca55ed2 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/int8/batchnorm_int8_wrapper.h" -#include "nnacl/int8/batchnorm_int8.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/int8/batchnorm_int8.h" +#include "nnacl_c/errorcode.h" int BatchNormInt8Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) { BatchNormArgs *args = (BatchNormArgs *)(cdata); diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.h index cac6f1ec01f88e29eac060775c6b8adc948c4a2d..28c349cd2117cdd586b7015f1cac7020ddcda8b2 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_BATCHNORM_INT8_WRAPPER_H_ #include -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" typedef struct BatchNormArgs { int8_t *in_addr_; int8_t *out_addr_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/concat_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/concat_int8_wrapper.h index cc23a3896bfbc8672c0e2ca4a1c10bc1052a02af..a019056ac7911bf2ca6ab34f997cc9c85efaf40a 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/concat_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/concat_int8_wrapper.h @@ -17,9 +17,9 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONCAT_INT8_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONCAT_INT8_WRAPPER_H_ -#include "nnacl/errorcode.h" -#include "nnacl/concat_parameter.h" -#include "nnacl/int8/concat_int8.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/concat_parameter.h" +#include "nnacl_c/int8/concat_int8.h" typedef struct { int8_t **inputs_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.c index e7c604e36a67694435ff800e6490056be42bc482..f9ffb0a79d1e7ac264e7542b9d3df2139e170d0e 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/int8/conv1x1_init_int8_wrapper.h" -#include "nnacl/int8/matmul_int8.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/errorcode.h" size_t Conv1x1PackWeightSize(int32_t input_channel, int32_t output_channel, bool support_optimize) { size_t size = 0; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.h index 1a721d57e2be160b181834712e48e18373086043..17b411ec45da0693fae747b4ec5754cc7de8e9a1 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.c index f9df30f2a45e8216f121e29160e1b74db566441c..b8a676475bd2c5084b464716d4f2338eb48b7a0e 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.c @@ -15,11 +15,11 @@ */ #include "wrapper/int8/conv1x1_run_int8_wrapper.h" -#include "nnacl/base/conv1x1_base.h" -#include "nnacl/int8/matmul_int8.h" -#include "nnacl/int8/pack_int8.h" -#include "nnacl/int8/conv1x1_int8.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/base/conv1x1_base.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/int8/pack_int8.h" +#include "nnacl_c/int8/conv1x1_int8.h" +#include "nnacl_c/errorcode.h" void Pre1x1Trans(Conv1x1Args *args, int8_t *src_input, int8_t *src_output) { args->output_ptr_ = src_output; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.h index 6a51d9606375cc632229ac9fef3d0f8f68feaed2..f1c940706047e86f78ba44094f5bfd465ca85821 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.h @@ -19,8 +19,8 @@ #include #include -#include "nnacl/conv_parameter.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "wrapper/base/micro_parameter.h" typedef struct { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv3x3_run_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv3x3_run_int8_wrapper.h index d2dad072a3ca075d001a58044a20641b548dbb77..1663d987837ff46993f87d8983f8a32ff06fc91f 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv3x3_run_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv3x3_run_int8_wrapper.h @@ -17,9 +17,9 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONV3X3_RUN_INT8_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONV3X3_RUN_INT8_WRAPPER_H_ -#include "nnacl/errorcode.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/int8/conv3x3_int8.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/int8/conv3x3_int8.h" typedef struct { int16_t *input_data; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv_init_int8_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv_init_int8_wrapper.c index d7454de57493d52847504ffb662c0c78884bf0c8..2cf29d3cbaf539d8471d399566db7fe75ad4e08a 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv_init_int8_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv_init_int8_wrapper.c @@ -15,9 +15,9 @@ */ #include "wrapper/int8/conv_init_int8_wrapper.h" -#include "nnacl/op_base.h" -#include "nnacl/int8/matmul_int8.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/errorcode.h" size_t ConvPackWeightSize(int input_channel, int output_channel, int kernel_plane, bool support_optimize) { size_t up_round_deep; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_depthwise_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_depthwise_int8_wrapper.h index cd5c9e8921c10dee2ad8ed082ac5c41871eefe97..21493fe8e5c97fa1645622a3afa302e999c4a2e0 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_depthwise_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_depthwise_int8_wrapper.h @@ -17,9 +17,9 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONVOLUTION_DEPTHWISE_INT8_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONVOLUTION_DEPTHWISE_INT8_WRAPPER_H_ -#include "nnacl/errorcode.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/int8/conv_depthwise_int8.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/int8/conv_depthwise_int8.h" typedef struct { int8_t *output_data_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_int8_wrapper.h index 00a4ce9f2131341809e9ca94f1003bd514a1d24e..166cbbef55a7f77ed2d824a5e3e1bb5aab970dcb 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_int8_wrapper.h @@ -17,10 +17,10 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONVOLUTION_INT8_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONVOLUTION_INT8_WRAPPER_H_ -#include "nnacl/errorcode.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/int8/conv_int8.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/int8/conv_int8.h" typedef struct { int8_t *input_data_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/matmul_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/matmul_int8_wrapper.h index ea2f340f7352093abe945ecc5f78d155932d471e..87666ed56b32cd1475c638054d8fc3b49b32c286 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/matmul_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/matmul_int8_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_MATMUL_INT8_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_MATMUL_INT8_WRAPPER_H_ #include -#include "nnacl/int8/matmul_int8.h" +#include "nnacl_c/int8/matmul_int8.h" #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.c index 81dc157856d90914d6a08f3f4384456b6c66ae96..f82f18afa6b055fc7f782b6e7dc58a0de6d37b77 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/int8/resize_int8_wrapper.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" int ResizeInt8Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) { ResizeInt8Args *args = (ResizeInt8Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.h index f66d244c86ed462675aee4ce9efa6c15b5b5b237..5cbb1620b072319d9275144acbb78d662338e48e 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_RESIZE_INT8_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_RESIZE_INT8_WRAPPER_H_ -#include "nnacl/int8/resize_int8.h" +#include "nnacl_c/int8/resize_int8.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.c index 5581bdc074f5fffcfd5d39de01c6a2d4bc010ae0..348e39d8a1aab4074968cee571b798b6ed846075 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/int8/slice_int8_wrapper.h" -#include "nnacl/int8/slice_int8.h" +#include "nnacl_c/int8/slice_int8.h" int SliceInt8Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) { SliceArgs *args = (SliceArgs *)(cdata); diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.h index e593ef7b6128f1a7c99df13db48befd6a30488f9..b4e703d7399de0c7dad0c8fc3ab1d3006759e9bb 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.h @@ -18,8 +18,8 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_SLICE_INT8_WRAPPER_H_ #include -#include "nnacl/slice_parameter.h" -#include "nnacl/kernel/slice.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/kernel/slice.h" typedef struct SliceArgs { int8_t *input_data_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/thread/micro_core_affinity.c b/mindspore-lite/tools/converter/micro/coder/wrapper/thread/micro_core_affinity.c index 5a452add1e24c81e0a8056dae6c6804dfd77cad6..c388fec06ca248edff306f0f842d6336383cbe04 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/thread/micro_core_affinity.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/thread/micro_core_affinity.c @@ -25,7 +25,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" int GetCpuCoreNum() { int core_num = 1; diff --git a/mindspore-lite/tools/converter/micro/providers/nnie/nnie_micro.h b/mindspore-lite/tools/converter/micro/providers/nnie/nnie_micro.h index b4698a2f705b6e24e095855b0bc8b34031520e10..1ae97a2d519a2dc44152f634676da707d7b83ae7 100644 --- a/mindspore-lite/tools/converter/micro/providers/nnie/nnie_micro.h +++ b/mindspore-lite/tools/converter/micro/providers/nnie/nnie_micro.h @@ -17,8 +17,8 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_PROVIDERS_NNIE_NNIE_MICRO_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_PROVIDERS_NNIE_NNIE_MICRO_H_ -#include "nnacl/custom_parameter.h" -#include "nnacl/tensor_c.h" +#include "nnacl_c/custom_parameter.h" +#include "nnacl_c/tensor_c.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/tools/converter/offline_packing_optimizer.cc b/mindspore-lite/tools/converter/offline_packing_optimizer.cc index f7d610cb02d78c6ab8c4bdf0859f2146d78742e5..19b0ca13e57aff52c99cb4c973bbaaf91fbacff9 100644 --- a/mindspore-lite/tools/converter/offline_packing_optimizer.cc +++ b/mindspore-lite/tools/converter/offline_packing_optimizer.cc @@ -27,7 +27,7 @@ #include "src/common/primitive_t_utils.h" #include "src/common/ops/anf_utils.h" #include "src/common/file_utils.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/converter/ops/while.cc b/mindspore-lite/tools/converter/ops/while.cc index 49f662230d3a3c95d183998695bcaba6e179048b..15ef4c018ccf04414951d5fe0b88eb03439a3415 100644 --- a/mindspore-lite/tools/converter/ops/while.cc +++ b/mindspore-lite/tools/converter/ops/while.cc @@ -19,7 +19,7 @@ #include "tools/converter/ops/while.h" #include "utils/check_convert_utils.h" #include "abstract/ops/primitive_infer_map.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/optimizer_manager.cc b/mindspore-lite/tools/converter/optimizer_manager.cc index 44d66ddab8420641f5e88dd8bd5c665d0fe20c71..e567ac04e43a7c17222447ddf81017a48dbcb279 100644 --- a/mindspore-lite/tools/converter/optimizer_manager.cc +++ b/mindspore-lite/tools/converter/optimizer_manager.cc @@ -24,7 +24,7 @@ #include "src/common/log_util.h" #include "tools/converter/parser/parser_utils.h" #include "include/registry/pass_base.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_activation_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_activation_parser.cc index 8681801fdfef9015a29ce3956efadfbfbc8d30f3..69309b2f0c554e3620c77dff391733d68215780d 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_activation_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_activation_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_activation_parser.h" #include #include "infer/cxx_api/activation.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/utils.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_argmax_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_argmax_parser.cc index a7e34625d6281a98b7ff4beb61bd8f05173cfabb..38a52abed3a569348580a97609a00e9794e54e08 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_argmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_argmax_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_argmax_parser.h" #include #include "infer/cxx_api/arg_max_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc index 8dd36c92fbcc05df4c267ed028f2729113b6f9e9..102f78e5907122982a1f4343939c6cefc11ed2bf 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc @@ -21,7 +21,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "ops_utils/op_utils.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_concat_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_concat_parser.cc index 59e417836381f8496d90a48e33dc8a43ec33d910..fc5cf48f55a5d0d47fb3be5d20a012c5d9423cf2 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_concat_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_concat_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_concat_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_conv_base_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_conv_base_parser.cc index 6d7495931b6efad03b731abe75c44d8f59beb673..766d41b4a43421332f4d02e887e099817dcfa73e 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_conv_base_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_conv_base_parser.cc @@ -16,7 +16,7 @@ #include "tools/converter/parser/caffe/caffe_conv_base_parser.h" #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_convolution_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_convolution_parser.cc index 379ea42c9f9e0bf345c151ea77f1640c87257e79..ac8591e5348ccf0e06b7bef5ac9e1584736f93f3 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_convolution_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_convolution_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_convolution_parser.h" #include #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_crop_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_crop_parser.cc index 3de900b0e2b090123f5dbbd89e64e07fdb6c831c..f41c923692871dc40191216415477104469a4521 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_crop_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_crop_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_crop_parser.h" #include #include "infer/crop.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc index ab415486d0cc65ca582fcae8d90d95db31b867e6..56c2566ac46e65252795dba0093eac215d3acf5a 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc @@ -18,7 +18,7 @@ #include #include "infer/cxx_api/conv2d_transpose_fusion.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_name.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc index 2fe0e8389f8880b527a725de956f6705b6b99045..e66c0f8b2a301f26cd8601c39868200845a4f684 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/eltwise.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_exp_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_exp_parser.cc index f9adf5a1785cb967b1a60a80d0e9ba28f0e36838..8a9c5c0967522dfe0ee7526b3917a5e2f6871c07 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_exp_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_exp_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/exp_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_flatten_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_flatten_parser.cc index 3c711bd5f8d5edfbcd1ffa4399436b5ff21d09ea..927a43e028634bc03108398393261cffeda1a193 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_flatten_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_flatten_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_flatten_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc index 882f3542f0b0f04b8ca3d34b58c32fa52bf4f220..c0f7440904638b4b8f0b204faf94537701107b61 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc @@ -18,7 +18,7 @@ #include #include "infer/cxx_api/full_connection.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_interp_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_interp_parser.cc index 8cc1d71c96f34797dc9aab04e280993beaef2fd3..2e944ba43f613ff8cb95948523fc1b96d4015fee 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_interp_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_interp_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_interp_parser.h" #include #include "infer/resize.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_model_parser.cc index ce29d1d9e1f3c43f16b244ade2b1ea6efce120e0..f2cc41e7db4da40d4a67800d1ef5699ed0930c9e 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -32,7 +32,7 @@ #include "tools/converter/parser/lite_model_parser_creator.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/parser/unify_format.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "infer/make_tuple.h" #include "infer/return.h" diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_node_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_node_parser.cc index f11c024fb902dc29d1031f72822cafc0f57c89c2..ec47182b0b3a9ce34d47ca303f8c612a0242edc7 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_node_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_node_parser.cc @@ -19,7 +19,7 @@ #include "include/securec.h" #include "ir/dtype/type_id.h" #include "src/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_permute_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_permute_parser.cc index 0ab3ce1e10127d6e71a0182b685ae5b368a108e5..a997b25d139b1ff19e1c3e8fa1303271994aff3d 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_permute_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_permute_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_permute_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_pooling_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_pooling_parser.cc index d4ff92714076d6d304e472ebe614e8aa3f1e3191..80973264c2a904b31ac843f8ffb2b342eb2075e1 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_pooling_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_pooling_parser.cc @@ -20,7 +20,7 @@ #include "infer/cxx_api/max_pool_fusion.h" #include "ops_utils/op_utils.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_power_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_power_parser.cc index 386716b6479fb3f55ac31559a75c5075dd8ca55a..a0ef44ed58287aeb565231e0a2199483a94f3daa 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_power_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_power_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/pow_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_prelu_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_prelu_parser.cc index e0ce4d03a5d4c9aafd8c205ce61b4113cddc75c6..109bc83fcc04d75ca13a2b72bb37eecc05d249d1 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_prelu_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_prelu_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_prelu_parser.h" #include #include "infer/cxx_api/prelu_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_quantize_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_quantize_parser.cc index 8ccceb31aedd7db584068f7e549f442c36b18028..1e5ca99b123df7b8b13820d8f6e4ea03545f0b5f 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_quantize_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_quantize_parser.cc @@ -18,7 +18,7 @@ #include "tools/converter/parser/caffe/caffe_quantize_parser.h" #include #include "infer/quant_dtype_cast.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_reduce_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_reduce_parser.cc index 4c4f88e1a672d3c820067bca8f046078e7ef52f8..7923ae842ca7f8e8cd415c25a5cb008d1a4aacea 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_reduce_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_reduce_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/reduce_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_reshape_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_reshape_parser.cc index fa31d9fd0bb5b7e92470cd86dd0a3c69d4c0f5c2..ecc748693e0ca556070a259ad35e01d58708d1bc 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_reshape_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_reshape_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_reshape_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_scale_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_scale_parser.cc index ac96c8c3677dc0c96914e946317fc5fe79d15a1f..fbe30ac4caca70e6157216898b9eb10d5f9eabed 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_scale_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_scale_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_scale_parser.h" #include #include "infer/cxx_api/scale_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_slice_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_slice_parser.cc index 47ef93d37dee9ea84198701b24a19f24469020ad..01c461cbd06c1f52bedbd90c6f7306ce2b8f9659 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_slice_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_slice_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_slice_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_softmax_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_softmax_parser.cc index 6f19e643103f4c618477709ea0c448f9d36c6e53..f1ac88339a8c20ec7240fdc6a535bd25d98ccff9 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_softmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_softmax_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_softmax_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_tile_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_tile_parser.cc index ad3c89931d5fd597f85b084c8df0d07713586faa..79b0fb98741bf6691e8e966534b0c18b8b09ded1 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_tile_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_tile_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/tile_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_upsample_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_upsample_parser.cc index 2bc24c15e3e5496d0485cd567794a6b33d081df1..ca174823f297455eaa84bb262e7ff1e0c8471a82 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_upsample_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_upsample_parser.cc @@ -19,7 +19,7 @@ #include #include "infer/resize.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/conv1d_inout_adjust.cc b/mindspore-lite/tools/converter/parser/conv1d_inout_adjust.cc index e3db1d30a025f9d4fec8aeecf7da84a8f04b9a66..ace278a0502c8c0bf87a06db24e38c27a691b535 100644 --- a/mindspore-lite/tools/converter/parser/conv1d_inout_adjust.cc +++ b/mindspore-lite/tools/converter/parser/conv1d_inout_adjust.cc @@ -29,7 +29,7 @@ #include "infer/unsqueeze.h" #include "ops/primitive_c.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/converter/parser/conv2d_transpose_input_adjust.cc b/mindspore-lite/tools/converter/parser/conv2d_transpose_input_adjust.cc index 2ca880c69d8c2b67cb844caf9d7e6614158c7094..a027d2d5bd081da0f40067548bacd01daf237781 100644 --- a/mindspore-lite/tools/converter/parser/conv2d_transpose_input_adjust.cc +++ b/mindspore-lite/tools/converter/parser/conv2d_transpose_input_adjust.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/conv2d_transpose_input_adjust.h" #include "mindspore/ops/op_def/lite_ops.h" #include "tools/converter/parser/parser_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/optimizer/common/gllo_utils.h" #include "infer/cxx_api/conv2d_transpose_fusion.h" #include "mindspore/ops/op_def/op_name.h" diff --git a/mindspore-lite/tools/converter/parser/einsum_adjust.cc b/mindspore-lite/tools/converter/parser/einsum_adjust.cc index d7ea8c309344cbbb17faf4eab1dac2c56ed7070b..6e2bf9477862b50f0ddc438840f6ce978ec255a2 100644 --- a/mindspore-lite/tools/converter/parser/einsum_adjust.cc +++ b/mindspore-lite/tools/converter/parser/einsum_adjust.cc @@ -24,7 +24,7 @@ #include "tools/converter/ops/ops_def.h" #include "tools/optimizer/common/gllo_utils.h" #include "src/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/unsqueeze.h" namespace mindspore::lite { diff --git a/mindspore-lite/tools/converter/parser/inputs_adjust.cc b/mindspore-lite/tools/converter/parser/inputs_adjust.cc index 056a65411a7fb7d308c7b2e2989ba363cfe6aa12..2181204f91b020fa6df4f51f6a754913aad93426 100644 --- a/mindspore-lite/tools/converter/parser/inputs_adjust.cc +++ b/mindspore-lite/tools/converter/parser/inputs_adjust.cc @@ -19,7 +19,7 @@ #include "mindspore/ops/op_def/lite_ops.h" #include "mindspore/ops/op_def/array_ops.h" #include "ops/primitive_c.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" diff --git a/mindspore-lite/tools/converter/parser/lstm_adjust_pass.cc b/mindspore-lite/tools/converter/parser/lstm_adjust_pass.cc index 896b1d375ab0b67a7e7c6a0b8b5d4f146ad568de..91dc6b487165c8d8b5ca9156cd53e88eddcb2b4f 100644 --- a/mindspore-lite/tools/converter/parser/lstm_adjust_pass.cc +++ b/mindspore-lite/tools/converter/parser/lstm_adjust_pass.cc @@ -25,7 +25,7 @@ #include "tools/lite_exporter/fetch_content.h" #include "tools/common/tensor_util.h" #include "utils/check_convert_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_activation_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_activation_parser.cc index 67061df9a0d055acc3687dc3b52901f348c2aa65..8b4a15df92e81bc0d2c59499fa02725e475d5cde 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_activation_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_activation_parser.cc @@ -21,7 +21,7 @@ #include "infer/cxx_api/prelu_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/activation.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/softplus.h" #include "infer/selu.h" #include "infer/ops_func_impl/hswish.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_adder_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_adder_parser.cc index e58aff5351e8e962e657cc84f2556eb69bf0216e..f20c9e9e08a75e586c55bc220260f2f37d846f2e 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_adder_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_adder_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_adder_parser.h" #include #include "infer/cxx_api/adder_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_argmax_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_argmax_parser.cc index 73b99356685dc7ac60bf8b4bd964056b2f75364d..8e27b2db723bef99dbab5744c152291eae473832 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_argmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_argmax_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_argmax_parser.h" #include #include "infer/cxx_api/arg_max_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_argmin_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_argmin_parser.cc index 6aa23f157386dda4274e1ace718fffbac7a3c862..e53d880c7a74ba69b6cf377a33203b99433756ff 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_argmin_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_argmin_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_argmin_parser.h" #include #include "infer/cxx_api/arg_min_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc index b2ecc78af3d9851aa10b44ad0beecce2283d3be9..e04c7666d989f00fc631704f9331d0a5b9a459dc 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc @@ -27,7 +27,7 @@ #include "infer/cxx_api/pow_fusion.h" #include "infer/eltwise.h" #include "infer/mod.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc index 24f94ffd5c4c0bd431d7f9671a2e326ca4346ed1..df91c890c9b20597102149782f323f28a0a014b8 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_batchnorm_parser.h" #include #include "infer/fused_batch_norm.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc index 29cb2b10c4053878646b93a2dd8937ba741fe4bc..4033c8f694269cbe5a15f8cd6d044bf20005de9a 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_biasadd_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_cast_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_cast_parser.cc index cc8382d46d0278d2340ba6ea7dd09434145bf9d2..9e6927fe040f1123f3329fa629644e259359a41e 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_cast_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_cast_parser.cc @@ -18,7 +18,7 @@ #include "tools/converter/parser/onnx/onnx_model_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_clip_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_clip_parser.cc index 7922ef6ef4a0dd471decee846dc462a1d072485f..1b64bc56182368452bf5a19694e913ef0420ffa0 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_clip_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_clip_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_clip_parser.h" #include #include "infer/clip.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_col2im_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_col2im_parser.cc index 90ec2fdd7f02515517d202a252f7b0f2a736d337..6e64a2afd591a5a3382b8b00b1d2ff4f54217f30 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_col2im_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_col2im_parser.cc @@ -18,7 +18,7 @@ #include #include #include "tools/converter/parser/onnx/onnx_model_parser.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/col2im.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_concat_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_concat_parser.cc index ca44f90d12b3e13d03064201b6aec9812a1b79e6..d85a26707a3617af8dabe068204cf9d8b3897c66 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_concat_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_concat_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_concat_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc index f797a75b11d129a7e04903b1abb93d6040c0f25b..f9492918101a4af98c9134b20af1c519a4147c8e 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc @@ -19,7 +19,7 @@ #include #include "tools/converter/parser/onnx/onnx_model_parser.h" #include "infer/constant_of_shape.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc index 9c3b0608d14b893d1cbdcde5d747b2b3e81b9357..cb8974cd978ac3420f0fd40ac199e04e802c4935 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc @@ -24,7 +24,7 @@ #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/ops/ops_def.h" #include "tools/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_conv2d_add_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_conv2d_add_parser.cc index 282a6dc844a49b716741e9510cf9d28697784088..88dac606ee6c4f1f747d97baa0f0281872e5a39b 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_conv2d_add_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_conv2d_add_parser.cc @@ -19,7 +19,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/custom.h" #include "tools/optimizer/common/gllo_utils.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_conv_base_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_conv_base_parser.cc index c95dd15f3b221813a365a3007cb273c78c2c6c41..03d3452269b944909e57d758de913185769d1a01 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_conv_base_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_conv_base_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore::lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_conv_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_conv_parser.cc index 9db3f0ac9bc3b62dc01e7c4a16c18c7634c8ad3d..f8939a3d5fbb511482d9594f4c96304efcf47dd5 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_conv_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_conv_parser.cc @@ -20,7 +20,7 @@ #include #include #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/infer/conv3d.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_conv_transpose_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_conv_transpose_parser.cc index 97e6fdc1a9ef7db410cac4f096d28cba85e8568b..426949ba9e27846f5d1788bda4ea2c8fad307ddb 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_conv_transpose_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_conv_transpose_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/cxx_api/conv2d_transpose_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_adjust.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_adjust.cc index 406bb813011f7924f8be3e2ba9ae469216631462..f7b58108ee1f50c35312299418fbf8402943be96 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_adjust.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_adjust.cc @@ -28,7 +28,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/uniform_real.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/common/tensor_util.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/common/node_util.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_parser.cc index 5582e396b61fc0821d5eca077b2026b035d66abd..9f72fb567335e5437bedbe450cce98ee11ac0dd8 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_parser.cc @@ -20,7 +20,7 @@ #include "tools/converter/parser/onnx/onnx_model_parser.h" #include "tools/converter/ops/ops_def.h" #include "mindspore/ops/op_def/op_name.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/affine_grid.h" #include "infer/histogram.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_deform_conv2d_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_deform_conv2d_parser.cc index 8df0be873c5e2c780a4872f58f567a9f3e7c0a39..902f6c0b0bee25beb6ce6ae5fe6dc7861aeb88f0 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_deform_conv2d_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_deform_conv2d_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/deformable_conv2d.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { namespace { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc index c0385750723af5228fd96c8461ca4d991cf98920..f4b23ad2cac2f3295721a23072985a00cedefcb2 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_depth_to_space_parser.h" #include #include "infer/depth_to_space.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_dequantize_linear_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_dequantize_linear_parser.cc index 4d99d25815cb9f3cb63f35b061b29dd5fe97e684..d59c5c2ba5afb8edfbc1ba4ac0c16ebd02ac9059 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_dequantize_linear_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_dequantize_linear_parser.cc @@ -16,7 +16,7 @@ #include "tools/converter/parser/onnx/onnx_dequantize_linear_parser.h" #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/ops/ops_def.h" #include "ops_utils/op_utils.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_dropout_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_dropout_parser.cc index c7fa69d17c69ffcc2a558cbff9517694252a1b13..f352c2116716a7a421264537e3e05e8a5a065f51 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_dropout_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_dropout_parser.cc @@ -18,7 +18,7 @@ #include #include "infer/ops_func_impl/dropout.h" #include "op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_einsum_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_einsum_parser.cc index a31948e612f44aa83633adfd7f680d6eb17384de..03479a51799533c7fe287b59a78bb4c256a5e3c2 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_einsum_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_einsum_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_einsum_parser.h" #include #include "tools/converter/ops/ops_def.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_erf_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_erf_parser.cc index 75f6a5135b13fd3978373fd6e10ef6e58dabe6ca..496d70343bf2faa9a3bbf7d8a8fec61b42c877ee 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_erf_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_erf_parser.cc @@ -16,7 +16,7 @@ #include "tools/converter/parser/onnx/onnx_erf_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_expand_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_expand_parser.cc index a4ee9ef3b5b5308d7990f4c63aa12a29e3ab7c01..1c07eb301896339e6977f1da518d42638ad820ba 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_expand_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_expand_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_flash_attention_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_flash_attention_parser.cc index bc1a86a8cddbb1eea921b71473b1abb6b05b5813..afd6f01ca3194f9fa1329b157ed175405f5b5d4c 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_flash_attention_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_flash_attention_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/custom.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_gather_element_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_gather_element_parser.cc index 67e156b7c16aa6eb43475c9ce79a4c846d9ce035..b43b6f2661383df58f6bd911b263df169bc31168 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_gather_element_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_gather_element_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_gather_element_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_gather_nd_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_gather_nd_parser.cc index 4f8f4a4c01a152350222e0643fd2112ac0354d5f..12e52f65567a902a407eab375f1c34faadb23f9b 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_gather_nd_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_gather_nd_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_gather_nd_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_gather_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_gather_parser.cc index bbd7c6b636c8e05eef9a02b3d0943aa04c4e1865..9bc2ef1aa290761bc5a21918dbe7441e99e37e15 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_gather_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_gather_parser.cc @@ -18,7 +18,7 @@ #include #include "infer/ops_func_impl/gather.h" #include "op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc index d004683bf2042c09caa24bfe11c54da2e008fa10..eee72d530ebfab3450cfa65d4d3c1fc702772da7 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc @@ -21,7 +21,7 @@ #include #include "tools/common/tensor_util.h" #include "tools/converter/ops/ops_def.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample3d_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample3d_parser.cc index 5b545a9c037fcd0e212a59313405d53cb3189750..3ebbb2c44fec16c1bc0c5a01d0a8512b1a19edaa 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample3d_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample3d_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_gridsample3d_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_enum.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample_parser.cc index 7cec2ead4fc3fdcddc648a23f28c3bab5924d46f..0fd0fab7df1aa7874689a7b9bff53b2be144d36b 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_gridsample_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_enum.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_gru_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_gru_parser.cc index 4f85a0583628883cf01f8b9808a8fbc438c8936e..2a83e42f42a2fbf512e51ab8d59b7e5214e6c666 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_gru_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_gru_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/gru.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" #include "mindspore/ops/infer/grad/gru_v2_grad.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_hardswish_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_hardswish_parser.cc index 8d0258903232e8135aa6d49440b9de12ca9907f1..17254fb4b9efa90bc850ce4474dbecf358481511 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_hardswish_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_hardswish_parser.cc @@ -18,7 +18,7 @@ #include #include "infer/ops_func_impl/hswish.h" #include "op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_identity_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_identity_parser.cc index 8fc860954462971d377df1e8b8807c297e4db452..a4c345ec32d87fce4d2a783e146ab9e54f2bcc02 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_identity_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_identity_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_if_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_if_parser.cc index 60066820e62ab9a5ec70bc4dcacfc39fd451b2de..9046619384d68b831c874d1a512c36efa88b88e5 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_if_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_if_parser.cc @@ -18,7 +18,7 @@ #include #include "tools/converter/parser/onnx/onnx_model_parser.h" #include "tools/converter/ops/ops_def.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc index dee5ed4cd5cab2b50b4a754bfb6a085800acc552..ab244232c2f868c533ac2283084eebdb25a71924 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc @@ -30,7 +30,7 @@ #include "infer/multinomial.h" #include "infer/affine_grid.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/common/tensor_util.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/common/node_util.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc index ac7fe6b94e5eff9b7520346dc5c443fbcf4c2e85..e86e5b0bef6aaf08e763b814bb71e7f3fb5785df 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_instance_norm_parser.h" #include #include "infer/instance_norm.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_layer_norm_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_layer_norm_parser.cc index ecb1354ed3de5d677c08c58c508170ec24c4ab93..270c5bad0a6664d2f27681480efed1c01c72b257 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_layer_norm_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_layer_norm_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_layer_norm_parser.h" #include #include "infer/cxx_api/layer_norm_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_less_or_equal_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_less_or_equal_parser.cc index 5c9643f4b41195e20f494698b6dad4c93db2b817..aab992911a5b6a1e01bd2cc95236129f82033ff7 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_less_or_equal_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_less_or_equal_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_less_or_equal_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_log_softmax_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_log_softmax_parser.cc index 7d77ae5488a3af24104f9f18a316b9ace358497d..5a3da48d83562e7c0b9c2e0ddd4a3e018bb22441 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_log_softmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_log_softmax_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_log_softmax_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_loop_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_loop_parser.cc index dec3f2dd3a9e30ff290078a1f548747ac18c1f21..642d12e4cba2c54a0975b2566c7d3fc03072becc 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_loop_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_loop_parser.cc @@ -18,7 +18,7 @@ #include #include "tools/converter/parser/onnx/onnx_model_parser.h" #include "tools/converter/ops/while.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc index e9502e2656a01efe373f2e6cd2bf2e0c3a11fd70..c3182f7f49d4c2dddcab88bf570d26a55881853f 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_lp_norm_parser.h" #include #include "infer/lp_normalization.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_lrn_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_lrn_parser.cc index 1e2fce3dc75730a2e55a52ff6695bb6e7e2da21d..557ce9f062ad2a56becf4c221e48ffd47cc9184d 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_lrn_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_lrn_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_lrn_parser.h" #include #include "infer/lrn.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_lstm_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_lstm_parser.cc index 0d5f29765cd286208037e35a9d12e4fd0cb012a0..de163a068261a70e1420efc37311ffbaa6cbada4 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_lstm_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_lstm_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_lstm_parser.h" #include #include "infer/lstm.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_matmul_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_matmul_parser.cc index b9017ca3fe95f224cc11fcb5287160cc6a829bd9..771a1d14d20da766439b5e8c725833371e6e2734 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_matmul_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_matmul_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_matmul_parser.h" #include #include "infer/cxx_api/mat_mul_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/utils.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc index 3366376a29b3c70ee082f97a0f155a06a8f18752..b9ff3795cbca2047960c13cc25b878070fc69db2 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -25,7 +25,7 @@ #include "include/registry/node_parser_registry.h" #include "ir/func_graph.h" #include "mindspore/ops/op_def/nn_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/make_tuple.h" #include "infer/return.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc index 7d914229a9bf92a6d37fb6c5856b7a673a804aeb..746590611380a32f715a5f541ce0bee6788233e6 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc @@ -20,7 +20,7 @@ #include #include #include "tools/converter/parser/onnx/onnx_model_parser.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/file_utils.h" #include "utils/ms_utils_secure.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc index ec29bfe57933a3976bc6358e74aa76d64bc2f172..d94ddd4c3ebf5faebb40a3a99595f0f2db93c86a 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_non_max_suppression_parser.h" #include #include "infer/non_max_suppression.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_nonzero_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_nonzero_parser.cc index db618ee1fc8eea1186b2bb274fb502307d478dbe..a3e2d237675ae7f98c08930bde5697f8b8c0471c 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_nonzero_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_nonzero_parser.cc @@ -18,7 +18,7 @@ #include #include "tools/converter/parser/onnx/onnx_model_parser.h" #include "infer/where.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_onehot_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_onehot_parser.cc index 6c58678129d421b14bf79b545de68c60b9dc0a05..4704e189097e8033d4ba3a859373ec81c8e967ba 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_onehot_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_onehot_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_onehot_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_pad_adjust.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_pad_adjust.cc index 3462848894b3ccc9285e9f22023d4cb55fb7b8ff..c68219f426ee908a7a4732c4480e8ed010e6fea4 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_pad_adjust.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_pad_adjust.cc @@ -22,7 +22,7 @@ #include "ops/primitive_c.h" #include "tools/common/tensor_util.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" namespace mindspore::lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_pad_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_pad_parser.cc index fadbe8c462d9ba72a3b808e2862577bb2518a6ca..b4d8e5cca5df8e42c9ed1d599f2b8b504c0bf4a4 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_pad_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_pad_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/pad_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc index c69f228477748d4e96da1a79c40ffa9704524661..314b50e0f123cfa4f25639bbd47ea1707f244b44 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc @@ -20,7 +20,7 @@ #include "infer/cxx_api/avg_pool_fusion.h" #include "infer/cxx_api/max_pool_fusion.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_name.h" #include "mindspore/ops/ops_utils/op_constants.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_prompt_flash_attention_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_prompt_flash_attention_parser.cc index 6620572b2a93c6e6a8d1ef73ce25f2e795d8d75b..36a1dbe1c6d8ea6c7ec60b265f3c4516f35af0cf 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_prompt_flash_attention_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_prompt_flash_attention_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/custom.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_enum.h" #include "op_def/auto_generate/gen_lite_ops.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_adjust.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_adjust.cc index 95da0e3fb55805b5183f3ce27bc7c4e7501a58cf..7bd3a8db890a8aa2dd298e78321d6c232e8dbb13 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_adjust.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_adjust.cc @@ -24,7 +24,7 @@ #include "infer/cxx_api/mat_mul_fusion.h" #include "tools/converter/ops/ops_def.h" #include "src/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/quantizer/quantize_util.h" #include "tools/common/node_util.h" #include "tools/common/tensor_util.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_parser.cc index 2ff60a78aa59e922a6f7b04c6fcb78cb2389f817..d8c818ad07e6276978b074fccb9ba56480fed055 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_parser.cc @@ -22,7 +22,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "op_def/auto_generate/gen_lite_ops.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_parser.cc index e548df48db6c229aade2ff47c78134b389f3194e..da96c8fd016bbf74553c6efd9dfc3cd9098b61a6 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_quantize_parser.h" #include #include "infer/quant_dtype_cast.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_random_normal_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_random_normal_parser.cc index d19164fcb448c3a98eba0128b957b5d8a80fc7bf..7b8fe37c81e2990ed55ae79307217060b074d7bb 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_random_normal_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_random_normal_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/random_normal.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_range_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_range_parser.cc index 5d5be17152145b6f09eac66580517be9afaf056a..400f9f3a97a01bc9697edd922b56dce151a6e840 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_range_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_range_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_range_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_reduce_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_reduce_parser.cc index f78ac25efddaba912f7cea25285a0de1a80f7486..98ad0af92b2988fd6049f83e133bbff1f370c6c7 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_reduce_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_reduce_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/reduce_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_reshape_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_reshape_parser.cc index 77addebf3fd74918355b5f8dacf4e48d62d78850..e6df36a282e0675aebfb104d067ca38e9be68e84 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_reshape_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_reshape_parser.cc @@ -19,7 +19,7 @@ #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_resize_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_resize_parser.cc index 25761368e9a17f55c5c5c187723c103a2bb13f93..0fd70b434b8631e9e98742727c9b30724e73219a 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_resize_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_resize_parser.cc @@ -21,7 +21,7 @@ #include #include "infer/resize.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_reverse_sequence_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_reverse_sequence_parser.cc index d3f2b006cb1c49dde9bf6d120e43b6c7926301bf..ee08a635dbeb3ca91529f350384e47be809d806f 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_reverse_sequence_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_reverse_sequence_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_reverse_sequence_parser.h" #include #include "infer/reverse_sequence.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_elements_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_elements_parser.cc index a8b676603b21d8a8e0b18094c66df6ff0f409025..da6d1e808827de08dda0e2062b5ac19ce477d497 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_elements_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_elements_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_scatter_elements_parser.h" #include #include "infer/scatter_elements.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_nd_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_nd_parser.cc index 99d97e8341e549a4a29b35f61fb1a03e0c28ea3f..569f67eb6c34d7beaf5da700afb5cb5551ad30c8 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_nd_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_nd_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_scatter_nd_parser.h" #include #include "infer/scatter_nd_update.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_shape_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_shape_parser.cc index 334f320f8f29a8e279455538993e083310ac8b8c..6f914950ca904dcf3503c601b927ab86f05a928b 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_shape_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_shape_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_shape_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_slice_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_slice_parser.cc index 82f71b61aec3cfe50fa9f8f18022023de62efdbe..771883b5f7e37fde25878ad1b27588d3e67fc678 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_slice_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_slice_parser.cc @@ -24,7 +24,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "ops_utils/op_utils.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_softmax_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_softmax_parser.cc index 907efa48d56508c0c8ad5fe7e7f69c0b9cea0726..82e7fa6264f3eee53dfa6c7bbedca0ea3d6bb4c8 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_softmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_softmax_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_softmax_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc index a36b3ca412fa7317b74424cecca0c6e0dd995603..dede41d2edb102c6b6a651637a8791ae2cae6e3b 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_space_to_depth_parser.h" #include #include "infer/space_to_depth.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_splice_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_splice_parser.cc index a84cbd49812d7a9e0f1a834019b2830a066f1fcf..a2b8a1d04244377b4174b58f2cefc5424aa7c2f4 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_splice_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_splice_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/splice.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_split_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_split_parser.cc index 2b489c1165a9ad999d9945d3cd102620558ed601..272ae543fc34cc521a0cf13e4bf4ca68c1cb21be 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_split_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_split_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc index e4c737c3c9ec9b811a2d561290c4bd6152238c1d..b0c0d91e71263c3c0b8ecf83bd4b1b8db0e37231 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_tile_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_tile_parser.cc index 691f94606a34f954ece2e8a7e4d7ae2ff9983343..127f47b53d16fafc4f5ca91d707cff3eedc65e3e 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_tile_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_tile_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/tile_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_topk_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_topk_parser.cc index f924526d1ac6db66c6acfb294f8160d3c5de6231..2d10114557fb9e8f142286184473cdc66f9589f5 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_topk_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_topk_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_topk_parser.h" #include #include "infer/cxx_api/topk_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_transpose_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_transpose_parser.cc index e6812ffea498198968b4de0b7b5aad2a49d557bf..c6af3e0ed8b92fd3f4ced484693e5d10a29ec017 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_transpose_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_transpose_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_trilu_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_trilu_parser.cc index cb84a25e6b0825d9e7e8fd9d868a40411156ffc1..8bb81b40d885fdefa8aab07a66a147fdf32bfc22 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_trilu_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_trilu_parser.cc @@ -19,7 +19,7 @@ #include #include "infer/tril.h" #include "infer/ops_func_impl/triu.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "op_def/auto_generate/gen_lite_ops.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc index e472f99874f3b1197b03a3996ab16d3370b3dc17..953ca6170556acec61b1e037a63340258134a5e5 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/unsqueeze.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_upsample_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_upsample_parser.cc index 46072f0cda0097cdc8e9841d1d3850bf10c00b51..09152b3164d5c655442b78cc36bad66097901314 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_upsample_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_upsample_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/resize.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_where_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_where_parser.cc index 51d4e54f32ed1db478bd9e1b29a611e7affc6647..6c248599664a2cb1ce57c30da9ce334570997346 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_where_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_where_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_where_parser.h" #include #include "infer/where.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/parser_utils.cc b/mindspore-lite/tools/converter/parser/parser_utils.cc index 6b6e9e86fccedf3bb14e2a7e7b9e25ab514049fb..9cb619263dc4e81ed57da06866a534ad554e2e2d 100644 --- a/mindspore-lite/tools/converter/parser/parser_utils.cc +++ b/mindspore-lite/tools/converter/parser/parser_utils.cc @@ -37,7 +37,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/format/to_format_base.h" #include "tools/optimizer/common/pass_manager_extends.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "src/common/common.h" #include "tools/converter/parser/conv2d_transpose_input_adjust.h" diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_activation_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_activation_parser.cc index dc83e894382c02cf51dd05cd6abb44af2d211368..baf428e511ce91758eefc0c5e1009c16cb077b94 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_activation_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_activation_parser.cc @@ -18,7 +18,7 @@ #include #include "include/securec.h" #include "infer/cxx_api/activation.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_argmax_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_argmax_parser.cc index 70d46d1595e54cf3b46954bbd6abefa105a36e05..de6525e28513cfad71223b164f71e4f0b1711248 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_argmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_argmax_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_argmax_parser.h" #include #include "infer/cxx_api/arg_max_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_arithmetic_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_arithmetic_parser.cc index 2c8e2fd6fad310d67c93bdd857de539279f8929d..1d8285847a18bb295b1ccf22ad9fde47dcd618a8 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_arithmetic_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_arithmetic_parser.cc @@ -20,7 +20,7 @@ #include "infer/cxx_api/mul_fusion.h" #include "infer/cxx_api/div_fusion.h" #include "infer/cxx_api/sub_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_batchnorm_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_batchnorm_parser.cc index eab856842044e51c1733b47c8149b56800b6e11e..e7fd0cd293ff6a50dfb74c89f725cfb4c64a338e 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_batchnorm_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_batchnorm_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_batchnorm_parser.h" #include #include "infer/fused_batch_norm.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_conv_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_conv_parser.cc index a83d920f7a9449d1ae59fad9fb9189091e0c87d4..f00cb23966ab32ac2fbb8d1626bb0f0c11fa4939 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_conv_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_conv_parser.cc @@ -20,7 +20,7 @@ #include #include #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { PrimitiveCPtr PytorchConvParser::Parse(const torch::jit::Node *torch_node, std::vector *input_indices) { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_cumsum_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_cumsum_parser.cc index 57ae7bce9acd7c9c64891a5fc9df37f7e5c2d626..ef98afe693f71a385929bfeb2dd5c59cbb8f40ed 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_cumsum_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_cumsum_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_cumsum_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_elementop_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_elementop_parser.cc index ac79fa0e6e92169b7fd08e1aa797aeee7dff60d6..d70c38c003455b33ff6c0e109483e0892b2d42e2 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_elementop_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_elementop_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_elementop_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_embedding_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_embedding_parser.cc index 56daceeb45b586e7f61603b4cd43852b47f7d2e4..0e84d2f1f24cbb1a6d9b290eedd17d7f85ae65b2 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_embedding_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_embedding_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_embedding_parser.h" #include #include "infer/ops_func_impl/gather.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_flatten_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_flatten_parser.cc index 1ee8b9d68794b7a59863b7abf8c69cc6bbde2294..7dcb239cfb71f898b7d49556f051e08ed4a4d454 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_flatten_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_flatten_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_flatten_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_gather_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_gather_parser.cc index 2dd5300d174320dcb2b355aee86c49c570dd9437..b1d6728da7fc973f9b04b898ecbfe97d553b4ef3 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_gather_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_gather_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_gather_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_list_construct_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_list_construct_parser.cc index 871a1b9efca6fc6397d3945fef66e00c40967828..7fce39b96000b42b46bc0bf6d514d1736ce872a0 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_list_construct_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_list_construct_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_list_construct_parser.h" #include #include "infer/make_tuple.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_logsoftmax_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_logsoftmax_parser.cc index e6458245115e2b5d49ef25c1d7f7430536f8f20c..16d955de23b6f32b4d919a8ce210a3788e3d3ab4 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_logsoftmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_logsoftmax_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_logsoftmax_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_adjust.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_adjust.cc index 16e1e7bd67e5850b9eb589263bf5a1a7df7b85fd..54babcb457c63fe00c59c2e7e70ca365bf588020 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_adjust.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_adjust.cc @@ -20,7 +20,7 @@ #include "tools/lite_exporter/fetch_content.h" #include "tools/common/tensor_util.h" #include "utils/check_convert_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/nn_ops.h" #include "mindspore/ops/op_def/sequence_ops.h" diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_parser.cc index b61234b8e70603a297b4a26974e840ed8ee257c8..922b2d9f6b7f4c4c17e7653ff6583aae86f9d5e5 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_lstm_parser.h" #include #include "infer/lstm.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_matmul_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_matmul_parser.cc index eaf24192f93b8d31e7f59e933f3d934d4137f6f7..dec94040fe288ce9e044c6742a4e3efa3ea1cea9 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_matmul_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_matmul_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_matmul_parser.h" #include #include "infer/cxx_api/mat_mul_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_model_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_model_parser.cc index c53ccb31b68b38f2979c2b5b6db6d9b7a65e97c7..182bf2b99a8ddfe43a77af61dac63fc1672b76cd 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_model_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_model_parser.cc @@ -30,7 +30,7 @@ #include "tools/converter/parser/pytorch/torch_graph_transfrom.h" #include "src/common/file_utils.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "infer/make_tuple.h" #include "infer/return.h" diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_node_parser.h b/mindspore-lite/tools/converter/parser/pytorch/pytorch_node_parser.h index 399d0851752f64c2b6f0d50b8f7f1da9f24e40c5..aba1816dd62751443c3e0a382a366c14f1c74344 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_node_parser.h +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_node_parser.h @@ -29,7 +29,7 @@ #include "src/common/log_adapter.h" #include "tools/common/tensor_util.h" #include "tools/converter/parser/parser_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_non_max_suppression_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_non_max_suppression_parser.cc index 04a65415c3d5bba8670bb60756ae780ae6625fe3..fe0d7a7bbd97dbf1f73803c93bffcb4b4f71583c 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_non_max_suppression_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_non_max_suppression_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_non_max_suppression_parser.h" #include #include "infer/non_max_suppression.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_permute_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_permute_parser.cc index 88f425c150013a4e38ff8eb0259550158f458614..ef2855c7e2cf8730a95afb03065e7762c8095584 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_permute_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_permute_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_pool_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_pool_parser.cc index 9ed77aca2b3352444196870e89ceab414d146ac5..5333e3d833da0c6173186c9720ea0b1f182560a8 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_pool_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_pool_parser.cc @@ -20,7 +20,7 @@ #include "infer/cxx_api/avg_pool_fusion.h" #include "infer/cxx_api/max_pool_fusion.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_pow_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_pow_parser.cc index 50e3fd56c60773787a2242800716fe7ea5e9f7e4..98a6704db7f9084f980bbed9a0eb39612e4a0a74 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_pow_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_pow_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_pow_parser.h" #include #include "infer/cxx_api/pow_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_reshape_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_reshape_parser.cc index eac82dff1de3cfee314c1f7b413e1636958d72b2..912641da6de1d14f4eb990b9aa63eccc409a5176 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_reshape_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_reshape_parser.cc @@ -18,7 +18,7 @@ #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/unsqueeze.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_split_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_split_parser.cc index 47808075ac3e1d5a157fa81014c69ddbe1a021a0..e3bef32f3a8943ba745e909b2132e37c943e9b4b 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_split_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_split_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_to_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_to_parser.cc index 6b3d63194b04b82743a5e15de93c821a57cf4ecb..9cf17ca9ceb36941e5e0d72cff8408d5bdfd24fb 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_to_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_to_parser.cc @@ -18,7 +18,7 @@ #include "tools/converter/parser/pytorch/pytorch_node_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_unaryop_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_unaryop_parser.cc index 8f73044d16f1fb8989d8626f9616b9a91e976662..a20ceef43b715f160f779e6d7646e2760b8d4139 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_unaryop_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_unaryop_parser.cc @@ -21,7 +21,7 @@ #include "infer/cxx_api/exp_fusion.h" #include "infer/ops_func_impl/tan.h" #include "infer/eltwise.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_unstack_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_unstack_parser.cc index b9db664340cf7da7f4c89348980ad5676a810cab..60e5601a24d20f21b2d40e5c421cc4fd05801ed9 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_unstack_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_unstack_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/unstack.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tf/functionalize_cond.cc b/mindspore-lite/tools/converter/parser/tf/functionalize_cond.cc index 088298b47cd5c3f3d21306c0144b66b7ca8d0185..35832e7bca1d1d0b1c15e1f73d3b09c95355bbbf 100644 --- a/mindspore-lite/tools/converter/parser/tf/functionalize_cond.cc +++ b/mindspore-lite/tools/converter/parser/tf/functionalize_cond.cc @@ -24,7 +24,7 @@ #include "mindspore/ops/op_def/framework_ops.h" #include "include/errorcode.h" #include "tools/converter/ops/ops_def.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "infer/return.h" #include "tools/lite_exporter/fetch_content.h" diff --git a/mindspore-lite/tools/converter/parser/tf/functionalize_control_op_pass.cc b/mindspore-lite/tools/converter/parser/tf/functionalize_control_op_pass.cc index 1f09ac492267bcde7548ab232a6bb0bfb1764189..cbadd72498d09187f5eeee8542a631226d161bab 100644 --- a/mindspore-lite/tools/converter/parser/tf/functionalize_control_op_pass.cc +++ b/mindspore-lite/tools/converter/parser/tf/functionalize_control_op_pass.cc @@ -20,7 +20,7 @@ #include "tools/converter/parser/tf/functionalize_while.h" #include "tools/converter/parser/tf/functionalize_cond.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore::opt { diff --git a/mindspore-lite/tools/converter/parser/tf/remove_ineffective_control_flow.cc b/mindspore-lite/tools/converter/parser/tf/remove_ineffective_control_flow.cc index 5c457b5b6af30b9f6e0cbb3435e5a00ca97d30f0..f14b76dc5b2e98fd007599f81cf35738942e8e7d 100644 --- a/mindspore-lite/tools/converter/parser/tf/remove_ineffective_control_flow.cc +++ b/mindspore-lite/tools/converter/parser/tf/remove_ineffective_control_flow.cc @@ -19,7 +19,7 @@ #include "mindspore/ops/op_def/framework_ops.h" #include "mindspore/ops/op_def/lite_ops.h" #include "mindspore/ops/op_def/sequence_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/parser/tf/tf_util.h" #include "tools/optimizer/common/gllo_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" diff --git a/mindspore-lite/tools/converter/parser/tf/tf_fake_quant_parser.cc b/mindspore-lite/tools/converter/parser/tf/tf_fake_quant_parser.cc index 3a3e5d900ae5d8d27729b9401e2aa78e97fe4bba..51b1ceced8a6d177215e35f7f84ae61f4878082c 100644 --- a/mindspore-lite/tools/converter/parser/tf/tf_fake_quant_parser.cc +++ b/mindspore-lite/tools/converter/parser/tf/tf_fake_quant_parser.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "tools/converter/parser/tf/tf_fake_quant_parser.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/ops/ops_def.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h" diff --git a/mindspore-lite/tools/converter/parser/tf/tf_node_parser.h b/mindspore-lite/tools/converter/parser/tf/tf_node_parser.h index e3779f529aa83ad5d4c20fa088e28e5a98848a13..ec17e5e8463aac9cd3a7ea0174ea1caf3c39beb8 100644 --- a/mindspore-lite/tools/converter/parser/tf/tf_node_parser.h +++ b/mindspore-lite/tools/converter/parser/tf/tf_node_parser.h @@ -25,7 +25,7 @@ #include "proto/graph.pb.h" #include "ops/primitive_c.h" #include "utils/check_convert_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/parser/parser_utils.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/op_name.h" diff --git a/mindspore-lite/tools/converter/parser/tf/tf_sparse_to_dense_parser.cc b/mindspore-lite/tools/converter/parser/tf/tf_sparse_to_dense_parser.cc index 857861419cfebcbb6fcc095e5d7626fb5b5ec92e..40617f35f9fd6d1ae71285240f2536c5a87db8bc 100644 --- a/mindspore-lite/tools/converter/parser/tf/tf_sparse_to_dense_parser.cc +++ b/mindspore-lite/tools/converter/parser/tf/tf_sparse_to_dense_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/sparse_to_dense.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc b/mindspore-lite/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc index 02decf5995f95759b0fb931a4e3ab8abb9708ee7..e4cf2609d34b34fdf39c747128d4043f5d540974 100644 --- a/mindspore-lite/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc +++ b/mindspore-lite/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc @@ -31,7 +31,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "include/securec.h" #include "tools/converter/ops/ops_def.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_activation_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_activation_parser.cc index db1bf31205482dea7173940e67b51fff1385d2dd..e3e1e36c0589f184cd649796971e2e37cce2331b 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_activation_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_activation_parser.cc @@ -20,7 +20,7 @@ #include "tools/converter/parser/tflite/tflite_util.h" #include "infer/cxx_api/prelu_fusion.h" #include "infer/cxx_api/activation.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_addn_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_addn_parser.cc index 4a35715dd802860dc8a1177dd5de22757e4dd737..798bdf74886e806a88ab4b3ccbcd23ae09d93d6b 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_addn_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_addn_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_argmax_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_argmax_parser.cc index 68cd48f001a6dd69d939ba21e653de79116134f4..7770dde6598b149ba37d8ca57d659fbd00efb967 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_argmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_argmax_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/cxx_api/arg_max_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_argmin_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_argmin_parser.cc index 0044e9f0aa7fa4b85817fb61341ba88b75ffd2d6..8e164ce5a2b4651cefb01d8fc109c726903742bb 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_argmin_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_argmin_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/cxx_api/arg_min_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc index c1d694f46097766bc6232d13b16e075ce53f6573..abfad806631c815aa0e8c11b955cf8b0c0e3875f 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc @@ -25,7 +25,7 @@ #include "infer/cxx_api/exp_fusion.h" #include "infer/cxx_api/pow_fusion.h" #include "infer/squared_difference.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_batch_matmul_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_batch_matmul_parser.cc index b52150483fe9a794cda2f2cec1b026dccacff616..1eb54bc2d6bc64c909877ea3e12cd87eb4b363b0 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_batch_matmul_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_batch_matmul_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/mat_mul_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc index f3326e4f02deb10ec760ee100044f04f2dfb9391..0782cfc0b2aa5dd676b884747365cb5aad756e63 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc @@ -20,7 +20,7 @@ #include #include #include "infer/batch_to_space.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc index 98f6399006b649dc2933026d16734a26e6c0e84b..64f4c316add48277117ad72c683be6fa24d9d983 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_cast_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_cast_parser.cc index 5a5bf434d9932afc0866cbbce608407bc358fb68..9fa6ce66af2e7519a2aa1945d29ddd86e9a978c7 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_cast_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_cast_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_concat_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_concat_parser.cc index 4740f37c297135c7ddd2384fda2486240ede0636..f3d8369281e83b88599b29802cbf2f7a20a6b7ab 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_concat_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_concat_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_conv_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_conv_parser.cc index 91cc09a654d7d06bb2a168eac12a69aad10bd188..384b8cd22d7dae5ad3be9afa63e65bc4cca305eb 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_conv_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_conv_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_conv_transpose_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_conv_transpose_parser.cc index b14924b63c69d6727efa16e98e1d7320aaddbe87..c5dcd8959c8e9c2ffdefdf4d13badb6ba13a433b 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_conv_transpose_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_conv_transpose_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/conv2d_transpose_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_custom_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_custom_parser.cc index f815fd5ec040df3c8cc483970f6e63b9bc65cb9b..ac388def97f33a12491dbdadb1d286b013ea9b37 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_custom_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_custom_parser.cc @@ -30,7 +30,7 @@ #include "infer/fft_imag.h" #include "infer/mfcc.h" #include "infer/rfft.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc index 882c8324d53fa84133a2e472eea922ceb72bc915..3c283bc064b2a95b77d8b32bbfdb3ead939cb4f4 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/depth_to_space.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc index 20c31db0f924f419c39f0f538e9f92a5ec0b7f9f..4946d92f0c978130431ef90792c73009f2e4583c 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc @@ -18,7 +18,7 @@ #include #include "infer/quant_dtype_cast.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc index 780b889a6d4786daeec282ec23298eddacb52ad4..558d43851dc06693221e1cc7363974d9bfc199cb 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_fill_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_fill_parser.cc index 66deea422379a16414798313f1d8fcfb6065b9af..884b4cd6be5e343d776818724283cfad819b6bf7 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_fill_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_fill_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/fill.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc index 5e9c32c2f933f550b80603b5e24a5757974565e2..dd86243b43439c5c5287dc23014e6671f6797e19 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/full_connection.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc index 9b74e6e8d387826d47455be082661c4ebdd44f60..958d0447a86fa0329feba0b6b6c50f059d4b606c 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_gather_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_gather_parser.cc index 1b920b44bcd7108b731818ed1fa32a7306eef163..70af88d6ee22ffb9b39f7dca7b606916e43565db 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_gather_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_gather_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc index 1ba4929159fd6e59633374de3ea50912ad6294ae..6026bcf0d113418af7d3f9baf50f7f06edf572eb 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/hashtable_lookup.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_if_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_if_parser.cc index 65bb608e7bbd6e47c6f98114f5cc9571d71a91ba..9c6fe8e2d949712a13b32aacecef7da46af19e5f 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_if_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_if_parser.cc @@ -18,7 +18,7 @@ #include "tools/converter/parser/tflite/tflite_if_parser.h" #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_inputs_adjust.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_inputs_adjust.cc index 34052d26d8b45ccf5eb4958129bea2d5ab756a85..eb563769d8f0c4a50210da2b2162e3cac79e2614 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_inputs_adjust.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_inputs_adjust.cc @@ -28,7 +28,7 @@ #include "infer/space_to_batch_nd.h" #include "infer/space_to_depth.h" #include "tools/converter/quantizer/quant_param_holder.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc index 9a11f19468b466a9fba26c18f377dda9fa47861a..0881aca7da9f0bcb3008f8bb50675c58f39cfb9a 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/cxx_api/l2_normalize_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_log_softmax_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_log_softmax_parser.cc index 2912213259e5f879779bc073aca6c46e0dc1de67..00015db064bf538c2a2af5b2281fa220f5311a75 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_log_softmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_log_softmax_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_logical_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_logical_parser.cc index 2f04244c64d781e9aec385a063d92926f034c25f..978219ae43ad7d211ddf92a9191ab8b2a6b45c2a 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_logical_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_logical_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_lrn_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_lrn_parser.cc index d971c9e7e3286de03011989571dc16d9df2bdf4d..a321c85d124d68a652b3023718a5af266e4cb806 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_lrn_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_lrn_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/lrn.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc index 1097bfcb68bb46f2d2db2879bbf11b83e37ec7f3..22082e171e07a58e84e691464f3023a123db9dbc 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/lsh_projection.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_matmul_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_matmul_parser.cc index 50eff237e508d94e455d33a193bf15326a83c152..04a54b3742681362e92f8c41dea6f76b0e2d8824 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_matmul_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_matmul_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/tflite/tflite_matmul_parser.h" #include #include "infer/cxx_api/mat_mul_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_model_parser.cc index 5f9c57d660981100393b2b4dd5e4f34ec4aa6c3c..51f8729ae53856d5dde38f5472fca2eacb00c140 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -34,7 +34,7 @@ #include "tools/converter/parser/parser_utils.h" #include "tools/converter/parser/lite_model_parser_creator.h" #include "tools/converter/parser/unify_format.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "infer/make_tuple.h" #include "infer/return.h" diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc index 537aaee80b2c2692ec411c99538555e462fa7134..aa0ca44800749dc72948d819ecc43d1b0fd78a35 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_pad_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_pad_parser.cc index c19b52181c28a8606997d2b1ce9cbd7df72526c2..75dfd50a8ab53107ad59cd6c4a9bf4649fb2466e 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_pad_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_pad_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/cxx_api/pad_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_pooling_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_pooling_parser.cc index f2d0aa7a872b2f6a55a24c9d86a20f6a15ab6a1a..0e51c1dafcc0952228c9b791b191336b6241c05d 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_pooling_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_pooling_parser.cc @@ -20,7 +20,7 @@ #include #include "infer/cxx_api/avg_pool_fusion.h" #include "infer/cxx_api/max_pool_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_quantize_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_quantize_parser.cc index fdd322b0fc959a4b126adf3fe50aa620acf92179..ec2ffeeae167c277614a8f5de940c06038d79c58 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_quantize_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_quantize_parser.cc @@ -18,7 +18,7 @@ #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/quant_dtype_cast.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_range_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_range_parser.cc index 40d52b46664f337ac880eb82a5749b6eb46683df..a7c9f4858000ec38874a6405ab42494d49661c7c 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_range_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_range_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_rank_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_rank_parser.cc index 46e9399b54f5ac4a47e98c47c31f309025e617af..aeb7842b3052e85b91523e07ba22eddd1457e835 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_rank_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_rank_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_reduce_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_reduce_parser.cc index e12a1999be60ff9c5a24f1b352dca0e448b73e61..9a28ef6f55d1da96a0aae1d6ff612cf004fe0133 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_reduce_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_reduce_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/reduce_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_reshape_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_reshape_parser.cc index c464615fc5858153f6ffc0ed2290ae2438033445..c5c6e05fda0fafdf491cca3bcaaa3b384f7350cb 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_reshape_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_reshape_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_resize_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_resize_parser.cc index 4ecf4a00e709eec0b27ac2fac2890c729fd17d08..cfb1217633899d2ddc3ea3814151eb7eef6ce201 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_resize_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_resize_parser.cc @@ -20,7 +20,7 @@ #include #include #include "infer/resize.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_parser.cc index 12548cd4e7fc6f79d983e44a485e978283722e59..f90b380f0eb9e0011391c98d81f879c1109cb58a 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc index 4d37da00c115c43b7104431513a6b1a4b91bfa0d..998e44d2b00801d04ac09f3911c7ee4b96e7da8c 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/reverse_sequence.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc index 44145e65f4e61e8b64aa32994c85615bafc74f41..4bb69435968511f4c4624fddfc35401bf0a234cb 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_shape_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_shape_parser.cc index 3779078670a07ee8a99ebb376488a74cb2ec4060..81fa131f18c9eb6ce01a952a23de96947ebfddff 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_shape_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_shape_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc index 3e38d87ef580b63251f3cedbaabce023023dff40..2aa648661567485a48a60decb812178170e770fd 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/skip_gram.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_slice_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_slice_parser.cc index 0f35b0d02a53379a2de81d7946cd39313b0d49ba..fbf79aeb97693be405d890b309c9c11e7320d000 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_slice_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_slice_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/slice_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_softmax_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_softmax_parser.cc index 07d28c7cdfa84369482a0c3e038bcd5dce40bc1c..2ddfaa87685b2c164ed1e5c50e5513f1f184bd55 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_softmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_softmax_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc index e2278c0d3d56421042c1a4d0461bd30dc8ab93cd..1dff3ae4bcb0b1b93fa6b8d02cc9afd4c10a3864 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/space_to_batch_nd.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc index 758bd3924eab93921839b70a23ac92158d76b936..77a698e50317f9183d42fecfd06cf47ce8b6227f 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/space_to_depth.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc index dd6be778bbf19a30280ddab7acabba157c99a5ab..f9eb6d76239ad99852c3b612b3d0c10e515c3c2a 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/sparse_to_dense.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_split_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_split_parser.cc index ad54a6a9bc4379d7310839af3e4e69c3fad5efb0..a9ac3aa2d6648c88ceba1bcc5446744a4451c7b3 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_split_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_split_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_split_v_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_split_v_parser.cc index b9735021bb477f64ab55b51efbb4074bc49b3f38..eaa19d3d2770839cff9111fb7f28f540a4f92e49 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_split_v_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_split_v_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc index 073f170a29e7fb3e0ed8b5bde9e40ed730a041ab..dcf8c051b1274ac2857ed0c64bca3fbea4f40315 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_stack_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_stack_parser.cc index 26dece90fb9b23931f3fe987c351fd8e93cc25b6..d5404a51492681437315cb67df187870e49cb95f 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_stack_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_stack_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/stack.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc index bf967d7808ed7a292a85b905ee516fd65295a0b2..9c0f6a1fd09421b8fa565117296c4bd55d9d22fe 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_tile_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_tile_parser.cc index 6ab7c327d50e6a3e895b304a3deb19e50e651767..8ad7c2e5d91437819a0e9470d99a966fc8b98a86 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_tile_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_tile_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/cxx_api/tile_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc index f775424ea05eff5ffa9669e370cf8ed989497adb..6d03b339dd1837f3a3c5b1b9b78a3f6251ac5a52 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/cxx_api/topk_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_transpose_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_transpose_parser.cc index d0f1978015b6f644d9e531f2dabb007fba6f1495..ecfbf597a4ac5f6118ae0dc28e2b1a3528630913 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_transpose_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_transpose_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_unique_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_unique_parser.cc index ee124cb45bd136f0d625290e372b6e031ac66ec6..b134cef73ae2f41ddfd8135b5be0de37e3dd1aad 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_unique_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_unique_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/unique.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_unstack_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_unstack_parser.cc index f75c390d4a1619c986628ba9cf9fb8d7b7df1ce6..1780068f5ed177ff2891ec4810afc3b671181dff 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_unstack_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_unstack_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/unstack.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_util.cc index 685277f23f1194d59701c3d858cefac1a4fdc04f..7b06d82b0d13f56422e980f5139d5140cdc6b311 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_util.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_util.cc @@ -21,7 +21,7 @@ #include #include "src/common/log_adapter.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_where_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_where_parser.cc index 51322f799b38d19fd00082f9b27c45f70378276c..91bf6f46d9344444686ace82487a5e57cc55ae45 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_where_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_where_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/where.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_while_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_while_parser.cc index d04d5b802c05d09a4fbf867f1c1d3f0f831a125d..9f0993873a0b8f00e659837d430aee8bfa54a073 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_while_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_while_parser.cc @@ -19,7 +19,7 @@ #include #include #include "tools/converter/ops/while.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc index a92cede16a5c926c1af82ed6e4c6926c7d9ffd7e..c79759e4ffd27e2b95b99f8c1e3ad5d852f14dc2 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/unify_format.cc b/mindspore-lite/tools/converter/parser/unify_format.cc index 2c814b4f86f3906c985c552d5da3be61512206c8..1dd15904a6a47960ed9ef3f9547d8b04714b8c48 100644 --- a/mindspore-lite/tools/converter/parser/unify_format.cc +++ b/mindspore-lite/tools/converter/parser/unify_format.cc @@ -22,7 +22,7 @@ #include "mindspore/ops/op_def/lite_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "tools/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" diff --git a/mindspore-lite/tools/converter/quantizer/cle_pattern.cc b/mindspore-lite/tools/converter/quantizer/cle_pattern.cc index fa8a0035c5e45618e2438fb2cdfffeccb34b92ff..949fc5b0542129de9fc8d34d286a487b1ec97423 100644 --- a/mindspore-lite/tools/converter/quantizer/cle_pattern.cc +++ b/mindspore-lite/tools/converter/quantizer/cle_pattern.cc @@ -23,7 +23,7 @@ #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/converter/quantizer/debug_info_manager.h b/mindspore-lite/tools/converter/quantizer/debug_info_manager.h index 498fa1694f1e78b03ff16b32bbae0edc01264ae1..936898c5fec8e8231beb7a38a90808f90fdb01d2 100644 --- a/mindspore-lite/tools/converter/quantizer/debug_info_manager.h +++ b/mindspore-lite/tools/converter/quantizer/debug_info_manager.h @@ -24,7 +24,7 @@ #include #include "tools/converter/quantizer/quantize_util.h" #include "tools/converter/graphdef_transform.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/common/statistic_utils.h" #include "src/litert/lite_model.h" #include "src/tensor.h" diff --git a/mindspore-lite/tools/converter/quantizer/fse_decoder.cc b/mindspore-lite/tools/converter/quantizer/fse_decoder.cc index 968dd3d0621b4967e44012bd1f697c90a8848a73..0595321fa14d919135929abf9ef3944dde5ca254 100644 --- a/mindspore-lite/tools/converter/quantizer/fse_decoder.cc +++ b/mindspore-lite/tools/converter/quantizer/fse_decoder.cc @@ -20,7 +20,7 @@ #include "include/errorcode.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite::quant { namespace { diff --git a/mindspore-lite/tools/converter/quantizer/fse_encoder.cc b/mindspore-lite/tools/converter/quantizer/fse_encoder.cc index 8da15c8718d62c666b85874c13a279776a471134..dda1d59f30e91df64e135dbc6453c7bb9c80c2f1 100644 --- a/mindspore-lite/tools/converter/quantizer/fse_encoder.cc +++ b/mindspore-lite/tools/converter/quantizer/fse_encoder.cc @@ -22,7 +22,7 @@ #include "ir/dtype/type_id.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/errorcode.h" #include "tools/converter/quantizer/quantize_util.h" #include "tools/common/statistic_utils.h" diff --git a/mindspore-lite/tools/converter/quantizer/full_quant_quantizer.cc b/mindspore-lite/tools/converter/quantizer/full_quant_quantizer.cc index 7f30b48e489334b016d6e3fe4132ebbd4cf2ebe4..3f1aba2f5a0da87a480539937a049802091bd4fe 100644 --- a/mindspore-lite/tools/converter/quantizer/full_quant_quantizer.cc +++ b/mindspore-lite/tools/converter/quantizer/full_quant_quantizer.cc @@ -37,7 +37,7 @@ #include "tools/common/tensor_util.h" #include "src/common/utils.h" #include "tools/common/node_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "tools/converter/quantizer/bias_correction_strategy.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/converter/quantizer/gptq.h b/mindspore-lite/tools/converter/quantizer/gptq.h index c9ce4afe96fb00208cf1f6646bcfad46f86ba3d1..8a9553e4bee4d7840953ad2534992247715a2114 100644 --- a/mindspore-lite/tools/converter/quantizer/gptq.h +++ b/mindspore-lite/tools/converter/quantizer/gptq.h @@ -21,7 +21,7 @@ #include #include "tools/converter/quantizer/quantizer.h" #include "src/tensor.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "tools/converter/quantizer/gptq_quantizer.h" namespace mindspore::lite::quant { diff --git a/mindspore-lite/tools/converter/quantizer/gptq_quantizer.h b/mindspore-lite/tools/converter/quantizer/gptq_quantizer.h index 37ecbc952eb9ab6fbec8ac4b8dd79f1f5e3f9c20..cf1839001f2ee809c3f20a5082d3287aebcb8349 100644 --- a/mindspore-lite/tools/converter/quantizer/gptq_quantizer.h +++ b/mindspore-lite/tools/converter/quantizer/gptq_quantizer.h @@ -29,7 +29,7 @@ #include "tools/converter/cxx_api/converter_para.h" #include "ir/func_graph.h" #include "ir/anf.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore { namespace lite::quant { diff --git a/mindspore-lite/tools/converter/quantizer/quant_helper/ascend_distribute_fake_quant_transform.cc b/mindspore-lite/tools/converter/quantizer/quant_helper/ascend_distribute_fake_quant_transform.cc index c7d3b726f9381f4456b5c47dbad613c909aa6d55..50f6c7c4931be85933e5c02bcb6a791f4bd289fe 100644 --- a/mindspore-lite/tools/converter/quantizer/quant_helper/ascend_distribute_fake_quant_transform.cc +++ b/mindspore-lite/tools/converter/quantizer/quant_helper/ascend_distribute_fake_quant_transform.cc @@ -25,7 +25,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "include/backend/optimizer/graph_optimizer.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/tuple_get_item.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" diff --git a/mindspore-lite/tools/converter/quantizer/quant_helper/ffn_full_quant.cc b/mindspore-lite/tools/converter/quantizer/quant_helper/ffn_full_quant.cc index 187c2c0cb5b835756826b577801aebc297f0a8ef..9d81956842959ed67e6a7da46040720fd3b6a63a 100644 --- a/mindspore-lite/tools/converter/quantizer/quant_helper/ffn_full_quant.cc +++ b/mindspore-lite/tools/converter/quantizer/quant_helper/ffn_full_quant.cc @@ -23,7 +23,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "include/backend/optimizer/graph_optimizer.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/tuple_get_item.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" diff --git a/mindspore-lite/tools/converter/quantizer/quant_param_holder.h b/mindspore-lite/tools/converter/quantizer/quant_param_holder.h index 35ff9d077363b7e2a9014833be7208a8d7d45dd7..0bf8e2e9af57b86f13815ffb3c3033b7aaedb8dd 100644 --- a/mindspore-lite/tools/converter/quantizer/quant_param_holder.h +++ b/mindspore-lite/tools/converter/quantizer/quant_param_holder.h @@ -22,7 +22,7 @@ #include #include "ir/anf.h" #include "schema/inner/model_generated.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "tools/converter/quantizer/quant_params.h" diff --git a/mindspore-lite/tools/converter/quantizer/quant_strategy.cc b/mindspore-lite/tools/converter/quantizer/quant_strategy.cc index eb802a642985b00daf0a1c7e375898799c1420ca..17f5bb39536953e5b1225d57fb5c0970713c0c57 100644 --- a/mindspore-lite/tools/converter/quantizer/quant_strategy.cc +++ b/mindspore-lite/tools/converter/quantizer/quant_strategy.cc @@ -26,7 +26,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/converter/quantizer/quantize_util.cc b/mindspore-lite/tools/converter/quantizer/quantize_util.cc index f9860299d074a31afed4520976461589cc15adbb..8e8040a6f07593d085321dcc2d0caf89ca1ae1bd 100644 --- a/mindspore-lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore-lite/tools/converter/quantizer/quantize_util.cc @@ -46,7 +46,7 @@ #include "tools/converter/parser/parser_utils.h" #include "mindspore/ops/op_def/other_ops.h" #include "utils/anf_utils.h" -#include "mindspore/ops/kernel/cpu/nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" diff --git a/mindspore-lite/tools/converter/quantizer/smooth_quant.cc b/mindspore-lite/tools/converter/quantizer/smooth_quant.cc index 9d3c8d5eef074d8da08def837a71585a6137def2..2346c111bae995196f99b878bb9d3a872dcbf01b 100644 --- a/mindspore-lite/tools/converter/quantizer/smooth_quant.cc +++ b/mindspore-lite/tools/converter/quantizer/smooth_quant.cc @@ -24,9 +24,9 @@ #include "tools/converter/quantizer/insert_quant_node_manager.h" #include "tools/optimizer/common/gllo_utils.h" #include "thread/threadpool.h" -#include "nnacl/fp32/scale_fp32.h" +#include "nnacl_c/fp32/scale_fp32.h" #include "infer/cxx_api/scale_fusion.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_o.h" diff --git a/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc b/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc index af42e0c5814a966675e4c4a4a3b5349432d22d61..08e3b8c29c629da56a8a53a4affb6c0ab5477ac0 100644 --- a/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc +++ b/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc @@ -18,7 +18,7 @@ #include #include #include "tools/optimizer/common/gllo_utils.h" -#include "mindspore/ops/kernel/cpu/nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/quantizer/quant_params.h" #include "tools/converter/quantizer/quantize_util.h" #include "tools/lite_exporter/fetch_content.h" diff --git a/mindspore-lite/tools/converter/registry/CMakeLists.txt b/mindspore-lite/tools/converter/registry/CMakeLists.txt index 31e4c75f0e52b158fd6d4019a98f8a7d75503be5..53595758050ce98123ff189fbf909606699cf55f 100644 --- a/mindspore-lite/tools/converter/registry/CMakeLists.txt +++ b/mindspore-lite/tools/converter/registry/CMakeLists.txt @@ -16,7 +16,7 @@ set(REG_SRC ${CONVERT_REG_SRC} ${KERNEL_REG_DIR}/../extendrt/delegate/plugin/tensorrt_executor_plugin.cc ${KERNEL_REG_DIR}/../extendrt/kernel/ascend/plugin/ascend_allocator_plugin.cc ${CONVERTER_DIR}/converter_context.cc - ${TOP_DIR}/mindspore/mindspore/ops/kernel/cpu/nnacl/tensor_c_utils.c + ${NNACL_DIR}/tensor_c_utils.c ${TOP_DIR}/mindspore-lite/src/common/file_utils.cc ) set_property(SOURCE ${REG_SRC} PROPERTY COMPILE_DEFINITIONS diff --git a/mindspore-lite/tools/converter/registry/model_parser_registry.cc b/mindspore-lite/tools/converter/registry/model_parser_registry.cc index 043e00814368ef4a84b9e3b0daf496181e29fc9b..1421775acdfedb212098f539c6a898be237b35b4 100644 --- a/mindspore-lite/tools/converter/registry/model_parser_registry.cc +++ b/mindspore-lite/tools/converter/registry/model_parser_registry.cc @@ -17,7 +17,7 @@ #include "include/registry/model_parser_registry.h" #include #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace registry { diff --git a/mindspore-lite/tools/converter/registry/pass_registry.cc b/mindspore-lite/tools/converter/registry/pass_registry.cc index 4e81081ef6a9aa699b54e22da4435aa0619ce96e..7bbaccb5d9ed8a16576bbbf3049d3f509a24b1ea 100644 --- a/mindspore-lite/tools/converter/registry/pass_registry.cc +++ b/mindspore-lite/tools/converter/registry/pass_registry.cc @@ -20,7 +20,7 @@ #include #include #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace registry { diff --git a/mindspore-lite/tools/cropper/build_cropper_config.sh b/mindspore-lite/tools/cropper/build_cropper_config.sh index e4195a5c021d8623b6e5fea408dd2a6bff4c9192..b787f84a908feec814da65d27e6edbb75e010234 100644 --- a/mindspore-lite/tools/cropper/build_cropper_config.sh +++ b/mindspore-lite/tools/cropper/build_cropper_config.sh @@ -252,23 +252,23 @@ getCommonFile() { while IFS='' read -r line; do mindrt_files_h+=("$line"); done < <(ls ${MINDSPORE_HOME}/mindspore/core/mindrt/include/thread/*.h) others_files_h=( "${MINDSPORE_LITE_HOME}"/src/litert/infer_manager.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/infer/infer_register.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/nnacl_utils.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/infer/infer_register.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/nnacl_utils.h "${MINDSPORE_LITE_HOME}"/src/common/ops/populate/populate_register.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/op_base.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/op_base.h "${MINDSPORE_HOME}"/mindspore/core/include/ir/dtype/type_id.h "${MINDSPORE_HOME}"/mindspore/core/include/utils/overload.h "${MINDSPORE_LITE_HOME}"/tools/common/option.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/intrinsics/ms_simd_instructions.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/intrinsics/ms_simd_instructions_fp16.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/infer/infer.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/infer/common_infer.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/tensor_c.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/errorcode.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/common_func.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/tensor_c_utils.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/tensorlist_c.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/tensorlist_c_utils.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions_fp16.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/infer/infer.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/infer/common_infer.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/tensor_c.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/errorcode.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/common_func.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/tensorlist_c.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.h "${MINDSPORE_HOME}"/mindspore/core/include/utils/log_adapter.h "${MINDSPORE_HOME}"/mindspore/core/include/ir/api_tensor_impl.h "${MINDSPORE_LITE_HOME}"/src/litert/cxx_api/tensor/tensor_impl.h @@ -312,29 +312,29 @@ getCommonFile() { ) # sava all assembly files assembly_files=() - while IFS='' read -r line; do assembly_files+=("$line"); done < <(ls ${MINDSPORE_HOME}/mindspore/ops/kernel/cpu/nnacl/assembly/*/*.S) + while IFS='' read -r line; do assembly_files+=("$line"); done < <(ls ${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl_c/assembly/*/*.S) others_files_c=( - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/nnacl_utils.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/errorcode.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/nnacl_utils.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/errorcode.c "${MINDSPORE_LITE_HOME}"/src/litert/infer_manager.cc "${MINDSPORE_LITE_HOME}"/src/common/ops/populate/populate_register.cc "${MINDSPORE_LITE_HOME}"/src/common/ops/populate/custom_populate.cc - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/infer/infer_register.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/infer/shape_fusion_infer.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/infer/common_infer.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/infer/infer_register.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/infer/common_infer.c "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/fp32/shape_fusion_fp32.cc "${MINDSPORE_HOME}"/mindspore/core/utils/status.cc "${MINDSPORE_HOME}"/mindspore/core/utils/log_adapter.cc - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/kernel.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/tensor_c_utils.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/tensorlist_c_utils.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/base/format_transpose.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/base/cast_base.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/fp32/transpose_fp32.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/fp32/pack_fp32.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/fp16/pack_fp16.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/fp32/pack_fp32_opt.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/nnacl_common.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/kernel.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/base/format_transpose.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/base/cast_base.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/nnacl_common.c ) all_files=("${src_files[@]}" "${regist_files[@]}" "${common_files[@]}" "${runtime_files_cc[@]}" "${others_files_c[@]}" "${assembly_files[@]}" "${nnacl_files_cc[@]}" "${mindrt_files[@]}" @@ -428,7 +428,7 @@ getCommonFile getTrainCommonFile # get src/common/ops getOpsFile "REG_POPULATE\(PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/common/ops/populate" "prototype" & -getOpsFile "REG_INFER\(.*?, PrimType_" "${MINDSPORE_HOME}/mindspore/ops/kernel/cpu/nnacl/infer" "prototype" & +getOpsFile "REG_INFER\(.*?, PrimType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl_c/infer" "prototype" & # support for cpu getOpsFile "REG_KERNEL\(.*?, kNumberTypeFloat32, PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu" "kNumberTypeFloat32" & getOpsFile "REG_KERNEL\(.*?, kNumberTypeFloat16, PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu" "kNumberTypeFloat16" & @@ -436,11 +436,11 @@ getOpsFile "REG_KERNEL\(.*?, kNumberTypeInt8, PrimitiveType_" "${MINDSPORE_LITE_ getOpsFile "REG_KERNEL\(.*?, kNumberTypeInt32, PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu" "kNumberTypeInt32" & getOpsFile "REG_KERNEL\(.*?, kNumberTypeBool, PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu" "kNumberTypeInt32" & #support for nnacl kernel -getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_HOME}/mindspore/ops/kernel/cpu/nnacl/kernel" "kNumberTypeFloat32" "kNumberTypeFloat32" & -getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_HOME}/mindspore/ops/kernel/cpu/nnacl/kernel" "kNumberTypeFloat16" "kNumberTypeFloat16" & -getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_HOME}/mindspore/ops/kernel/cpu/nnacl/kernel" "kNumberTypeInt8" "kNumberTypeInt8" & -getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_HOME}/mindspore/ops/kernel/cpu/nnacl/kernel" "kNumberTypeInt32" "kNumberTypeInt32" & -getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_HOME}/mindspore/ops/kernel/cpu/nnacl/kernel" "kNumberTypeInt32" "kNumberTypeBool" & +getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl_c/kernel" "kNumberTypeFloat32" "kNumberTypeFloat32" & +getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl_c/kernel" "kNumberTypeFloat16" "kNumberTypeFloat16" & +getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl_c/kernel" "kNumberTypeInt8" "kNumberTypeInt8" & +getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl_c/kernel" "kNumberTypeInt32" "kNumberTypeInt32" & +getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl_c/kernel" "kNumberTypeInt32" "kNumberTypeBool" & getNnaclKernelFile "NNACL_KERNEL\(PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl" "kNumberTypeFloat32" "kNumberTypeFloat32" & getNnaclKernelFile "NNACL_KERNEL\(PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl" "kNumberTypeFloat16" "kNumberTypeFloat16" & getNnaclKernelFile "NNACL_KERNEL\(PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl" "kNumberTypeInt8" "kNumberTypeInt8" & diff --git a/mindspore-lite/tools/graph_kernel/common/infer_shape.cc b/mindspore-lite/tools/graph_kernel/common/infer_shape.cc index 304cce1627fdc34ae1f746420b2530cf449fa50e..24750e33f2b4ddb212b7c8b1f0b6a2374d86ab76 100644 --- a/mindspore-lite/tools/graph_kernel/common/infer_shape.cc +++ b/mindspore-lite/tools/graph_kernel/common/infer_shape.cc @@ -22,9 +22,9 @@ #include "schema/model_generated.h" #include "src/tensor.h" #include "src/common/utils.h" -#include "nnacl/infer/common_infer.h" -#include "nnacl/infer/infer_register.h" -#include "nnacl/custom_parameter.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/custom_parameter.h" namespace mindspore::graphkernel { using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/tools/graph_kernel/common/utils.h b/mindspore-lite/tools/graph_kernel/common/utils.h index ce7afd0d92ff6318c80f80462211338040dc7a6b..78a6f184f06a087f8e7fc14b4e9361219aaffca0 100644 --- a/mindspore-lite/tools/graph_kernel/common/utils.h +++ b/mindspore-lite/tools/graph_kernel/common/utils.h @@ -19,7 +19,7 @@ #include #include #include -#include "nnacl/tensor_c.h" +#include "nnacl_c/tensor_c.h" #include "common/kernel_build_info.h" #include "include/backend/kernel_info.h" diff --git a/mindspore-lite/tools/graph_kernel/runtime/akg_kernel.h b/mindspore-lite/tools/graph_kernel/runtime/akg_kernel.h index f1afd16e3faf0cd14576e33ad1ee9e1b26835465..8a7485542668a40cf98702b0e81150df5fa0d273 100644 --- a/mindspore-lite/tools/graph_kernel/runtime/akg_kernel.h +++ b/mindspore-lite/tools/graph_kernel/runtime/akg_kernel.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/custom_parameter.h" +#include "nnacl_c/custom_parameter.h" namespace mindspore::kernel { using AkgParallelLambda = int (*)(int task_id, int num_task, void *cdata); diff --git a/mindspore-lite/tools/lite_exporter/anf_exporter.cc b/mindspore-lite/tools/lite_exporter/anf_exporter.cc index 93d4fbe5a41bd12025b0f5e2a97406c56ef05e44..60f1bf4e726aafd3409883658f8a9ce163d4396e 100644 --- a/mindspore-lite/tools/lite_exporter/anf_exporter.cc +++ b/mindspore-lite/tools/lite_exporter/anf_exporter.cc @@ -31,7 +31,7 @@ #include "mindspore/ops/op_def/op_name.h" #include "mindspore/ops/ops_utils/op_utils.h" #include "mindspore/ops/op_def/sequence_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/depend.h" #include "infer/cxx_api/partial_fusion.h" #include "infer/make_tuple.h" diff --git a/mindspore-lite/tools/lite_exporter/fetch_content.cc b/mindspore-lite/tools/lite_exporter/fetch_content.cc index 97d6a97390bda75ff7d10d3a31eb16b9744f0bd0..b7187d558d682ff5dbdda73501e9b37b985fba1b 100644 --- a/mindspore-lite/tools/lite_exporter/fetch_content.cc +++ b/mindspore-lite/tools/lite_exporter/fetch_content.cc @@ -25,7 +25,7 @@ #include "mindapi/base/format.h" #include "mindspore/ops/op_def/framework_ops.h" #include "mindspore/ops/op_def/sequence_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "src/common/ops/anf_utils.h" #include "src/common/ops/populate/populate_register.h" diff --git a/mindspore-lite/tools/lite_exporter/fetch_content.h b/mindspore-lite/tools/lite_exporter/fetch_content.h index 5b5dadf1b91a6ccfd08459bbb27039108f90cf15..3b370c68b582dc3c63f6525410ef791a10184729 100644 --- a/mindspore-lite/tools/lite_exporter/fetch_content.h +++ b/mindspore-lite/tools/lite_exporter/fetch_content.h @@ -24,7 +24,7 @@ #include "ir/primitive.h" #include "ir/func_graph.h" #include "src/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" namespace mindspore { diff --git a/mindspore-lite/tools/optimizer/common/format_utils.cc b/mindspore-lite/tools/optimizer/common/format_utils.cc index ba90ce1853b98ed7a2e1bc9332e870d400b7d096..ce11b23ee74cefb317d2212f608d35b8584aee23 100644 --- a/mindspore-lite/tools/optimizer/common/format_utils.cc +++ b/mindspore-lite/tools/optimizer/common/format_utils.cc @@ -64,7 +64,7 @@ #include "infer/deformable_conv2d.h" #include "infer/roi_align.h" #include "tools/lite_exporter/fetch_content.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/common/graph_util.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/common/gllo_utils.cc b/mindspore-lite/tools/optimizer/common/gllo_utils.cc index 98759ae17894485dffd65917fda551e486f1b56b..81718a474589882a8fe1f53abde0d53cb96f2188 100644 --- a/mindspore-lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore-lite/tools/optimizer/common/gllo_utils.cc @@ -36,7 +36,7 @@ #include "frontend/operator/ops.h" #include "include/backend/optimizer/helper.h" #include "tools/converter/quantizer/quant_param_holder.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "tools/converter/parser/parser_utils.h" #include "tools/optimizer/common/helper.h" diff --git a/mindspore-lite/tools/optimizer/common/helper.cc b/mindspore-lite/tools/optimizer/common/helper.cc index b519cc213a8f0537659e8644db43c5dc550e48be..1fde0eb3ae6704d36c501ce9db899d5fa89cb0c5 100644 --- a/mindspore-lite/tools/optimizer/common/helper.cc +++ b/mindspore-lite/tools/optimizer/common/helper.cc @@ -20,7 +20,7 @@ #include #include "tools/optimizer/common/helper.h" #include "mindspore/ops/op_def/sequence_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" namespace mindspore { diff --git a/mindspore-lite/tools/optimizer/common/multiple_pattern_process_pass.cc b/mindspore-lite/tools/optimizer/common/multiple_pattern_process_pass.cc index dcf1a58f101ffa12f08b49a463bff8a5139e6e3e..e809de0ed77b4f293f45fdedfde364a8c531cb3b 100644 --- a/mindspore-lite/tools/optimizer/common/multiple_pattern_process_pass.cc +++ b/mindspore-lite/tools/optimizer/common/multiple_pattern_process_pass.cc @@ -16,7 +16,7 @@ #include "tools/optimizer/common/multiple_pattern_process_pass.h" #include "tools/optimizer/common/helper.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::opt { AnfNodePtr MultiplePatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { diff --git a/mindspore-lite/tools/optimizer/const_fold/fold_along_infershape.cc b/mindspore-lite/tools/optimizer/const_fold/fold_along_infershape.cc index 3ea6c6b33f55771c8940f5e6246ad56bb94dfbf9..39cd9f708467a39a2a81478aa0c9d8163b67ef0a 100644 --- a/mindspore-lite/tools/optimizer/const_fold/fold_along_infershape.cc +++ b/mindspore-lite/tools/optimizer/const_fold/fold_along_infershape.cc @@ -18,7 +18,7 @@ #include "tools/optimizer/const_fold/fold_along_infershape.h" #include #include "mindspore/ops/op_def/framework_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/const_fold/fold_with_infershape.cc b/mindspore-lite/tools/optimizer/const_fold/fold_with_infershape.cc index 63ff7bab72df0c789cd450ec0a61ae337482acda..1cb1982af0d6ce2e6a2b09a19f7973e7b7bc4f88 100644 --- a/mindspore-lite/tools/optimizer/const_fold/fold_with_infershape.cc +++ b/mindspore-lite/tools/optimizer/const_fold/fold_with_infershape.cc @@ -20,7 +20,7 @@ #include #include "mindspore/ops/op_def/framework_ops.h" #include "tools/optimizer/common/format_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fisson/eliminate_concat_split.cc b/mindspore-lite/tools/optimizer/fisson/eliminate_concat_split.cc index 135ca1383c987ce667dfcca5588454db181c68f8..a935cecd5d739f9e25dd9f576680161284154251 100644 --- a/mindspore-lite/tools/optimizer/fisson/eliminate_concat_split.cc +++ b/mindspore-lite/tools/optimizer/fisson/eliminate_concat_split.cc @@ -29,7 +29,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/optimizer/parallel/spliter.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fisson/fisson_util.cc b/mindspore-lite/tools/optimizer/fisson/fisson_util.cc index b5f6decc35e2b94a34d629ba1be4db0b3739862a..c751ffd7754bd4cef0ee29fa807e0a7db8c2c071 100644 --- a/mindspore-lite/tools/optimizer/fisson/fisson_util.cc +++ b/mindspore-lite/tools/optimizer/fisson/fisson_util.cc @@ -26,7 +26,7 @@ #include "infer/make_tuple.h" #include "tools/optimizer/parallel/spliter.h" #include "tools/optimizer/parallel/split_strategy.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "ops_utils/op_utils.h" #include "include/registry/converter_context.h" diff --git a/mindspore-lite/tools/optimizer/fisson/iter_node_outputs.cc b/mindspore-lite/tools/optimizer/fisson/iter_node_outputs.cc index 857fb0938be0cdee4d5a8acb51b77af9c05b0fd5..4dc8858fab510c17fd8c71f9b906f653b895704b 100644 --- a/mindspore-lite/tools/optimizer/fisson/iter_node_outputs.cc +++ b/mindspore-lite/tools/optimizer/fisson/iter_node_outputs.cc @@ -17,7 +17,7 @@ #define USE_DEPRECATED_API #include "tools/optimizer/fisson/iter_node_outputs.h" #include "tools/optimizer/parallel/spliter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace opt { diff --git a/mindspore-lite/tools/optimizer/fisson/multi_conv_split_pass.cc b/mindspore-lite/tools/optimizer/fisson/multi_conv_split_pass.cc index 332bdf4dbd5c6cf6930a28245e236bd5427fa21f..dd2b6036215e01d745a99f4368420d4c695ac1d7 100644 --- a/mindspore-lite/tools/optimizer/fisson/multi_conv_split_pass.cc +++ b/mindspore-lite/tools/optimizer/fisson/multi_conv_split_pass.cc @@ -23,7 +23,7 @@ #include "infer/cxx_api/conv2d_fusion.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/parallel/split_strategy.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" using mindspore::schema::PrimitiveType_Conv2dTransposeFusion; diff --git a/mindspore-lite/tools/optimizer/fisson/node_out_shapes.cc b/mindspore-lite/tools/optimizer/fisson/node_out_shapes.cc index 839d53fb73ca5b158ffa0f823d227290711136f1..baf877a79e08b5ddfb1d2bcd5b66f05a77326c8f 100644 --- a/mindspore-lite/tools/optimizer/fisson/node_out_shapes.cc +++ b/mindspore-lite/tools/optimizer/fisson/node_out_shapes.cc @@ -19,7 +19,7 @@ #include #include #include "tools/optimizer/parallel/spliter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace opt { diff --git a/mindspore-lite/tools/optimizer/format/delete_redundant_transpose.cc b/mindspore-lite/tools/optimizer/format/delete_redundant_transpose.cc index de5b5aa445dbd4f49f11eae79c811b07059f353b..ff001d9ace422f469781db372be9fa386a038372 100644 --- a/mindspore-lite/tools/optimizer/format/delete_redundant_transpose.cc +++ b/mindspore-lite/tools/optimizer/format/delete_redundant_transpose.cc @@ -21,7 +21,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "tools/optimizer/common/format_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "tools/common/node_util.h" #include "tools/converter/quantizer/quant_params.h" diff --git a/mindspore-lite/tools/optimizer/format/to_format_base.cc b/mindspore-lite/tools/optimizer/format/to_format_base.cc index 87257b46b57a6dade542c3412da79117678da6c2..f161f18ce18bc72a7df8f4a0aa9d98689ff7a561 100644 --- a/mindspore-lite/tools/optimizer/format/to_format_base.cc +++ b/mindspore-lite/tools/optimizer/format/to_format_base.cc @@ -26,7 +26,7 @@ #include "src/common/utils.h" #include "tools/common/tensor_util.h" #include "tools/converter/parser/parser_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" diff --git a/mindspore-lite/tools/optimizer/fusion/activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/activation_fusion.cc index 0720af1f4feb7937b3dc77856f05564ddbb02315..cb372a043a96de89bd12eeff63efe29bb008ba1f 100644 --- a/mindspore-lite/tools/optimizer/fusion/activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/activation_fusion.cc @@ -21,7 +21,7 @@ #include #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/activation.h" #include "ops_utils/op_utils.h" #include "src/common/utils.h" @@ -48,11 +48,13 @@ STATUS DoFusion(CNodePtr cur_cnode, const CNodePtr &pre_cnode) { MS_CHECK_TRUE_MSG(cur_act_prim->GetAttr(ops::kMaxVal) != nullptr, RET_ERROR, "Get max value failed."); MS_CHECK_TRUE_MSG(pre_act_prim->GetAttr(ops::kMinVal) != nullptr, RET_ERROR, "Get min value failed."); MS_CHECK_TRUE_MSG(cur_act_prim->GetAttr(ops::kMinVal) != nullptr, RET_ERROR, "Get min value failed."); - auto pre_max_val = - pre_act_type == RELU ? FLT_MAX : pre_act_type == RELU6 ? kValueThreshold6 : pre_act_prim->get_max_val(); + auto pre_max_val = pre_act_type == RELU ? FLT_MAX + : pre_act_type == RELU6 ? kValueThreshold6 + : pre_act_prim->get_max_val(); auto pre_min_val = (pre_act_type == RELU || pre_act_type == RELU6) ? 0 : pre_act_prim->get_min_val(); - auto cur_max_val = - cur_act_type == RELU ? FLT_MAX : cur_act_type == RELU6 ? kValueThreshold6 : cur_act_prim->get_max_val(); + auto cur_max_val = cur_act_type == RELU ? FLT_MAX + : cur_act_type == RELU6 ? kValueThreshold6 + : cur_act_prim->get_max_val(); auto cur_min_val = (cur_act_type == RELU || cur_act_type == RELU6) ? 0 : cur_act_prim->get_min_val(); auto new_max_val = std::min(pre_max_val, cur_max_val); auto new_min_val = std::max(pre_min_val, cur_min_val); diff --git a/mindspore-lite/tools/optimizer/fusion/add_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/add_activation_fusion.cc index 484f881982aeada7304fae8944b538019fc2d1f7..da34ced9840da9e3f37ecef5dc760ef2dfbdcc53 100644 --- a/mindspore-lite/tools/optimizer/fusion/add_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/add_activation_fusion.cc @@ -23,7 +23,7 @@ #include "infer/cxx_api/add_fusion.h" #include "include/common/utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "tools/converter/quantizer/quant_param_holder.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/optimizer/fusion/add_concat_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/add_concat_activation_fusion.cc index 27b90a503afdb2a6799c1bef1fce241810b92865..1bd9fb8facf85431f62f9ec193547034dd8223c8 100644 --- a/mindspore-lite/tools/optimizer/fusion/add_concat_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/add_concat_activation_fusion.cc @@ -23,7 +23,7 @@ #include "infer/cxx_api/activation.h" #include "infer/cxx_api/add_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/fusion/add_layernorm_fusion.cc b/mindspore-lite/tools/optimizer/fusion/add_layernorm_fusion.cc index cb3aa516bfc48cab81d4a1918763615c33586778..6b50a3cf57c5eb3a5295c1f7418c7c41c23d115c 100644 --- a/mindspore-lite/tools/optimizer/fusion/add_layernorm_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/add_layernorm_fusion.cc @@ -30,7 +30,7 @@ #include "include/common/utils/anfalgo.h" #include "include/backend/anf_runtime_algorithm.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/optimizer/graph/node_infershape.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/fusion/adjust_col2im_pass.cc b/mindspore-lite/tools/optimizer/fusion/adjust_col2im_pass.cc index db29238cfe85d818fc101a5b3b0f6e76e4cebada..fc50cdd7b382d1edb70329fafb4839f345222cac 100644 --- a/mindspore-lite/tools/optimizer/fusion/adjust_col2im_pass.cc +++ b/mindspore-lite/tools/optimizer/fusion/adjust_col2im_pass.cc @@ -28,7 +28,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/range_v2.h" #include "mindspore/ops/op_def/image_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" namespace mindspore { diff --git a/mindspore-lite/tools/optimizer/fusion/affine_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/affine_activation_fusion.cc index f4c45cca41a2e16773036b6b6886eec8aeba4a75..d627d93fae9b3e583e5b6c675a9cb9a3606c3274 100644 --- a/mindspore-lite/tools/optimizer/fusion/affine_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/affine_activation_fusion.cc @@ -21,7 +21,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "infer/cxx_api/activation.h" #include "infer/affine.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/optimizer/fusion/affine_fusion.cc b/mindspore-lite/tools/optimizer/fusion/affine_fusion.cc index fe0156ea03f9d52acd69149416bef33ad6e1b396..bb043793d3b862e5faeb6cd91d8755bbf54f3131 100644 --- a/mindspore-lite/tools/optimizer/fusion/affine_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/affine_fusion.cc @@ -25,7 +25,7 @@ #include "infer/splice.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/antiquant_add_mul_matmul_allreduce_fusion.cc b/mindspore-lite/tools/optimizer/fusion/antiquant_add_mul_matmul_allreduce_fusion.cc index d2176d93d33fe3ba86d4ace432410f0ec6bdc53e..b62990a49835796fa391aa2f034c5dc0974fe214 100644 --- a/mindspore-lite/tools/optimizer/fusion/antiquant_add_mul_matmul_allreduce_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/antiquant_add_mul_matmul_allreduce_fusion.cc @@ -24,7 +24,7 @@ #include "mindspore/ops/infer/all_reduce.h" #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h" #include "ir/anf.h" diff --git a/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc b/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc index bfad59fe17aa260f1a0a235097bf5f761abef715..ea8485c7566f145e69863e24394d004d6590c0af 100644 --- a/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc @@ -27,7 +27,7 @@ #include "tools/converter/quantizer/quantize_util.h" #include "tools/optimizer/common/gllo_utils.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" diff --git a/mindspore-lite/tools/optimizer/fusion/batchnorm_to_scale_fusion.cc b/mindspore-lite/tools/optimizer/fusion/batchnorm_to_scale_fusion.cc index 85d208317a59796ab848a749a46b308e842b29da..dfc8fc248c7a32b9971ec47de1a3c1d04886aa17 100644 --- a/mindspore-lite/tools/optimizer/fusion/batchnorm_to_scale_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/batchnorm_to_scale_fusion.cc @@ -25,7 +25,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/common/tensor_util.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "src/common/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" diff --git a/mindspore-lite/tools/optimizer/fusion/cast_fusion.cc b/mindspore-lite/tools/optimizer/fusion/cast_fusion.cc index 513d06a093fd01b481c32572200f67a09b537c7e..c9bec643360ee563334646a97a9dcdc0c0f1f9eb 100644 --- a/mindspore-lite/tools/optimizer/fusion/cast_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/cast_fusion.cc @@ -25,7 +25,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "tools/converter/quantizer/quant_param_holder.h" #include "tools/optimizer/common/format_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/lite_exporter/fetch_content.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_activation_fusion.cc index f519a7ee88d9802484eacdb47b2339ad686ded0c..730d45631bbebd09da524a3a7619133e422ff28a 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_activation_fusion.cc @@ -19,7 +19,7 @@ #include #include "mindspore/ops/op_def/lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/activation.h" #include "ops_utils/op_utils.h" #include "tools/optimizer/common/gllo_utils.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_biasadd_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_biasadd_fusion.cc index a23f4a7aaae5abda4c2a10d89cf184d7e10b6eee..da78af0a8f60f5cb9ed33f8f0a8ee63664ea126e 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_biasadd_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_biasadd_fusion.cc @@ -23,7 +23,7 @@ #include "tools/lite_exporter/fetch_content.h" #include "tools/optimizer/common/gllo_utils.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindapi/base/types.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_bn_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_bn_fusion.cc index 53b1f07c7095408a58b130c777a1abacf32bc61e..f9cc8268aa9745c37e690c1cd0e3800a6c99466d 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_bn_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_bn_fusion.cc @@ -20,7 +20,7 @@ #include "include/common/utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/fusion/batchnorm_to_scale_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_conv_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_conv_fusion.cc index 30d9022f389ac4dd202a156b79eacbdc25ccd071..2ea850740a8b8aef286cea250887b60cdadf5ff1 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_conv_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_conv_fusion.cc @@ -22,7 +22,7 @@ #include "tools/common/tensor_util.h" #include "infer/cxx_api/conv2d_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_pad_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_pad_fusion.cc index 44d868ec42e878709a435f96cbe70df5b3d75c3b..0851cab2b29a4ad5ebfa519ee7a698b0b3a8eb62 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_pad_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_pad_fusion.cc @@ -25,7 +25,7 @@ #include "infer/cxx_api/pad_fusion.h" #include "infer/cxx_api/conv2d_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops/primitive_c.h" #include "ops_utils/op_utils.h" #include "src/common/utils.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_scale_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_scale_fusion.cc index a3e03803ffa84645782b98c5ec4c0f2361b90885..ea98d48a847e24d0bf5755c7153acae206626256 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_scale_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_scale_fusion.cc @@ -20,7 +20,7 @@ #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" namespace mindspore::opt { diff --git a/mindspore-lite/tools/optimizer/fusion/conv_transform_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_transform_fusion.cc index c1363d766a1e03d347ae9a87f8b10991f4d1cb31..cc57a7fca9f794cda2efea9fbb491f77fa308d35 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_transform_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_transform_fusion.cc @@ -26,7 +26,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/quantizer/quant_param_holder.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_tuple_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_tuple_activation_fusion.cc index 284aa0434155435218e7821b4481f3dd0c46297f..c94104888bf287dde0721a2928cda18d8dec106c 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_tuple_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_tuple_activation_fusion.cc @@ -22,7 +22,7 @@ #include "infer/cxx_api/activation.h" #include "infer/cxx_api/conv2d_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc index df594472726bbbae70d5cb1ebcf2452a3cb558d3..3e9ffd3d5e4b1110d5aeef143a7c28c7fff51e7c 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc @@ -19,7 +19,7 @@ #include #include "mindspore/ops/op_def/sequence_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" namespace mindspore::opt { diff --git a/mindspore-lite/tools/optimizer/fusion/decoder_layer_fusion.cc b/mindspore-lite/tools/optimizer/fusion/decoder_layer_fusion.cc index a95d560a2b720630ec1a42b51e78e4a4b9f0c069..b8ab59753247f80d5b06570a2409a6de03e8e594 100644 --- a/mindspore-lite/tools/optimizer/fusion/decoder_layer_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/decoder_layer_fusion.cc @@ -26,7 +26,7 @@ #include "mindspore/ops/op_def/lite_ops.h" #include "mindspore/ops/op_def/array_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/tuple_get_item.h" #include "tools/common/tensor_util.h" #include "ops_utils/op_utils.h" diff --git a/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc b/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc index 84f8d794496317b15e3870fe68415aadcb3d0781..9e55beaf9651ce2b08c8a38e9ed118c9d92cc52b 100644 --- a/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc @@ -30,7 +30,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/tuple_get_item.h" #include "tools/common/tensor_util.h" #include "ops_utils/op_utils.h" @@ -992,7 +992,8 @@ STATUS EncoderLayerFusion::CheckPattern(const FuncGraphPtr &func_graph, const Eq } } act_type_ = (is_position_bias_) ? (ActType::ActType_Relu) - : (is_fast_gelu_) ? (ActType::ActType_FastGelu) : (ActType::ActType_Gelu); + : (is_fast_gelu_) ? (ActType::ActType_FastGelu) + : (ActType::ActType_Gelu); if (!is_position_bias_ && !is_use_past_ && !is_query_layer_) { if (!IsActGELU(func_graph, equiv, is_act_)) { return RET_ERROR; diff --git a/mindspore-lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc b/mindspore-lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc index 1c21920d16a499e0333f140fd94fe4f4c105a7dc..455487af9cea5ebf55e1c0b4b3c27dd72d0eaa6e 100644 --- a/mindspore-lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc @@ -20,7 +20,7 @@ #include "tools/lite_exporter/fetch_content.h" #include "ops_utils/op_utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" diff --git a/mindspore-lite/tools/optimizer/fusion/ffn_custom_pass.cc b/mindspore-lite/tools/optimizer/fusion/ffn_custom_pass.cc index 47ac51771814f3488e671a643f8018773a13478e..94bcd429162384648ff64f2c2857e938a4045997 100644 --- a/mindspore-lite/tools/optimizer/fusion/ffn_custom_pass.cc +++ b/mindspore-lite/tools/optimizer/fusion/ffn_custom_pass.cc @@ -22,7 +22,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "mindspore/ops/op_def/lite_ops.h" #include "mindspore/ops/infer/custom.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/common/string_util.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" diff --git a/mindspore-lite/tools/optimizer/fusion/ffn_fusion.cc b/mindspore-lite/tools/optimizer/fusion/ffn_fusion.cc index e210aaa18a189df7a73728e22baa516b311ad417..91cdb1cca8e7d4719e80e39d39249d3ea992b24f 100644 --- a/mindspore-lite/tools/optimizer/fusion/ffn_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/ffn_fusion.cc @@ -22,7 +22,7 @@ #include "mindspore/ops/op_def/lite_ops.h" #include "mindspore/ops/infer/custom.h" #include "infer/f_f_n.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" diff --git a/mindspore-lite/tools/optimizer/fusion/flash_attention_fusion_for_custom.cc b/mindspore-lite/tools/optimizer/fusion/flash_attention_fusion_for_custom.cc index c89c64b3a321d81d24cbc18c89a1288c93d54764..fbcdb17bd4143dab701a94667edb349a5a655345 100644 --- a/mindspore-lite/tools/optimizer/fusion/flash_attention_fusion_for_custom.cc +++ b/mindspore-lite/tools/optimizer/fusion/flash_attention_fusion_for_custom.cc @@ -18,7 +18,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/flash_attention.h" diff --git a/mindspore-lite/tools/optimizer/fusion/flash_attention_tik_fusion.cc b/mindspore-lite/tools/optimizer/fusion/flash_attention_tik_fusion.cc index 56af4daf8e0f0b0b9577ec1d2cae1d0923b0f64a..0502d6df862a4964364abad43f49ae519bb9af3b 100644 --- a/mindspore-lite/tools/optimizer/fusion/flash_attention_tik_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/flash_attention_tik_fusion.cc @@ -19,7 +19,7 @@ #include "op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/nn_ops.h" #include "infer/custom.h" -#include "nnacl/base/cast_base.h" +#include "nnacl_c/base/cast_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" namespace mindspore { diff --git a/mindspore-lite/tools/optimizer/fusion/fullconnected_add_fusion.cc b/mindspore-lite/tools/optimizer/fusion/fullconnected_add_fusion.cc index 848808ab541202497a8de6aaa56613a8edb96c79..614be3ad4852265b321c20af21fd7ad1185fc402 100644 --- a/mindspore-lite/tools/optimizer/fusion/fullconnected_add_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/fullconnected_add_fusion.cc @@ -23,7 +23,7 @@ #include "infer/cxx_api/add_fusion.h" #include "infer/cxx_api/full_connection.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" diff --git a/mindspore-lite/tools/optimizer/fusion/fullconnected_fusion.cc b/mindspore-lite/tools/optimizer/fusion/fullconnected_fusion.cc index 58b5306f54fe5fa3e45a1225b339ba5951d0c812..e31920fcbb03f96c687447038a8df169782040dd 100644 --- a/mindspore-lite/tools/optimizer/fusion/fullconnected_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/fullconnected_fusion.cc @@ -23,7 +23,7 @@ #include "infer/cxx_api/full_connection.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/quantizer/quant_param_holder.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" diff --git a/mindspore-lite/tools/optimizer/fusion/gelu_fusion.cc b/mindspore-lite/tools/optimizer/fusion/gelu_fusion.cc index 61fd1096d91ab1330f583c19943df8d972d75111..92b0f7f78fd2dce7a07d687ebebfc8ef91492a1d 100644 --- a/mindspore-lite/tools/optimizer/fusion/gelu_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/gelu_fusion.cc @@ -21,7 +21,7 @@ #include "infer/cxx_api/activation.h" #include "include/common/utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace opt { diff --git a/mindspore-lite/tools/optimizer/fusion/glu_fusion.cc b/mindspore-lite/tools/optimizer/fusion/glu_fusion.cc index f6f18a7f7a50fcab66ba72b144b637b40fe23999..2e4dcddbc5d8d56395f303a2db69e3ad851b0d1a 100644 --- a/mindspore-lite/tools/optimizer/fusion/glu_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/glu_fusion.cc @@ -23,7 +23,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "include/common/utils/utils.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/groupnorm_fusion.cc b/mindspore-lite/tools/optimizer/fusion/groupnorm_fusion.cc index f1c70f862d033b8453073f89f6b025fc6e1e7701..295d33cb3291152ace458d19b5eba99aeff5f444 100644 --- a/mindspore-lite/tools/optimizer/fusion/groupnorm_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/groupnorm_fusion.cc @@ -26,7 +26,7 @@ #include "include/common/utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/ops/ops_utils.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/optimizer/fusion/hard_swish_fusion.cc b/mindspore-lite/tools/optimizer/fusion/hard_swish_fusion.cc index 92f27e075d5c6b2baf9b1cc90ec4b093835f8dc0..6a7e7088d29ff2b9cf6538c13df9bf4b698023af 100644 --- a/mindspore-lite/tools/optimizer/fusion/hard_swish_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/hard_swish_fusion.cc @@ -17,7 +17,7 @@ #include "tools/optimizer/fusion/hard_swish_fusion.h" #include #include "mindspore/ops/op_def/lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/activation.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" diff --git a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_assign_fusion.cc b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_assign_fusion.cc index a5d90a54859a4a3cc1217af64e5e6d973f93e159..a85a38d8b779ad423a0f87b3631b975dd94f4e70 100644 --- a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_assign_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_assign_fusion.cc @@ -22,7 +22,7 @@ #include "mindspore/ops/op_def/math_ops.h" #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/nn_optimizer_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_concat_fusion.cc b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_concat_fusion.cc index 9ce4f27b121c91700d005ffa4018d4e841a3cab1..7fca7f7b1f6e391e8c46b40f0061bdba0117f583 100644 --- a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_concat_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_concat_fusion.cc @@ -22,7 +22,7 @@ #include "mindspore/ops/op_def/math_ops.h" #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_k.h" diff --git a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_load_fusion.cc b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_load_fusion.cc index 31f36d957c7bc2c33e0226ec5efe4c835ce0d704..512291922128c1c5b5258b816511726f7be87506 100644 --- a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_load_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_load_fusion.cc @@ -22,7 +22,7 @@ #include "mindspore/ops/op_def/math_ops.h" #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/nn_optimizer_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_k.h" diff --git a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc index 4734e0659e318b05d32f41903b3f4f5f484ebbe0..0b03e4bc2ff915b410c57594d7b5f613ce6873da 100644 --- a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc @@ -23,7 +23,7 @@ #include "src/common/log_adapter.h" #include "infer/splice.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/math_ops.h" diff --git a/mindspore-lite/tools/optimizer/fusion/leaky_relu_fusion.cc b/mindspore-lite/tools/optimizer/fusion/leaky_relu_fusion.cc index 02869e99b6ab93b895e51c1068ea105a473aa9c8..fedb5bf1531af6277766b39739045175b8e76a6c 100644 --- a/mindspore-lite/tools/optimizer/fusion/leaky_relu_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/leaky_relu_fusion.cc @@ -22,7 +22,7 @@ #include "infer/cxx_api/activation.h" #include "infer/cxx_api/mul_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "infer/leaky_relu.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/matmul_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/matmul_activation_fusion.cc index 187b9a1249c887472b559488ae65b19a60a764b9..918b634e1ee14d1089e20732f4338908fb671e77 100644 --- a/mindspore-lite/tools/optimizer/fusion/matmul_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/matmul_activation_fusion.cc @@ -22,7 +22,7 @@ #include "infer/cxx_api/mat_mul_fusion.h" #include "include/common/utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/matmul_add_fusion.cc b/mindspore-lite/tools/optimizer/fusion/matmul_add_fusion.cc index 117989677acd36fe5b9b8128c8e31339984ba1fe..42110b949614af6dd8b7346d39ab43da4696bb77 100644 --- a/mindspore-lite/tools/optimizer/fusion/matmul_add_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/matmul_add_fusion.cc @@ -24,7 +24,7 @@ #include "infer/cxx_api/add_fusion.h" #include "infer/cxx_api/mat_mul_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" diff --git a/mindspore-lite/tools/optimizer/fusion/matmul_allreduce_fusion.cc b/mindspore-lite/tools/optimizer/fusion/matmul_allreduce_fusion.cc index 761c7dc34f43d8d08945b65d464df583a21d23de..ad98cfbe36cba020b3680e349b60102e9cb0bffd 100644 --- a/mindspore-lite/tools/optimizer/fusion/matmul_allreduce_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/matmul_allreduce_fusion.cc @@ -24,7 +24,7 @@ #include "mindspore/ops/infer/all_reduce.h" #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h" #include "ir/anf.h" diff --git a/mindspore-lite/tools/optimizer/fusion/matmul_mul_fusion.cc b/mindspore-lite/tools/optimizer/fusion/matmul_mul_fusion.cc index 29944d5e79f1fc80a5b1577b62705e7e12ef7358..78cc55c3740d9d2c0abb379c862defb4c1a300dc 100644 --- a/mindspore-lite/tools/optimizer/fusion/matmul_mul_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/matmul_mul_fusion.cc @@ -22,7 +22,7 @@ #include "infer/cxx_api/mat_mul_fusion.h" #include "infer/cxx_api/mul_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/matmul_scale_fusion.cc b/mindspore-lite/tools/optimizer/fusion/matmul_scale_fusion.cc index fa9090a2adbda535ef11268b70607f27e8e66a9d..1f5ea6034f23f2ea46c32739e4381b1b83f26f24 100644 --- a/mindspore-lite/tools/optimizer/fusion/matmul_scale_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/matmul_scale_fusion.cc @@ -22,7 +22,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/quantizer/quant_param_holder.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/mul_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/mul_activation_fusion.cc index a32fc34da4f487451906712ce33bb0fc6eba8eea..b10a258758dc0103a9b250daa5cd515306b06366 100644 --- a/mindspore-lite/tools/optimizer/fusion/mul_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/mul_activation_fusion.cc @@ -20,7 +20,7 @@ #include "infer/cxx_api/activation.h" #include "infer/cxx_api/mul_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/mul_add_fusion.cc b/mindspore-lite/tools/optimizer/fusion/mul_add_fusion.cc index 75e1211c571c20571fb3d9a1a1fd82ee0fd88068..68d334b487c07300deac5088bb092c5984aca575 100644 --- a/mindspore-lite/tools/optimizer/fusion/mul_add_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/mul_add_fusion.cc @@ -20,7 +20,7 @@ #include #include #include "mindspore/ops/op_def/lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/add_fusion.h" #include "infer/cxx_api/mul_fusion.h" #include "infer/cxx_api/scale_fusion.h" diff --git a/mindspore-lite/tools/optimizer/fusion/mul_reduce_fusion.cc b/mindspore-lite/tools/optimizer/fusion/mul_reduce_fusion.cc index 80f3771e300c5737cee4080aada7a83b8d283697..bc1025b2884791940aa9b49abe5503557e17d78b 100644 --- a/mindspore-lite/tools/optimizer/fusion/mul_reduce_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/mul_reduce_fusion.cc @@ -29,7 +29,7 @@ #include "infer/cxx_api/mul_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/op_name.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc b/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc index 71f13cd2d97592f02d52c95b6a9db7a6d5cba23e..7fd376eaa8ff1d9b57aaa28dbede880b02e38d6a 100644 --- a/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc @@ -30,7 +30,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/tuple_get_item.h" #include "tools/common/tensor_util.h" #include "ops_utils/op_utils.h" diff --git a/mindspore-lite/tools/optimizer/fusion/norm_fusion.cc b/mindspore-lite/tools/optimizer/fusion/norm_fusion.cc index 94b50f3b2ed96f00365ae35ef85ae1806cee4373..0b2e98ebac76205dde262ab534dd768b2ed6fda0 100644 --- a/mindspore-lite/tools/optimizer/fusion/norm_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/norm_fusion.cc @@ -27,7 +27,7 @@ #include "include/common/utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/ops/anf_utils.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/optimizer/fusion/onnx_gelu_fusion.cc b/mindspore-lite/tools/optimizer/fusion/onnx_gelu_fusion.cc index 5e61d4bf025e1fd34c6c855a4923000ce439cc83..9a78a2f2d48347ee40f86330cc66e349e5965098 100644 --- a/mindspore-lite/tools/optimizer/fusion/onnx_gelu_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/onnx_gelu_fusion.cc @@ -18,7 +18,7 @@ #include "tools/optimizer/fusion/onnx_gelu_fusion.h" #include #include "mindspore/ops/op_def/lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ccsrc/include/common/utils/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" diff --git a/mindspore-lite/tools/optimizer/fusion/prelu_fusion.cc b/mindspore-lite/tools/optimizer/fusion/prelu_fusion.cc index 55ac223dd47f4619868b44cf3e816c8f1eb4782e..594023da49d642e43b7dfdd4c1f9c14537f0c773 100644 --- a/mindspore-lite/tools/optimizer/fusion/prelu_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/prelu_fusion.cc @@ -17,7 +17,7 @@ #include "tools/optimizer/fusion/prelu_fusion.h" #include "mindspore/ops/op_def/math_ops.h" #include "mindspore/ops/op_def/lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/prelu_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/quant_dtype_cast_fusion.cc b/mindspore-lite/tools/optimizer/fusion/quant_dtype_cast_fusion.cc index 38bffca64bbd3156e1df7d6f8fb8cfcb2bca4a0b..f8ce59dfe651b00cddb752654d1fcc1d9a36822a 100644 --- a/mindspore-lite/tools/optimizer/fusion/quant_dtype_cast_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/quant_dtype_cast_fusion.cc @@ -16,7 +16,7 @@ #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" #include "mindspore/ops/op_def/framework_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_q.h" namespace mindspore { diff --git a/mindspore-lite/tools/optimizer/fusion/reduce_stack_fusion.cc b/mindspore-lite/tools/optimizer/fusion/reduce_stack_fusion.cc index 756b1aaed3573ad9590935ecd0fa241f0dd7501f..f2d73ef33cfa23157ec7f407b50a32bec228f4de 100644 --- a/mindspore-lite/tools/optimizer/fusion/reduce_stack_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/reduce_stack_fusion.cc @@ -22,7 +22,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/lite_exporter/fetch_content.h" #include "mindspore/ops/op_def/op_name.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/remove_transitivity_op.cc b/mindspore-lite/tools/optimizer/fusion/remove_transitivity_op.cc index d17bf8542a4a460bbb7a9ff6d418061348363c4c..bc97c6e4293ba7135fb4abdb20e1d74bec53db46 100644 --- a/mindspore-lite/tools/optimizer/fusion/remove_transitivity_op.cc +++ b/mindspore-lite/tools/optimizer/fusion/remove_transitivity_op.cc @@ -23,7 +23,7 @@ #include "tools/optimizer/fusion/strided_slice_checker.h" #include "tools/optimizer/common/gllo_utils.h" #include "mindspore/ops/op_def/op_name.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" diff --git a/mindspore-lite/tools/optimizer/fusion/reshape_like_operator_ablation.cc b/mindspore-lite/tools/optimizer/fusion/reshape_like_operator_ablation.cc index 6045c9fb94c21108d9c7ca461e44111c10d0e79d..b12a5634508a5d12bb8ffb1eec6f9f29e870982f 100644 --- a/mindspore-lite/tools/optimizer/fusion/reshape_like_operator_ablation.cc +++ b/mindspore-lite/tools/optimizer/fusion/reshape_like_operator_ablation.cc @@ -19,7 +19,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/errorcode.h" #include "src/common/log_util.h" diff --git a/mindspore-lite/tools/optimizer/fusion/reshape_reduce_fusion.cc b/mindspore-lite/tools/optimizer/fusion/reshape_reduce_fusion.cc index 370a5b86c876c3efd480b85169c8a7b619b98df5..4d909670b071c0a61f0d980ecf4fcbc8a99c2420 100644 --- a/mindspore-lite/tools/optimizer/fusion/reshape_reduce_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/reshape_reduce_fusion.cc @@ -23,7 +23,7 @@ #include "mindspore/ops/op_def/op_name.h" #include "tools/lite_exporter/fetch_content.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" namespace mindspore { diff --git a/mindspore-lite/tools/optimizer/fusion/reshape_reshape_fusion.cc b/mindspore-lite/tools/optimizer/fusion/reshape_reshape_fusion.cc index c525934b2eba783a6f656293ee724410ac3c3f7a..7cf375e69061b1637375d91a365acda513fd1fe7 100644 --- a/mindspore-lite/tools/optimizer/fusion/reshape_reshape_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/reshape_reshape_fusion.cc @@ -25,7 +25,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/lite_exporter/fetch_content.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" diff --git a/mindspore-lite/tools/optimizer/fusion/reshape_shape_fusion.cc b/mindspore-lite/tools/optimizer/fusion/reshape_shape_fusion.cc index 5b2af29324e8b5a57fe9a5932142bad8dfb2f7d3..ff949dc15a43712902b8f4757d4851d0dfee596b 100644 --- a/mindspore-lite/tools/optimizer/fusion/reshape_shape_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/reshape_shape_fusion.cc @@ -19,7 +19,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/reshape_transpose_fusion.cc b/mindspore-lite/tools/optimizer/fusion/reshape_transpose_fusion.cc index a2307937b2184fbfa54795926fcf496704685b36..87e4c45dbc5cf5aeae6ca343c40ae315fbe43007 100644 --- a/mindspore-lite/tools/optimizer/fusion/reshape_transpose_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/reshape_transpose_fusion.cc @@ -24,7 +24,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/lite_exporter/fetch_content.h" #include "tools/optimizer/common/format_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/tools/optimizer/fusion/resize_fusion.cc b/mindspore-lite/tools/optimizer/fusion/resize_fusion.cc index e32773694b5fb258a5f6cbb75d6fce8112f6b06c..1755f1e7cd7a18fd7c439ff01c364f1c75e5a344 100644 --- a/mindspore-lite/tools/optimizer/fusion/resize_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/resize_fusion.cc @@ -25,7 +25,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/common/tensor_util.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "infer/resize.h" #include "mindapi/base/types.h" diff --git a/mindspore-lite/tools/optimizer/fusion/scale_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/scale_activation_fusion.cc index 91d5c74f40a9d0925d38596d04efb6309087ea9d..543cdfc5440bf7de981f1a358f80318faa5dc8f2 100644 --- a/mindspore-lite/tools/optimizer/fusion/scale_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/scale_activation_fusion.cc @@ -22,7 +22,7 @@ #include "infer/cxx_api/scale_fusion.h" #include "ops_utils/op_utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/scale_base_fusion.cc b/mindspore-lite/tools/optimizer/fusion/scale_base_fusion.cc index b0aaaf236c36294a8b3ab539bd1ac5338462f2f9..214424bf9968018e13b829b07baa1004a96081f9 100644 --- a/mindspore-lite/tools/optimizer/fusion/scale_base_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/scale_base_fusion.cc @@ -21,7 +21,7 @@ #include "tools/common/tensor_util.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/quantizer/quant_param_holder.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/scale_scale_fusion.cc b/mindspore-lite/tools/optimizer/fusion/scale_scale_fusion.cc index d68ace82e319497dc6abc820466f7384752608c2..ff537a4c8490b0a0be0853edc8b115a10eb825ce 100644 --- a/mindspore-lite/tools/optimizer/fusion/scale_scale_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/scale_scale_fusion.cc @@ -24,7 +24,7 @@ #include "tools/common/tensor_util.h" #include "infer/cxx_api/scale_fusion.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc b/mindspore-lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc index 9e058ea8ef89abcefce8e1fdc3a77ef498a60e08..34c15542b2399d2934870cd3c648722834de6323 100644 --- a/mindspore-lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc @@ -22,7 +22,7 @@ #include "ops_utils/op_utils.h" #include "include/common/utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc b/mindspore-lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc index fa947d7a6466776275b058ae8cdc44d844c31251..74e9632beb4bc4da68e011354ece4445db22a7e8 100644 --- a/mindspore-lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc @@ -22,7 +22,7 @@ #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/squeeze_fusion.cc b/mindspore-lite/tools/optimizer/fusion/squeeze_fusion.cc index fd3a7728d37f87b894418aff3baa86e7bfabc594..1eb956e706e00133d4c0729c58aa3abd591531f6 100644 --- a/mindspore-lite/tools/optimizer/fusion/squeeze_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/squeeze_fusion.cc @@ -24,7 +24,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/unsqueeze.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" diff --git a/mindspore-lite/tools/optimizer/fusion/strided_slice_fusion.cc b/mindspore-lite/tools/optimizer/fusion/strided_slice_fusion.cc index e575974cbb2cd00dcc44ec796c72b0cea06aa414..d4df93720740f73dd10e8a2d0340e1b7bfa89c6d 100644 --- a/mindspore-lite/tools/optimizer/fusion/strided_slice_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/strided_slice_fusion.cc @@ -23,7 +23,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/lite_exporter/fetch_content.h" #include "ir/func_graph.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_name.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/tensor_dot_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tensor_dot_fusion.cc index a9c7d8662416bfd8625d9df01893e2abd8e58462..fabb55b274d4eea79553c7452755aabf589399e4 100644 --- a/mindspore-lite/tools/optimizer/fusion/tensor_dot_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tensor_dot_fusion.cc @@ -24,7 +24,7 @@ #include "ops_utils/op_utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/lite_exporter/fetch_content.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc index 81c66cac65c75c6e8eaa8b3a25e33e32f9b05653..5589a730cec9c3a301448353c4b293745bead9ba 100644 --- a/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc @@ -32,7 +32,7 @@ #include "src/common/utils.h" #include "tools/common/tensor_util.h" #include "include/common/utils/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" diff --git a/mindspore-lite/tools/optimizer/fusion/tf_gelu_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tf_gelu_fusion.cc index c78153f7062b90284f99285a4afe0d3a8644f637..91ed75861a8002e7403f342aef08f2010837b9c8 100644 --- a/mindspore-lite/tools/optimizer/fusion/tf_gelu_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tf_gelu_fusion.cc @@ -19,7 +19,7 @@ #include #include "mindspore/ops/op_def/lite_ops.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindapi/base/types.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc index 1143a737498b0e72caa9d54ac2e896589086b671..34ef3c0bde8e93d47bf6c11adbf58177d1d1cfb8 100644 --- a/mindspore-lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc @@ -27,7 +27,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" #include "tools/optimizer/common/helper.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc index 588558e319cc5c7b6fa892cf988bb46e629fd9e2..c987f4450bf67d6bf2a3f9fb78678225dbcfec3b 100644 --- a/mindspore-lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc @@ -32,7 +32,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/helper.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" diff --git a/mindspore-lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc index 59e2ff676325209486ad19321b78833b1813b7e6..fd9f90989174f4e819472663441e49d33f59da72 100644 --- a/mindspore-lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc @@ -25,7 +25,7 @@ #include "tools/converter/quantizer/quant_param_holder.h" #include "tools/converter/quantizer/quantize_util.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/tile_matmul_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tile_matmul_fusion.cc index 3a9ad0a574356922488a6d3a058f9c723e05a3d4..cd77724fb8361a859234d8a228fd88067117bea2 100644 --- a/mindspore-lite/tools/optimizer/fusion/tile_matmul_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tile_matmul_fusion.cc @@ -19,7 +19,7 @@ #include #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/lite_exporter/fetch_content.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/tools/optimizer/fusion/transpose_fusion.cc b/mindspore-lite/tools/optimizer/fusion/transpose_fusion.cc index 850d4227989b8816bd54fef0397039a15da29160..aa2827644b9036c2febe4f900129ef7bc00165d0 100644 --- a/mindspore-lite/tools/optimizer/fusion/transpose_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/transpose_fusion.cc @@ -26,7 +26,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/optimizer/common/format_utils.h" #include "infer/cxx_api/scale_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" diff --git a/mindspore-lite/tools/optimizer/fusion/transpose_gather_fusion.cc b/mindspore-lite/tools/optimizer/fusion/transpose_gather_fusion.cc index fab33d732a06c4d31bf8632431575ec40012399c..f68648c57e710a418d22ca699362bb94da87b8a3 100644 --- a/mindspore-lite/tools/optimizer/fusion/transpose_gather_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/transpose_gather_fusion.cc @@ -22,7 +22,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/format_utils.h" #include "tools/lite_exporter/fetch_content.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/tools/optimizer/fusion/transpose_matmul_fusion.cc b/mindspore-lite/tools/optimizer/fusion/transpose_matmul_fusion.cc index 8064276d380099b3de82b4357f570c28d98e20e9..60e61d5a4567191f4bf3bfa46a567459da85eb2c 100644 --- a/mindspore-lite/tools/optimizer/fusion/transpose_matmul_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/transpose_matmul_fusion.cc @@ -22,7 +22,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/format_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/graph/add_tensor_array.cc b/mindspore-lite/tools/optimizer/graph/add_tensor_array.cc index 70ada19c6cf6fa2fd30d2b2da6b2848344ceef4a..e6a40b0761aa51b3d652dd9e7b91a5320d788dff 100644 --- a/mindspore-lite/tools/optimizer/graph/add_tensor_array.cc +++ b/mindspore-lite/tools/optimizer/graph/add_tensor_array.cc @@ -26,7 +26,7 @@ #include "infer/tensor_array.h" #include "infer/tensor_array_read.h" #include "infer/tensor_array_write.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/make_tuple.h" #include "infer/return.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/optimizer/graph/attr_to_args_pass.cc b/mindspore-lite/tools/optimizer/graph/attr_to_args_pass.cc index e57043670af95eb47ad3f6d4791f24f0fe63d472..fd6679f42223818c84e925f40354b2ef3d314950 100644 --- a/mindspore-lite/tools/optimizer/graph/attr_to_args_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/attr_to_args_pass.cc @@ -17,7 +17,7 @@ #include "tools/optimizer/graph/attr_to_args_pass.h" #include #include "tools/common/node_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "ops/primitive_c.h" #include "ops/base_operator.h" diff --git a/mindspore-lite/tools/optimizer/graph/broadcast_for_select.cc b/mindspore-lite/tools/optimizer/graph/broadcast_for_select.cc index ce3382f32f2306eec5ef22e4875f0a9cd33257d4..f818d0ae7b23f5a9c2823c705360e04e8ebedeb0 100644 --- a/mindspore-lite/tools/optimizer/graph/broadcast_for_select.cc +++ b/mindspore-lite/tools/optimizer/graph/broadcast_for_select.cc @@ -20,7 +20,7 @@ #include #include "mindspore/ops/op_def/array_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/common/utils/anfalgo.h" #include "mindspore/ccsrc/include/common/utils/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" @@ -45,8 +45,9 @@ ShapeVector CalcBroadcastShape(AnfNodePtr cond, AnfNodePtr x, AnfNodePtr y) { auto cond_size = cond_shape.size(); auto x_size = x_shape.size(); auto y_size = y_shape.size(); - ShapeVector broadcast_shape = - cond_size > x_size ? cond_size > y_size ? cond_shape : y_shape : x_size > y_size ? x_shape : y_shape; + ShapeVector broadcast_shape = cond_size > x_size ? cond_size > y_size ? cond_shape : y_shape + : x_size > y_size ? x_shape + : y_shape; auto n = broadcast_shape.size(); for (size_t i = n; i > 0; --i) { auto cond_i = cond_size < i ? 1 : cond_shape[cond_size - i]; diff --git a/mindspore-lite/tools/optimizer/graph/clip_convert_activation_pass.cc b/mindspore-lite/tools/optimizer/graph/clip_convert_activation_pass.cc index e7e25ad154a40e5addad4af96101d51416a4951d..970d9df18a055931ef7435996bc6996a5441ec86 100644 --- a/mindspore-lite/tools/optimizer/graph/clip_convert_activation_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/clip_convert_activation_pass.cc @@ -25,7 +25,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "src/tensor.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc b/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc index 13138233f66ea96128f136f1822a70c20dab6431..af4ed935b3922e3777667696ec89aa187cc1896f 100644 --- a/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc @@ -28,7 +28,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "src/common/log_adapter.h" #include "tools/common/node_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/graph/core_infershape_pass.cc b/mindspore-lite/tools/optimizer/graph/core_infershape_pass.cc index 0aa5756ac14c48d610f7e261d29cddd322cc1334..ffa5b0c255ecfa1deb52dd0286d1930956b5673d 100644 --- a/mindspore-lite/tools/optimizer/graph/core_infershape_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/core_infershape_pass.cc @@ -21,7 +21,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "tools/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "abstract/ops/primitive_infer_map.h" #include "tools/lite_exporter/fetch_content.h" diff --git a/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc b/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc index 5d984bf764becaa5c38e6b1bd635fa0940fd0869..7e2cf736a50cebb83ef80d82a75b730279822926 100644 --- a/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc +++ b/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc @@ -28,7 +28,7 @@ #include "src/common/common.h" #include "src/common/utils.h" #include "tools/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/optimizer/graph/specify_graph_input_format.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" diff --git a/mindspore-lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc b/mindspore-lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc index d345d546eea4ad82550206752353e1ea968a0a89..e90cb51ef6287122033581f8044b6acd099c8629 100644 --- a/mindspore-lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc @@ -26,7 +26,7 @@ #include "src/common/log_adapter.h" #include "tools/common/tensor_util.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/graph/infershape_pass.cc b/mindspore-lite/tools/optimizer/graph/infershape_pass.cc index 81ed965a0665380c983fd7eca0e74660b29ee37a..6ae3dcd6bf009e70836886ea5b123424bee2fbf7 100644 --- a/mindspore-lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/infershape_pass.cc @@ -21,7 +21,7 @@ #include "mindspore/ops/op_def/framework_ops.h" #include "tools/common/node_util.h" #include "tools/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "ops_utils/op_utils.h" #include "tools/optimizer/graph/decrease_transpose_algo.h" diff --git a/mindspore-lite/tools/optimizer/graph/input_data_type_trans_pass.cc b/mindspore-lite/tools/optimizer/graph/input_data_type_trans_pass.cc index 47b53c66bdb861f5868832b5543df6bb977851a3..c8b686abbf88ae64b78913ad01f96080bfc9aaae 100644 --- a/mindspore-lite/tools/optimizer/graph/input_data_type_trans_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/input_data_type_trans_pass.cc @@ -25,7 +25,7 @@ #include "tools/lite_exporter/fetch_content.h" #include "src/tensor.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" diff --git a/mindspore-lite/tools/optimizer/graph/int64_cast_int32_pass.cc b/mindspore-lite/tools/optimizer/graph/int64_cast_int32_pass.cc index 1f8008b41444e55b9d2cbc193ee31c511c830eaf..dfd783819fa51b995a0951109427b1d3737500dc 100644 --- a/mindspore-lite/tools/optimizer/graph/int64_cast_int32_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/int64_cast_int32_pass.cc @@ -27,7 +27,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "src/tensor.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" diff --git a/mindspore-lite/tools/optimizer/graph/lite_tensor_extractor.cc b/mindspore-lite/tools/optimizer/graph/lite_tensor_extractor.cc index 7d2539af9bc54bc0724a82aea208258ae9f12c2d..87af8994f6aa12668ac61ac17ecbdc03c77edf26 100644 --- a/mindspore-lite/tools/optimizer/graph/lite_tensor_extractor.cc +++ b/mindspore-lite/tools/optimizer/graph/lite_tensor_extractor.cc @@ -24,7 +24,7 @@ #include "src/tensorlist.h" #include "tools/optimizer/common/format_utils.h" #include "utils/ms_utils_secure.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/graph/miniaturization_pass.cc b/mindspore-lite/tools/optimizer/graph/miniaturization_pass.cc index 892be53bd8024a0095bceadd9e52b44aa90c62a7..3705d36258380a132163c1795ec31c3b2b8fb087 100644 --- a/mindspore-lite/tools/optimizer/graph/miniaturization_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/miniaturization_pass.cc @@ -19,7 +19,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "infer/fill.h" #include "src/common/utils.h" diff --git a/mindspore-lite/tools/optimizer/graph/mul_constant_pass.cc b/mindspore-lite/tools/optimizer/graph/mul_constant_pass.cc index 2971ae293dcf24be75a6a3994f2e7472660ebde1..e912b52cbda84196de5c77a76423382e006885a9 100644 --- a/mindspore-lite/tools/optimizer/graph/mul_constant_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/mul_constant_pass.cc @@ -17,7 +17,7 @@ #define USE_DEPRECATED_API #include "tools/optimizer/graph/mul_constant_pass.h" #include "mindspore/ops/op_def/lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "infer/cxx_api/mul_fusion.h" #include "src/common/utils.h" diff --git a/mindspore-lite/tools/optimizer/graph/node_infershape.cc b/mindspore-lite/tools/optimizer/graph/node_infershape.cc index 69d3edf9e778617a911d7f3d318ecd3b379328a0..ab03d4c38e1254c124ac869043ab99d6f6a6e01a 100644 --- a/mindspore-lite/tools/optimizer/graph/node_infershape.cc +++ b/mindspore-lite/tools/optimizer/graph/node_infershape.cc @@ -35,7 +35,7 @@ #include "src/tensorlist.h" #include "src/registry/kernel_interface_registry.h" #include "tools/optimizer/graph/lite_tensor_extractor.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "tools/optimizer/format/to_nchw_format.h" #include "tools/optimizer/format/to_nhwc_format.h" diff --git a/mindspore-lite/tools/optimizer/graph/preprocess_dynamic_shape.cc b/mindspore-lite/tools/optimizer/graph/preprocess_dynamic_shape.cc index e55755fd230c8317e66f077891d3129fcc71df50..485b1eafcfc1d36fd892f3ea8eee57f7a1836f5e 100644 --- a/mindspore-lite/tools/optimizer/graph/preprocess_dynamic_shape.cc +++ b/mindspore-lite/tools/optimizer/graph/preprocess_dynamic_shape.cc @@ -30,7 +30,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/lite_exporter/fetch_content.h" #include "mindspore/ops/op_def/op_name.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" diff --git a/mindspore-lite/tools/optimizer/graph/redundant_op_remove_pass.cc b/mindspore-lite/tools/optimizer/graph/redundant_op_remove_pass.cc index 74fe203a553b524419c73f6153746d08873886b8..7ece481d08c9ca6d4e821d73b8430bee06c4f4f2 100644 --- a/mindspore-lite/tools/optimizer/graph/redundant_op_remove_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/redundant_op_remove_pass.cc @@ -31,7 +31,7 @@ #include "infer/depend.h" #include "infer/cxx_api/pad_fusion.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/common/utils/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" diff --git a/mindspore-lite/tools/optimizer/graph/slice_prepose_pass.cc b/mindspore-lite/tools/optimizer/graph/slice_prepose_pass.cc index ccc49d4c9463bf8167ff4997cb03538f77b5f063..e0beb62d2f798deaf49a48df8ad7077cc6513fe4 100644 --- a/mindspore-lite/tools/optimizer/graph/slice_prepose_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/slice_prepose_pass.cc @@ -32,7 +32,7 @@ #include "tools/optimizer/common/helper.h" #include "include/backend/optimizer/helper.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/graph/special_node_postprocess.cc b/mindspore-lite/tools/optimizer/graph/special_node_postprocess.cc index 33dce353a09f308aec8af89ad8a5d028ff6641d6..87dd2fa161f721caa862db92959df7e06830ddb9 100644 --- a/mindspore-lite/tools/optimizer/graph/special_node_postprocess.cc +++ b/mindspore-lite/tools/optimizer/graph/special_node_postprocess.cc @@ -24,7 +24,7 @@ #include "mindspore/ops/op_def/framework_ops.h" #include "include/errorcode.h" #include "tools/optimizer/common/format_utils.h" -#include "nnacl//op_base.h" +#include "nnacl_c//op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/graph/specify_graph_input_format.cc b/mindspore-lite/tools/optimizer/graph/specify_graph_input_format.cc index 22cbae3ae9b302679c99100ac9aca648f0767fd4..5b73b8bf7b3644183dc69c4c41fc5415e8175237 100644 --- a/mindspore-lite/tools/optimizer/graph/specify_graph_input_format.cc +++ b/mindspore-lite/tools/optimizer/graph/specify_graph_input_format.cc @@ -25,7 +25,7 @@ #include "tools/converter/parser/parser_utils.h" #include "tools/optimizer/common/format_utils.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" diff --git a/mindspore-lite/tools/optimizer/graph/specify_graph_output_format.cc b/mindspore-lite/tools/optimizer/graph/specify_graph_output_format.cc index 72ab75dfb6edee2052af4dc3788600a728dcea3c..18debeb7e7eabcaa5b7410e79252f7ce347f81e0 100644 --- a/mindspore-lite/tools/optimizer/graph/specify_graph_output_format.cc +++ b/mindspore-lite/tools/optimizer/graph/specify_graph_output_format.cc @@ -26,7 +26,7 @@ #include "tools/optimizer/common/format_utils.h" #include "tools/lite_exporter/fetch_content.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "infer/make_tuple.h" #include "mindspore/ccsrc/include/common/utils/utils.h" diff --git a/mindspore-lite/tools/optimizer/graph/transpose_strategy.cc b/mindspore-lite/tools/optimizer/graph/transpose_strategy.cc index 2ab1c639b61106f5800be540cfaff33030602d3c..afccb911aaf81b9c28a5d48118dffbdc13374840 100644 --- a/mindspore-lite/tools/optimizer/graph/transpose_strategy.cc +++ b/mindspore-lite/tools/optimizer/graph/transpose_strategy.cc @@ -33,7 +33,7 @@ #include "infer/cxx_api/slice_fusion.h" #include "ops_utils/op_utils.h" #include "tools/lite_exporter/fetch_content.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" diff --git a/mindspore-lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc b/mindspore-lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc index ab455c66e774936d2785ace6fe181c638e0fcbc0..ac4a402b2d5996b4b9af03c2034a7a7627e09d55 100644 --- a/mindspore-lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc @@ -23,7 +23,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/tools/optimizer/parallel/conv2d_info.cc b/mindspore-lite/tools/optimizer/parallel/conv2d_info.cc index 1ffbd0d7b6a934f1a45663b5893b0282ccc1b425..06daaae828eb3fa6996c49ed22afd21a616df72e 100644 --- a/mindspore-lite/tools/optimizer/parallel/conv2d_info.cc +++ b/mindspore-lite/tools/optimizer/parallel/conv2d_info.cc @@ -30,7 +30,7 @@ #include "tools/optimizer/parallel/operator_info_register.h" #include "tools/optimizer/parallel/spliter.h" #include "tools/optimizer/fisson/fisson_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "utils/anf_utils.h" diff --git a/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc b/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc index d997fcd7288e3779adea6306d358c0d6b256a41e..46f7ca33250826da3480b3abf899f07e134d9832 100644 --- a/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc +++ b/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc @@ -24,7 +24,7 @@ #include "include/securec.h" #include "mindspore/ops/op_def/conv_pool_ops.h" #include "mindspore/ops/op_def/lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/conv2d_fusion.h" #include "tools/optimizer/common/gllo_utils.h" #include "include/common/utils/utils.h" diff --git a/mindspore-lite/tools/optimizer/parallel/multi_conv_info.cc b/mindspore-lite/tools/optimizer/parallel/multi_conv_info.cc index 073aa215ad61249e5c0835f90e2591d4ab0b78c2..d143584417761b2e1a99af857d31153eb80d31d6 100644 --- a/mindspore-lite/tools/optimizer/parallel/multi_conv_info.cc +++ b/mindspore-lite/tools/optimizer/parallel/multi_conv_info.cc @@ -21,7 +21,7 @@ #include "tools/optimizer/parallel/spliter.h" #include "infer/cxx_api/conv2d_fusion.h" #include "tools/optimizer/parallel/split_strategy.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" using mindspore::schema::PrimitiveType_Conv2dTransposeFusion; diff --git a/mindspore-lite/tools/optimizer/parallel/multi_node_split.cc b/mindspore-lite/tools/optimizer/parallel/multi_node_split.cc index b9e86cdcd4c47294dee191ae3492ba8a90fd5f00..3cb6fae13500d6efdc5bd7e5a94b8f0edceb3ea0 100644 --- a/mindspore-lite/tools/optimizer/parallel/multi_node_split.cc +++ b/mindspore-lite/tools/optimizer/parallel/multi_node_split.cc @@ -17,7 +17,7 @@ #define USE_DEPRECATED_API #include "tools/optimizer/parallel/multi_node_split.h" #include "tools/optimizer/parallel/multi_conv_info.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace opt { diff --git a/mindspore-lite/tools/optimizer/parallel/operator_info.cc b/mindspore-lite/tools/optimizer/parallel/operator_info.cc index 30631f3941a8b23d7c71529cbb7a21f5bb03d1eb..8075b7ce05f2030ef0bb5cfae2afc78ded1f129c 100644 --- a/mindspore-lite/tools/optimizer/parallel/operator_info.cc +++ b/mindspore-lite/tools/optimizer/parallel/operator_info.cc @@ -22,7 +22,7 @@ #include "infer/tuple_get_item.h" #include "include/common/utils/utils.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "src/common/log_util.h" diff --git a/mindspore-lite/tools/optimizer/parallel/parallel_pass.cc b/mindspore-lite/tools/optimizer/parallel/parallel_pass.cc index 99f301541ada0d2f06e45613fb2db73a478e955d..4dd3e50d0497cb369b3bb3b992bc01625aacba00 100644 --- a/mindspore-lite/tools/optimizer/parallel/parallel_pass.cc +++ b/mindspore-lite/tools/optimizer/parallel/parallel_pass.cc @@ -20,7 +20,7 @@ #include "ir/tensor.h" #include "tools/optimizer/parallel/operator_info_register.h" #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" namespace mindspore { diff --git a/mindspore-lite/tools/optimizer/parallel/split_strategy.cc b/mindspore-lite/tools/optimizer/parallel/split_strategy.cc index 82c45d01080cde89c7587380bedc5038beee1884..e76b036afb7037363c09cf4eff687aae42f1ebbd 100644 --- a/mindspore-lite/tools/optimizer/parallel/split_strategy.cc +++ b/mindspore-lite/tools/optimizer/parallel/split_strategy.cc @@ -19,7 +19,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore {